From 8abcc70a747eff568198ae64aa1a60b7625a3c36 Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Wed, 4 Feb 2026 13:09:58 +0100 Subject: [PATCH] model: (qwen3next) correct vectorized key_gdiff calculation (#19324) * model: (qwen3next) correct vectorized key_gdiff calculation * move transpose to outside of loop --- src/models/qwen3next.cpp | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/models/qwen3next.cpp b/src/models/qwen3next.cpp index 57b6659baf..06d946c5fa 100644 --- a/src/models/qwen3next.cpp +++ b/src/models/qwen3next.cpp @@ -265,9 +265,15 @@ std::pair llm_build_qwen3next::build_delta_net_chu cb(g_diff, "g_diff", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs) ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff); - ggml_tensor * key_gdiff = ggml_mul(ctx0, k, g_diff_exp); + ggml_tensor * g_diff_exp_t = ggml_reshape_4d(ctx0, g_diff_exp, + 1, chunk_size, n_chunks, g_diff_exp->ne[3]); + + ggml_tensor * key_gdiff = ggml_mul(ctx0, k, g_diff_exp_t); cb(key_gdiff, "key_gdiff", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs) + ggml_tensor * key_gdiff_t = ggml_cont(ctx0, ggml_transpose(ctx0, key_gdiff)); + cb(key_gdiff_t, "key_gdiff_t", il); // shape: (chunk_size, S_k, n_chunks, H_v * n_seqs) + // state to be updated per chunk ggml_tensor * new_state = state; // ggml_dup(ctx0, state); @@ -322,9 +328,9 @@ std::pair llm_build_qwen3next::build_delta_net_chu : ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 2); // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new - ggml_tensor * k_gdiff = ggml_cont(ctx0, get_slice_2d(ctx0, key_gdiff, chunk)); + ggml_tensor * k_gdiff_t = get_slice_2d(ctx0, key_gdiff_t, chunk); //ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, k_gdiff, v_new); // this is slower on metal, why? - ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, ggml_cont(ctx0, ggml_transpose(ctx0, k_gdiff))); + ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, k_gdiff_t); // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew ggml_tensor * gexp_last_chunk = ggml_cont(ctx0, get_slice_2d(ctx0, g_last_exp, chunk));