repalce qwen3next
This commit is contained in:
parent
c7edcf22ec
commit
86833eb747
|
|
@ -781,15 +781,31 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
|
||||||
cb(k_conv, "k_conv_predelta", il);
|
cb(k_conv, "k_conv_predelta", il);
|
||||||
cb(v_conv, "v_conv_predelta", il);
|
cb(v_conv, "v_conv_predelta", il);
|
||||||
|
|
||||||
// Choose between build_delta_net_chunking, build_delta_net_recurrent, and build_delta_net_autoregressive based on n_tokens
|
// Choose between build_delta_net_chunking and fused ggml_gated_delta_net based on n_tokens
|
||||||
std::pair<ggml_tensor *, ggml_tensor *> attn_out; // pair of (output, new_state)
|
ggml_tensor * output;
|
||||||
|
ggml_tensor * new_state;
|
||||||
if (n_seq_tokens == 1) {
|
if (n_seq_tokens == 1) {
|
||||||
attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il);
|
// Fused op expects state as [S_v*S_v*H, n_seqs]
|
||||||
|
ggml_tensor * state_2d = ggml_reshape_2d(ctx0, state, head_v_dim * head_v_dim * num_v_heads, n_seqs);
|
||||||
|
ggml_tensor * result = ggml_gated_delta_net(ctx0, q_conv, k_conv, v_conv, gate, beta, state_2d,
|
||||||
|
hparams.f_norm_rms_eps);
|
||||||
|
|
||||||
|
// Unpack: attn scores then new state
|
||||||
|
const int64_t attn_elems = head_v_dim * num_v_heads * n_seq_tokens * n_seqs;
|
||||||
|
const int64_t state_elems = head_v_dim * head_v_dim * num_v_heads * n_seqs;
|
||||||
|
|
||||||
|
output = ggml_view_4d(ctx0, result, head_v_dim, num_v_heads, n_seq_tokens, n_seqs,
|
||||||
|
head_v_dim * sizeof(float),
|
||||||
|
head_v_dim * num_v_heads * sizeof(float),
|
||||||
|
head_v_dim * num_v_heads * n_seq_tokens * sizeof(float),
|
||||||
|
0);
|
||||||
|
new_state = ggml_view_1d(ctx0, result, state_elems, attn_elems * sizeof(float));
|
||||||
} else {
|
} else {
|
||||||
attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il);
|
std::pair<ggml_tensor *, ggml_tensor *> attn_out;
|
||||||
|
attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il);
|
||||||
|
output = attn_out.first;
|
||||||
|
new_state = attn_out.second;
|
||||||
}
|
}
|
||||||
ggml_tensor * output = attn_out.first;
|
|
||||||
ggml_tensor * new_state = attn_out.second;
|
|
||||||
cb(output, "attn_output", il);
|
cb(output, "attn_output", il);
|
||||||
cb(new_state, "new_state", il);
|
cb(new_state, "new_state", il);
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue