This commit is contained in:
Georgi Gerganov 2026-02-06 07:56:44 -06:00 committed by GitHub
commit 54bb2345b3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 54 additions and 59 deletions

View File

@ -117,7 +117,7 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_chu
GGML_ASSERT(k->ne[2] == n_tokens);
GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs);
GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v && state->ne[2] == H_v && state->ne[3] == n_seqs);
GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
@ -141,25 +141,23 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_chu
cb(beta, "beta_in", il);
cb(g, "g_in", il);
q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs);
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
k = ggml_permute(ctx0, k, 0, 2, 1, 3);
v = ggml_permute(ctx0, v, 0, 2, 1, 3);
g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_v, n_seqs);
beta = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3));
state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs);
beta = ggml_permute(ctx0, beta, 2, 0, 1, 3);
cb(q, "q_perm", il);
cb(k, "k_perm", il);
cb(v, "v_perm", il);
cb(beta, "beta_perm", il);
cb(g, "g_perm", il);
cb(state, "state_in", il);
GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs);
GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs);
GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs);
GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs);
GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_v && v->ne[3] == n_seqs);
GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_v && beta->ne[0] == 1 && beta->ne[3] == n_seqs);
// Do padding
const int64_t chunk_size = CHUNK_SIZE;
@ -180,19 +178,19 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_chu
cb(g, "g_pad", il);
ggml_tensor * v_beta = ggml_mul(ctx0, v, beta);
ggml_tensor * k_beta = ggml_mul(ctx0, k, beta);
ggml_tensor * k_beta = ggml_mul(ctx0, ggml_repeat_4d(ctx0, beta, k->ne[0], beta->ne[1], beta->ne[2], beta->ne[3]), k);
cb(v_beta, "v_beta", il);
cb(k_beta, "k_beta", il);
q = ggml_reshape_4d(ctx0, q, S_k, chunk_size, n_chunks, H_k * n_seqs);
k = ggml_reshape_4d(ctx0, k, S_k, chunk_size, n_chunks, H_k * n_seqs);
k_beta = ggml_reshape_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, H_k * n_seqs);
k_beta = ggml_reshape_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, H_v * n_seqs);
v = ggml_reshape_4d(ctx0, v, S_v, chunk_size, n_chunks, H_v * n_seqs);
v_beta = ggml_reshape_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, H_v * n_seqs);
g = ggml_reshape_4d(ctx0, g, chunk_size, 1, n_chunks, H_k * n_seqs);
beta = ggml_reshape_4d(ctx0, beta, 1, chunk_size, n_chunks, H_k * n_seqs);
g = ggml_reshape_4d(ctx0, g, chunk_size, 1, n_chunks, H_v * n_seqs);
beta = ggml_reshape_4d(ctx0, beta, 1, chunk_size, n_chunks, H_v * n_seqs);
ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g);
cb(g_cumsum, "g_cumsum", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs)
@ -237,8 +235,8 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_chu
cb(k_cumdecay, "k_cumdecay", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
ggml_tensor * attn_kq = ggml_mul_mat(ctx0, k, q);
attn_kq = ggml_mul(ctx0, attn_kq, decay_mask);
attn_kq = ggml_mul(ctx0, attn_kq, diag_mask);
attn_kq = ggml_mul(ctx0, decay_mask, attn_kq);
attn_kq = ggml_mul(ctx0, attn_kq, diag_mask);
cb(attn_kq, "attn_kq", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
@ -268,17 +266,12 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_chu
ggml_tensor * g_diff_exp_t = ggml_reshape_4d(ctx0, g_diff_exp,
1, chunk_size, n_chunks, g_diff_exp->ne[3]);
ggml_tensor * key_gdiff = ggml_mul(ctx0, k, g_diff_exp_t);
ggml_tensor * key_gdiff = ggml_mul(ctx0, ggml_repeat_4d(ctx0, g_diff_exp_t, k->ne[0], g_diff_exp_t->ne[1], g_diff_exp_t->ne[2], g_diff_exp_t->ne[3]), k);
cb(key_gdiff, "key_gdiff", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs)
ggml_tensor * key_gdiff_t = ggml_cont(ctx0, ggml_transpose(ctx0, key_gdiff));
cb(key_gdiff_t, "key_gdiff_t", il); // shape: (chunk_size, S_k, n_chunks, H_v * n_seqs)
// state to be updated per chunk
ggml_tensor * new_state = state; // ggml_dup(ctx0, state);
cb(new_state, "new_state", il); // shape: (S_v, S_v, H_v, n_seqs)
// shape after loop of chunks: (S_v, chunk_size, n_chunks, H_v * n_seqs)
ggml_tensor * core_attn_out = nullptr;
@ -300,7 +293,7 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_chu
ggml_tensor * attn_chunk = get_slice_2d(ctx0, attn_kq, chunk);
cb(attn_chunk, "attn_chunk", il);
ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs);
ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs);
// v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk);
@ -312,7 +305,7 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_chu
cb(v_new, "v_new_chunk", il);
// attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
ggml_tensor * q_g_exp = ggml_mul(ctx0, q_chunk, gexp_chunk);
ggml_tensor * q_g_exp = ggml_mul(ctx0, ggml_repeat_4d(ctx0, gexp_chunk, q_chunk->ne[0], gexp_chunk->ne[1], gexp_chunk->ne[2], gexp_chunk->ne[3]), q_chunk);
ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_g_exp);
cb(attn_inter, "attn_inter_chunk", il);
@ -334,8 +327,8 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_chu
// last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
ggml_tensor * gexp_last_chunk = ggml_cont(ctx0, get_slice_2d(ctx0, g_last_exp, chunk));
new_state = ggml_add(ctx0,
ggml_mul(ctx0, new_state, ggml_reshape_4d(ctx0, gexp_last_chunk, gexp_last_chunk->ne[0], gexp_last_chunk->ne[1], H_v, n_seqs)),
state = ggml_add(ctx0,
ggml_mul(ctx0, state, ggml_reshape_4d(ctx0, gexp_last_chunk, gexp_last_chunk->ne[0], gexp_last_chunk->ne[1], H_v, n_seqs)),
ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs));
}
@ -345,14 +338,14 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_chu
ggml_row_size(core_attn_out->type, S_v),
ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks),
ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks * H_v), 0);
output_tokens = ggml_cont(ctx0, output_tokens);
cb(output_tokens, "output_tokens", il);
// permute back to (S_v, H_v, n_tokens, n_seqs)
output_tokens = ggml_permute(ctx0, output_tokens, 0, 2, 1, 3);
output_tokens = ggml_cont(ctx0, output_tokens);
return {output_tokens, new_state};
return {output_tokens, state};
}
std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_autoregressive(
@ -376,33 +369,35 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_aut
GGML_ASSERT(k->ne[2] == n_tokens);
GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs);
GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v && state->ne[2] == H_v && state->ne[3] == n_seqs);
GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case
//GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case
const float eps_norm = hparams.f_norm_rms_eps;
q = ggml_l2_norm(ctx0, q, eps_norm);
k = ggml_l2_norm(ctx0, k, eps_norm);
const float scale = 1.0f / sqrtf(S_v);
const float scale = 1.0f / sqrtf(S_k);
q = ggml_scale(ctx0, q, scale);
beta = ggml_sigmoid(ctx0, beta);
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
k = ggml_permute(ctx0, k, 0, 2, 1, 3);
v = ggml_permute(ctx0, v, 1, 2, 0, 3);
cb(q, "q_in", il);
cb(k, "k_in", il);
cb(v, "v_in", il);
cb(beta, "beta_in", il);
cb(g, "g_in", il);
state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs);
ggml_tensor * g_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, g), 1, 1, H_k, n_seqs);
ggml_tensor * beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_k, n_seqs);
ggml_tensor * g_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, g), 1, 1, H_v, n_seqs);
ggml_tensor * beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_v, n_seqs);
// Apply exponential to g_t
g_t = ggml_exp(ctx0, g_t);
@ -412,28 +407,26 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_aut
state = ggml_mul(ctx0, state, g_t);
// kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2)
ggml_tensor * k_t_unsqueezed = ggml_reshape_4d(ctx0, k, 1, S_v, H_v, n_seqs);
ggml_tensor * kv_mem = ggml_mul(ctx0, state, k_t_unsqueezed);
// we need to sum over dim=-2, so we transpose, sum, then transpose again
kv_mem = ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, kv_mem))));
state = ggml_cont(ctx0, ggml_transpose(ctx0, state));
ggml_tensor * kv_mem = ggml_mul(ctx0, state, k);
kv_mem = ggml_sum_rows(ctx0, kv_mem);
// v_t = v.unsqueeze(2) (we insert the singleton dimension after n_seqs and H_v)
ggml_tensor * v_t = ggml_reshape_4d(ctx0, v, S_v, 1, H_v, n_seqs);
// delta = (v_t - kv_mem) * beta_t
ggml_tensor * v_diff = ggml_sub(ctx0, v_t, kv_mem); // both should be [S_v, 1, H_v, n_seqs]
ggml_tensor * v_diff = ggml_sub(ctx0, v, kv_mem); // both should be [1, S_v, H_v, n_seqs]
ggml_tensor * delta = ggml_mul(ctx0, v_diff, beta_t);
// last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta
ggml_tensor * k_t_delta = ggml_mul(ctx0, ggml_repeat_4d(ctx0, k_t_unsqueezed, S_v, S_v, H_v, n_seqs), delta);
state = ggml_add(ctx0, state, k_t_delta);
ggml_tensor * k_t_delta = ggml_mul(ctx0, ggml_repeat_4d(ctx0, k, S_v, S_v, H_v, n_seqs), delta);
state = ggml_add(ctx0, state, k_t_delta);
// Compute the attention output
// core_attn_out = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2)
ggml_tensor * q_t_unsqueezed = ggml_reshape_4d(ctx0, q, 1, S_v, H_v, n_seqs); // unsqueeze q_t
ggml_tensor * state_q = ggml_mul(ctx0, state, q_t_unsqueezed);
// again, since it's over dim = -2, transpose, sum, transpose back
ggml_tensor * core_attn_out =
ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, state_q))));
ggml_tensor * state_q = ggml_mul(ctx0, state, q);
ggml_tensor * core_attn_out = ggml_sum_rows(ctx0, state_q);
core_attn_out = ggml_transpose(ctx0, core_attn_out);
state = ggml_transpose(ctx0, state);
// core_attn_out should be [S_v, 1, H_v, n_seqs] after this
cb(core_attn_out, "output_tokens", il);
@ -734,25 +727,27 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
int64_t nb1_qkv = ggml_row_size(conv_qkv_mix->type, qkv_dim);
// Extract the convolved Q, K, V from conv_output
ggml_tensor * q_conv =
ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv, 0);
cb(q_conv, "q_conv", il);
ggml_tensor * k_conv =
ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv,
ggml_tensor * q_conv = ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv, 0);
ggml_tensor * k_conv = ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv,
head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
ggml_tensor * v_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_v_dim, num_v_heads, n_seq_tokens, n_seqs,
ggml_row_size(conv_qkv_mix->type, head_v_dim),
nb1_qkv,
nb1_qkv * n_seq_tokens,
ggml_row_size(conv_qkv_mix->type, 2 * head_k_dim * num_k_heads));
cb(q_conv, "q_conv", il);
cb(k_conv, "k_conv", il);
ggml_tensor * v_conv =
ggml_view_2d(ctx0, conv_qkv_mix, head_v_dim * num_v_heads, n_seq_tokens * n_seqs, nb1_qkv,
2 * head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
cb(v_conv, "v_conv", il);
// Unsqueeze them
q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
//v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim * num_v_heads, 1, n_seqs);
state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs);
cb(state, "state_predelta", il);
// if head keys and value keys are different, repeat to force tensors into matching shapes
@ -818,7 +813,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
cb(cur, "linear_attn_out", il);
// Reshape back to original dimensions
cur = ggml_cont_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs);
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs);
return cur;
}