repalce qwen3next

This commit is contained in:
Aman Gupta 2026-02-11 11:47:54 +05:30
parent c7edcf22ec
commit 86833eb747
1 changed files with 22 additions and 6 deletions

View File

@ -781,15 +781,31 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
cb(k_conv, "k_conv_predelta", il);
cb(v_conv, "v_conv_predelta", il);
// Choose between build_delta_net_chunking, build_delta_net_recurrent, and build_delta_net_autoregressive based on n_tokens
std::pair<ggml_tensor *, ggml_tensor *> attn_out; // pair of (output, new_state)
// Choose between build_delta_net_chunking and fused ggml_gated_delta_net based on n_tokens
ggml_tensor * output;
ggml_tensor * new_state;
if (n_seq_tokens == 1) {
attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il);
// Fused op expects state as [S_v*S_v*H, n_seqs]
ggml_tensor * state_2d = ggml_reshape_2d(ctx0, state, head_v_dim * head_v_dim * num_v_heads, n_seqs);
ggml_tensor * result = ggml_gated_delta_net(ctx0, q_conv, k_conv, v_conv, gate, beta, state_2d,
hparams.f_norm_rms_eps);
// Unpack: attn scores then new state
const int64_t attn_elems = head_v_dim * num_v_heads * n_seq_tokens * n_seqs;
const int64_t state_elems = head_v_dim * head_v_dim * num_v_heads * n_seqs;
output = ggml_view_4d(ctx0, result, head_v_dim, num_v_heads, n_seq_tokens, n_seqs,
head_v_dim * sizeof(float),
head_v_dim * num_v_heads * sizeof(float),
head_v_dim * num_v_heads * n_seq_tokens * sizeof(float),
0);
new_state = ggml_view_1d(ctx0, result, state_elems, attn_elems * sizeof(float));
} else {
attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il);
std::pair<ggml_tensor *, ggml_tensor *> attn_out;
attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il);
output = attn_out.first;
new_state = attn_out.second;
}
ggml_tensor * output = attn_out.first;
ggml_tensor * new_state = attn_out.second;
cb(output, "attn_output", il);
cb(new_state, "new_state", il);