From a5251ca11d2317d93a7b6da4217483f4e83beb3d Mon Sep 17 00:00:00 2001 From: "Piotr Wilkin (ilintar)" Date: Tue, 16 Dec 2025 11:59:53 +0100 Subject: [PATCH] Optimization: Qwen3 next autoregressive pass (#17996) * It's Qwen3 Next, the lean mean token generation machine! * Apply patches from thread * Remove recurrent version, only keep chunked and autoregressive * Remove unnecessary conts and asserts * Remove more extra conts and asserts * Cleanup masking --- src/models/models.h | 22 +-- src/models/qwen3next.cpp | 333 +++++++++------------------------------ 2 files changed, 85 insertions(+), 270 deletions(-) diff --git a/src/models/models.h b/src/models/models.h index 6494f54501..ffb36acc61 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -441,23 +441,13 @@ private: ggml_tensor * cur, ggml_tensor * causal_mask, ggml_tensor * identity, + ggml_tensor * diag_mask, int il); ggml_tensor * build_layer_ffn( ggml_tensor * cur, int il); - ggml_tensor * build_delta_net_recurrent( - ggml_tensor * q, - ggml_tensor * k, - ggml_tensor * v, - ggml_tensor * g, - ggml_tensor * beta, - ggml_tensor * state, - ggml_tensor * causal_mask, - ggml_tensor * identity, - int il); - ggml_tensor * build_delta_net_chunking( ggml_tensor * q, ggml_tensor * k, @@ -467,8 +457,18 @@ private: ggml_tensor * state, ggml_tensor * causal_mask, ggml_tensor * identity, + ggml_tensor * diag_mask, int il); + ggml_tensor * build_delta_net_autoregressive( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * beta, + ggml_tensor * state, + int il); + ggml_tensor * build_norm_gated( ggml_tensor * input, ggml_tensor * weights, diff --git a/src/models/qwen3next.cpp b/src/models/qwen3next.cpp index c8f1b5ec90..775b3135d3 100644 --- a/src/models/qwen3next.cpp +++ b/src/models/qwen3next.cpp @@ -17,13 +17,15 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr ggml_tensor * inp_out_ids = build_inp_out_ids(); ggml_tensor * causal_mask = - ggml_tri(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, ubatch.n_seq_tokens, ubatch.n_seq_tokens), 1.0f), + ggml_tri(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, CHUNK_SIZE, CHUNK_SIZE), 1.0f), GGML_TRI_TYPE_LOWER); - ggml_tensor * identity = ggml_diag(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, ubatch.n_seq_tokens), 1.0f)); + ggml_tensor * identity = ggml_diag(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, CHUNK_SIZE), 1.0f)); + ggml_tensor * diag_mask = ggml_add(ctx0, causal_mask, identity); ggml_build_forward_expand(gf, causal_mask); ggml_build_forward_expand(gf, identity); + ggml_build_forward_expand(gf, diag_mask); for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; @@ -34,7 +36,7 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr // Determine layer type and build appropriate attention mechanism if (hparams.is_recurrent(il)) { // Linear attention layer (gated delta net) - cur = build_layer_attn_linear(inp->get_recr(), cur, causal_mask, identity, il); + cur = build_layer_attn_linear(inp->get_recr(), cur, causal_mask, identity, diag_mask, il); } else { // Full attention layer cur = build_layer_attn(inp->get_attn(), cur, inp_pos, il); @@ -93,14 +95,8 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_chunking( ggml_tensor * state, ggml_tensor * causal_mask, ggml_tensor * identity, + ggml_tensor * diag_mask, int il) { - GGML_ASSERT(ggml_is_contiguous(q)); - GGML_ASSERT(ggml_is_contiguous(k)); - GGML_ASSERT(ggml_is_contiguous(v)); - GGML_ASSERT(ggml_is_contiguous(g)); - GGML_ASSERT(ggml_is_contiguous(beta)); - GGML_ASSERT(ggml_is_contiguous(state)); - const int64_t S_k = q->ne[0]; const int64_t H_k = q->ne[1]; const int64_t n_tokens = q->ne[2]; @@ -120,15 +116,10 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_chunking( GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case - // TODO: can this ever be false? - const bool use_qk_l2norm = true; + const float eps_norm = hparams.f_norm_rms_eps; - if (use_qk_l2norm) { - const float eps_norm = hparams.f_norm_rms_eps; - - q = ggml_l2_norm(ctx0, q, eps_norm); - k = ggml_l2_norm(ctx0, k, eps_norm); - } + q = ggml_l2_norm(ctx0, q, eps_norm); + k = ggml_l2_norm(ctx0, k, eps_norm); const float scale = 1.0f / sqrtf(S_v); @@ -136,8 +127,6 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_chunking( beta = ggml_sigmoid(ctx0, beta); - ggml_tensor * causal_diag_mask = ggml_add(ctx0, causal_mask, identity); - cb(q, "q_in", il); cb(k, "k_in", il); cb(v, "v_in", il); @@ -188,36 +177,21 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_chunking( cb(v_beta, "v_beta", il); cb(k_beta, "k_beta", il); - ggml_tensor * chunked_mask = - ggml_view_4d(ctx0, causal_mask, chunk_size, - chunk_size, causal_mask->ne[2], causal_mask->ne[3], - causal_mask->nb[1], causal_mask->nb[2], causal_mask->nb[3], 0); + q = ggml_reshape_4d(ctx0, q, S_k, chunk_size, n_chunks, H_k * n_seqs); + k = ggml_reshape_4d(ctx0, k, S_k, chunk_size, n_chunks, H_k * n_seqs); + k_beta = ggml_reshape_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, H_k * n_seqs); + v = ggml_reshape_4d(ctx0, v, S_v, chunk_size, n_chunks, H_v * n_seqs); + v_beta = ggml_reshape_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, H_v * n_seqs); - ggml_tensor * chunked_diag_mask = - ggml_view_4d(ctx0, causal_diag_mask, chunk_size, - chunk_size, causal_diag_mask->ne[2], causal_diag_mask->ne[3], - causal_diag_mask->nb[1], causal_diag_mask->nb[2], causal_diag_mask->nb[3], 0); - - ggml_tensor * chunked_identity = - ggml_view_4d(ctx0, identity, chunk_size, - chunk_size, identity->ne[2], identity->ne[3], - identity->nb[1], identity->nb[2], identity->nb[3], 0); - - q = ggml_cont_4d(ctx0, q, S_k, chunk_size, n_chunks, H_k * n_seqs); - k = ggml_cont_4d(ctx0, k, S_k, chunk_size, n_chunks, H_k * n_seqs); - k_beta = ggml_cont_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, H_k * n_seqs); - v = ggml_cont_4d(ctx0, v, S_v, chunk_size, n_chunks, H_v * n_seqs); - v_beta = ggml_cont_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, H_v * n_seqs); - - g = ggml_cont_4d(ctx0, g, chunk_size, 1, n_chunks, H_k * n_seqs); - beta = ggml_cont_4d(ctx0, beta, 1, chunk_size, n_chunks, H_k * n_seqs); + g = ggml_reshape_4d(ctx0, g, chunk_size, 1, n_chunks, H_k * n_seqs); + beta = ggml_reshape_4d(ctx0, beta, 1, chunk_size, n_chunks, H_k * n_seqs); ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g); cb(g_cumsum, "g_cumsum", il); - ggml_tensor * gcs_i = ggml_cont_4d(ctx0, g_cumsum, chunk_size, 1, n_chunks, H_v * n_seqs); - ggml_tensor * gcs_j = ggml_cont_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs); + ggml_tensor * gcs_i = ggml_reshape_4d(ctx0, g_cumsum, chunk_size, 1, n_chunks, H_v * n_seqs); + ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs); ggml_tensor * gcs_j_broadcast = ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, n_chunks, H_v * n_seqs); @@ -226,23 +200,23 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_chunking( cb(decay_mask, "decay_mask", il); - decay_mask = ggml_mul(ctx0, decay_mask, chunked_diag_mask); + decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); decay_mask = ggml_exp(ctx0, decay_mask); - decay_mask = ggml_mul(ctx0, decay_mask, chunked_diag_mask); + decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); ggml_tensor * kmulkbeta = ggml_mul_mat(ctx0, k, k_beta); ggml_tensor * k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask); - ggml_tensor * attn = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, chunked_mask)); + ggml_tensor * attn = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, causal_mask)); cb(attn, "attn_pre_solve", il); - ggml_tensor * attn_lower = ggml_mul(ctx0, attn, chunked_mask); - ggml_tensor * lhs = ggml_sub(ctx0, ggml_repeat(ctx0, chunked_identity, attn_lower), attn_lower); + ggml_tensor * attn_lower = ggml_mul(ctx0, attn, causal_mask); + ggml_tensor * lhs = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower); ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false); - attn = ggml_mul(ctx0, lin_solve, chunked_mask); - attn = ggml_add(ctx0, attn, chunked_identity); + attn = ggml_mul(ctx0, lin_solve, causal_mask); + attn = ggml_add(ctx0, attn, identity); cb(attn, "attn_solved", il); @@ -291,7 +265,7 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_chunking( // attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0) attn = ggml_mul_mat(ctx0, k_chunk, q_chunk); attn = ggml_mul(ctx0, attn, decay_mask_chunk); - attn = ggml_mul(ctx0, attn, ggml_add(ctx0, chunked_identity, chunked_mask)); + attn = ggml_mul(ctx0, attn, diag_mask); ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs); @@ -361,23 +335,14 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_chunking( return ggml_concat(ctx0, flat_output, flat_state, 0); } -ggml_tensor * llm_build_qwen3next::build_delta_net_recurrent( +ggml_tensor * llm_build_qwen3next::build_delta_net_autoregressive( ggml_tensor * q, ggml_tensor * k, ggml_tensor * v, ggml_tensor * g, ggml_tensor * beta, ggml_tensor * state, - ggml_tensor * causal_mask, - ggml_tensor * identity, int il) { - GGML_ASSERT(ggml_is_contiguous(q)); - GGML_ASSERT(ggml_is_contiguous(k)); - GGML_ASSERT(ggml_is_contiguous(v)); - GGML_ASSERT(ggml_is_contiguous(g)); - GGML_ASSERT(ggml_is_contiguous(beta)); - GGML_ASSERT(ggml_is_contiguous(state)); - const int64_t S_k = q->ne[0]; const int64_t H_k = q->ne[1]; const int64_t n_tokens = q->ne[2]; @@ -386,6 +351,7 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_recurrent( const int64_t S_v = v->ne[0]; const int64_t H_v = v->ne[1]; + GGML_ASSERT(n_tokens == 1); // This function is optimized for single token processing GGML_ASSERT(v->ne[2] == n_tokens); GGML_ASSERT(k->ne[2] == n_tokens); GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); @@ -397,215 +363,65 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_recurrent( GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case - // TODO: can this ever be false? - const bool use_qk_l2norm = true; + const float eps_norm = hparams.f_norm_rms_eps; - if (use_qk_l2norm) { - const float eps_norm = hparams.f_norm_rms_eps; - - q = ggml_l2_norm(ctx0, q, eps_norm); - k = ggml_l2_norm(ctx0, k, eps_norm); - } + q = ggml_l2_norm(ctx0, q, eps_norm); + k = ggml_l2_norm(ctx0, k, eps_norm); const float scale = 1.0f / sqrtf(S_v); - q = ggml_scale(ctx0, q, scale); - + q = ggml_scale(ctx0, q, scale); beta = ggml_sigmoid(ctx0, beta); - ggml_tensor * causal_diag_mask = ggml_add(ctx0, causal_mask, identity); - cb(q, "q_in", il); cb(k, "k_in", il); cb(v, "v_in", il); cb(beta, "beta_in", il); cb(g, "g_in", il); - q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); - k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); - v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); - g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs); - - beta = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3)); state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs); - cb(q, "q_perm", il); - cb(k, "k_perm", il); - cb(v, "v_perm", il); - cb(beta, "beta_perm", il); - cb(g, "g_perm", il); - cb(state, "state_in", il); + ggml_tensor * g_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, g), 1, 1, H_k, n_seqs); + ggml_tensor * beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_k, n_seqs); - GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs); - GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs); - GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs); - GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs); + // Apply exponential to g_t + g_t = ggml_exp(ctx0, g_t); - ggml_tensor * v_beta = ggml_mul(ctx0, v, beta); - ggml_tensor * k_beta = ggml_mul(ctx0, k, beta); + // Apply the gated delta rule for the single timestep + // last_recurrent_state = last_recurrent_state * g_t + state = ggml_mul(ctx0, state, g_t); - ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g); + // kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) + ggml_tensor * k_t_unsqueezed = ggml_reshape_4d(ctx0, k, 1, S_v, H_v, n_seqs); + ggml_tensor * kv_mem = ggml_mul(ctx0, state, k_t_unsqueezed); + // we need to sum over dim=-2, so we transpose, sum, then transpose again + kv_mem = ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, kv_mem)))); - cb(k_beta, "k_beta", il); - cb(v_beta, "v_beta", il); - cb(g_cumsum, "g_cumsum", il); + // v_t = v.unsqueeze(2) (we insert the singleton dimension after n_seqs and H_v) + ggml_tensor * v_t = ggml_reshape_4d(ctx0, v, S_v, 1, H_v, n_seqs); + // delta = (v_t - kv_mem) * beta_t + ggml_tensor * v_diff = ggml_sub(ctx0, v_t, kv_mem); // both should be [S_v, 1, H_v, n_seqs] + ggml_tensor * delta = ggml_mul(ctx0, v_diff, beta_t); - ggml_tensor * gcs_i = ggml_cont_4d(ctx0, g_cumsum, n_tokens, 1, H_v, n_seqs); // [chunk_size, 1, n_tokens, n_seqs] - ggml_tensor * gcs_j = ggml_cont_4d(ctx0, g_cumsum, 1, n_tokens, H_v, n_seqs); // [1, chunk_size, n_tokens, n_seqs] + // last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta + ggml_tensor * k_t_delta = ggml_mul(ctx0, ggml_repeat_4d(ctx0, k_t_unsqueezed, S_v, S_v, H_v, n_seqs), delta); + state = ggml_add(ctx0, state, k_t_delta); - // Broadcast both tensors to [chunk_size, chunk_size, H_v, n_seqs] - // ggml_tensor * gcs_i_broadcast = - // ggml_repeat_4d(ctx0, gcs_i, GGML_DELTA_NET_CHUNK, GGML_DELTA_NET_CHUNK, num_chunks * H_v, - // n_seqs); // [chunk_size, 1, H_v, n_seqs] -> [chunk_size, chunk_size, H_v, n_seqs] - // Don't need this, this one will get auto-broadcast - ggml_tensor * gcs_j_broadcast = - ggml_repeat_4d(ctx0, gcs_j, n_tokens, n_tokens, H_v, n_seqs); // [1, chunk_size, H_v, n_seqs] -> [chunk_size, chunk_size, H_v, n_seqs] - - ggml_tensor * decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i); - - // Apply lower triangular mask to ensure attention is causal (only past tokens influence current) - decay_mask = ggml_mul(ctx0, decay_mask, causal_diag_mask); - // Apply exponential to get the decay mask values - decay_mask = ggml_exp(ctx0, decay_mask); - // Apply lower triangular mask again to ensure only lower triangular values remain - decay_mask = ggml_mul(ctx0, decay_mask, causal_diag_mask); - - cb(decay_mask, "decay_mask", il); - - // attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) - ggml_tensor * kmulkbeta = ggml_mul_mat(ctx0, k, k_beta); - - cb(kmulkbeta, "kmulkbeta", il); - - ggml_tensor * k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask); - ggml_tensor * attn = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, causal_mask)); - - cb(attn, "attn_pre_rec", il); - - // for i in range(1, chunk_size): - // row = attn[..., i, :i].clone() - // sub = attn[..., :i, :i].clone() - // attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) - // attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) - // - // We reduce this to a linear triangular solve: AX = B, where B = attn, A = I - tril(A) - ggml_tensor * attn_lower = ggml_mul(ctx0, attn, causal_mask); - ggml_tensor * lhs = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower); - - ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false); - attn = ggml_mul(ctx0, lin_solve, causal_mask); - attn = ggml_add(ctx0, attn, identity); - - // value = attn @ v_beta - v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), attn); - - cb(v, "value_beta", il); - - // k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) - ggml_tensor * g_cumsum_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum)); - ggml_tensor * gexp = ggml_exp(ctx0, g_cumsum_t); - - cb(gexp, "g_cum_exp", il); - - ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, gexp); - - cb(kbeta_gexp, "kbeta_gexp", il); - - ggml_tensor * k_cumdecay = - ggml_cont(ctx0, ggml_transpose(ctx0, ggml_mul_mat(ctx0, attn, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gexp))))); - - cb(k_cumdecay, "k_cumdecay", il); - - // attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0) - attn = ggml_mul_mat(ctx0, k, q); - attn = ggml_mul(ctx0, attn, decay_mask); - attn = ggml_mul(ctx0, attn, ggml_add(ctx0, identity, causal_mask)); - - cb(attn, "attn_decay_key", il); - - ggml_tensor * state_t = ggml_cont(ctx0, ggml_transpose(ctx0, state)); - - // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state - ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay); - - cb(v_prime, "v_prime", il); - - // v_new = v_i - v_prime - ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, v, v_prime), v_prime); - - ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new)); - - cb(v_new, "v_new", il); - - // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state - ggml_tensor * q_g_exp = ggml_mul(ctx0, q, gexp); - ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_g_exp); - - cb(attn_inter, "attn_inter", il); - - // core_attn_out[:, :, i] = attn_inter + attn @ v_new - ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn); - - cb(v_attn, "v_attn", il); - - ggml_tensor * core_attn_out = ggml_add(ctx0, attn_inter, v_attn); - - cb(core_attn_out, "core_attn_out", il); - - // g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1) - // g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp() - // key_gdiff = key * g_diff.unsqueeze(-1) - // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new - // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew - - ggml_tensor * g_cum_last = - ggml_cont(ctx0, ggml_view_4d(ctx0, g_cumsum_t, g_cumsum_t->ne[0], 1, g_cumsum_t->ne[2], g_cumsum_t->ne[3], - g_cumsum_t->nb[1], g_cumsum_t->nb[2], g_cumsum_t->nb[3], - g_cumsum_t->nb[0] * (g_cumsum_t->ne[1] - 1))); - - cb(g_cum_last, "g_cum_last", il); - - ggml_tensor * gexp_last = - ggml_reshape_4d(ctx0, ggml_exp(ctx0, g_cum_last), 1, 1, g_cum_last->ne[0] * g_cum_last->ne[2], g_cum_last->ne[3]); - - cb(gexp_last, "gexp_last", il); - - ggml_tensor * g_cum_last_3d = - ggml_reshape_3d(ctx0, g_cum_last, g_cum_last->ne[0], g_cum_last->ne[2], g_cum_last->ne[3]); - - cb(g_cum_last_3d, "g_cum_last_3d", il); - - ggml_tensor * g_cumsum_3d = ggml_reshape_3d(ctx0, g_cumsum, g_cumsum->ne[0], g_cumsum->ne[2], g_cumsum->ne[3]); - - cb(g_cumsum_3d, "g_cumsum_3d", il); - - ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum_3d, g_cum_last_3d)); - - cb(g_diff, "g_diff", il); - - ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff); - - cb(g_diff_exp, "g_diff_exp", il); - - ggml_tensor * key_gdiff = ggml_mul(ctx0, k, - ggml_reshape_4d(ctx0, g_diff_exp, 1, g_diff_exp->ne[0], g_diff_exp->ne[1], - g_diff_exp->ne[2] * g_diff_exp->ne[3])); - - cb(key_gdiff, "key_gdiff", il); - - ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, ggml_cont(ctx0, ggml_transpose(ctx0, key_gdiff))); - - cb(kgdmulvnew, "kgdmulvnew", il); - - state = ggml_add(ctx0, ggml_mul(ctx0, state, gexp_last), kgdmulvnew); + // Compute the attention output + // core_attn_out = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2) + ggml_tensor * q_t_unsqueezed = ggml_reshape_4d(ctx0, q, 1, S_v, H_v, n_seqs); // unsqueeze q_t + ggml_tensor * state_q = ggml_mul(ctx0, state, q_t_unsqueezed); + // again, since it's over dim = -2, transpose, sum, transpose back + ggml_tensor * core_attn_out = + ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, state_q)))); + // core_attn_out should be [S_v, 1, H_v, n_seqs] after this + cb(core_attn_out, "output_tokens", il); cb(state, "new_state", il); - // flatten output - ggml_tensor * flat_output = - ggml_cont_1d(ctx0, ggml_permute(ctx0, core_attn_out, 0, 2, 1, 3), S_v * H_v * n_tokens * n_seqs); - - ggml_tensor * flat_state = ggml_cont_1d(ctx0, state, S_v * S_v * H_v * n_seqs); + // flatten output, no need to permute since n_tokens is 1 so [S_v, 1, H_v, n_seqs] and [S_v, H_v, 1, n_seqs] are equivalent memory-layout wise + ggml_tensor * flat_output = ggml_reshape_1d(ctx0, core_attn_out, S_v * H_v * n_tokens * n_seqs); + ggml_tensor * flat_state = ggml_reshape_1d(ctx0, state, S_v * S_v * H_v * n_seqs); return ggml_concat(ctx0, flat_output, flat_state, 0); } @@ -712,6 +528,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( ggml_tensor * cur, ggml_tensor * causal_mask, ggml_tensor * identity, + ggml_tensor * diag_mask, int il) { const auto * mctx_cur = inp->mctx; @@ -737,11 +554,11 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( cb(mixed_ba, "linear_attn_mixed_ba", il); int64_t qkvz_new_dim = 2 * head_k_dim + 2 * head_v_dim * (num_v_heads / num_k_heads); - ggml_tensor * mixed_qkvz_reshaped = ggml_cont_4d(ctx0, mixed_qkvz, qkvz_new_dim, num_k_heads, n_seq_tokens, n_seqs); + ggml_tensor * mixed_qkvz_reshaped = ggml_reshape_4d(ctx0, mixed_qkvz, qkvz_new_dim, num_k_heads, n_seq_tokens, n_seqs); // Reshape mixed_ba: [batch, seq_len, hidden_size] -> [batch, seq_len, num_k_heads, 2*num_v_heads/num_k_heads] int64_t ba_new_dim = 2 * num_v_heads / num_k_heads; - ggml_tensor * mixed_ba_reshaped = ggml_cont_4d(ctx0, mixed_ba, ba_new_dim, num_k_heads, n_seq_tokens, n_seqs); + ggml_tensor * mixed_ba_reshaped = ggml_reshape_4d(ctx0, mixed_ba, ba_new_dim, num_k_heads, n_seq_tokens, n_seqs); // Split mixed_ba into b and a (beta and alpha parameters) int64_t split_sizes_ba[2] = { @@ -762,8 +579,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( ggml_tensor * beta = ggml_cont_3d(ctx0, b, num_v_heads, n_seq_tokens, n_seqs); ggml_tensor * alpha = ggml_cont_3d(ctx0, a, num_v_heads, n_seq_tokens, n_seqs); - GGML_ASSERT(ggml_nelements(beta) + ggml_nelements(alpha) == ggml_nelements(mixed_ba)); - ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt); ggml_tensor * alpha_softplus = ggml_softplus(ctx0, alpha_biased); cb(alpha_softplus, "a_softplus", il); @@ -799,9 +614,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( (split_sizes_qkvz[0] + split_sizes_qkvz[1] + split_sizes_qkvz[2]) * sizeof(float)); cb(z, "z", il); - GGML_ASSERT(ggml_nelements(query) + ggml_nelements(key) + ggml_nelements(value) + ggml_nelements(z) == - ggml_nelements(mixed_qkvz)); - // After creating query, key, and value_reshaped, reshape each to flatten the head dimensions // query: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs] ggml_tensor * query_flat = ggml_cont_3d(ctx0, query, head_k_dim * num_k_heads, n_seq_tokens, n_seqs); @@ -925,10 +737,13 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( cb(k_conv, "k_conv_predelta", il); cb(v_conv, "v_conv_predelta", il); - // Choose between build_delta_net_chunking and build_delta_net_recurrent based on n_tokens - ggml_tensor * attn_out = n_seq_tokens > CHUNK_SIZE ? - build_delta_net_chunking (q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, il) : - build_delta_net_recurrent(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, il); + // Choose between build_delta_net_chunking, build_delta_net_recurrent, and build_delta_net_autoregressive based on n_tokens + ggml_tensor * attn_out; + if (n_seq_tokens == 1) { + attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il); + } else { + attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il); + } cb(attn_out, "attn_out", il); // The tensors were concatenated 1d, so we need to extract them 1d as well