From 86833eb747d5d5b8be216f7a9ecd365eaecb4cfa Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Wed, 11 Feb 2026 11:47:54 +0530 Subject: [PATCH] repalce qwen3next --- src/models/qwen3next.cpp | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/src/models/qwen3next.cpp b/src/models/qwen3next.cpp index 99b1a76a48..886eb3d66f 100644 --- a/src/models/qwen3next.cpp +++ b/src/models/qwen3next.cpp @@ -781,15 +781,31 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( cb(k_conv, "k_conv_predelta", il); cb(v_conv, "v_conv_predelta", il); - // Choose between build_delta_net_chunking, build_delta_net_recurrent, and build_delta_net_autoregressive based on n_tokens - std::pair attn_out; // pair of (output, new_state) + // Choose between build_delta_net_chunking and fused ggml_gated_delta_net based on n_tokens + ggml_tensor * output; + ggml_tensor * new_state; if (n_seq_tokens == 1) { - attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il); + // Fused op expects state as [S_v*S_v*H, n_seqs] + ggml_tensor * state_2d = ggml_reshape_2d(ctx0, state, head_v_dim * head_v_dim * num_v_heads, n_seqs); + ggml_tensor * result = ggml_gated_delta_net(ctx0, q_conv, k_conv, v_conv, gate, beta, state_2d, + hparams.f_norm_rms_eps); + + // Unpack: attn scores then new state + const int64_t attn_elems = head_v_dim * num_v_heads * n_seq_tokens * n_seqs; + const int64_t state_elems = head_v_dim * head_v_dim * num_v_heads * n_seqs; + + output = ggml_view_4d(ctx0, result, head_v_dim, num_v_heads, n_seq_tokens, n_seqs, + head_v_dim * sizeof(float), + head_v_dim * num_v_heads * sizeof(float), + head_v_dim * num_v_heads * n_seq_tokens * sizeof(float), + 0); + new_state = ggml_view_1d(ctx0, result, state_elems, attn_elems * sizeof(float)); } else { - attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il); + std::pair attn_out; + attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il); + output = attn_out.first; + new_state = attn_out.second; } - ggml_tensor * output = attn_out.first; - ggml_tensor * new_state = attn_out.second; cb(output, "attn_output", il); cb(new_state, "new_state", il);