model: try to improve Qwen3 Next (#18683)

* qwen3next: simplify qkvz projection

* use ggml_swiglu_split

* revert swiglu_split, but remove redundant repeat()

* fix missing reshape

* rm 2 redundant transposes

* move mul_mat(k,q) to outside of chunking

* rm redundant cont

* improve g_cs_chunk

* add comments about no cont

* use std::pair instead of ggml_concat

* vectorize key_gdiff calculation

* rm unused tensor

* avoid ggml_concat inside loop

* bring back ggml_concat as it may not work on other backend

* nits
This commit is contained in:
Xuan-Son Nguyen 2026-01-11 12:53:33 +01:00 committed by GitHub
parent 79456a690a
commit 506bb6e010
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 245 additions and 187 deletions

View File

@ -4367,7 +4367,37 @@ class Qwen3NextModel(Qwen2MoeModel):
elif name.endswith("norm.weight") and not name.endswith("linear_attn.norm.weight"):
data_torch = data_torch + 1
yield from super().modify_tensors(data_torch, name, bid)
if "in_proj_qkvz.weight" in name:
# original order: [q, k, v, z] * head_count
# corrected order: [q * head_count, k * head_count, v * head_count, z * head_count]
head_k_dim = self.hparams["linear_key_head_dim"]
head_v_dim = self.hparams["linear_value_head_dim"]
num_v_heads = self.hparams["linear_num_value_heads"]
num_k_heads = self.hparams["linear_num_key_heads"]
hidden_size = self.hparams["hidden_size"]
split_arg_list_qkvz = [
head_k_dim, # q partition
head_k_dim, # k partition
(num_v_heads // num_k_heads * head_v_dim), # v partition
(num_v_heads // num_k_heads * head_v_dim), # z partition
]
# view as (n_embd, head_count, [q+k+v+z])
data_torch = data_torch.permute(1, 0).contiguous()
data_torch = data_torch.view(-1, num_k_heads, sum(split_arg_list_qkvz))
# split into q, k, v, z
q, k, v, z = torch.split(data_torch, split_arg_list_qkvz, dim=-1)
# flatten dim + head_count
q = q.contiguous().view(hidden_size, -1)
k = k.contiguous().view(hidden_size, -1)
v = v.contiguous().view(hidden_size, -1)
z = z.contiguous().view(hidden_size, -1)
# stack back
qkv = torch.cat([q, k, v], dim=-1).permute(1, 0).contiguous()
z = z.permute(1, 0).contiguous()
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_QKV, bid, ".weight"), qkv)
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_GATE, bid, ".weight"), z)
else:
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("RND1")

View File

@ -1738,6 +1738,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_POST_NORM,
MODEL_TENSOR.ATTN_GATE,
MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_GATE_INP_SHEXP,
MODEL_TENSOR.FFN_UP_SHEXP,

View File

@ -950,6 +950,8 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
LLM_TENSOR_ATTN_K_NORM,
LLM_TENSOR_ATTN_V,
LLM_TENSOR_ATTN_OUT,
LLM_TENSOR_ATTN_QKV,
LLM_TENSOR_ATTN_GATE,
LLM_TENSOR_FFN_NORM,
LLM_TENSOR_FFN_GATE_INP,
LLM_TENSOR_FFN_GATE_EXPS,

View File

@ -6763,7 +6763,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
} else {
// Linear attention (gated delta net) specific tensors
// Create tensors with calculated dimensions
layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), { n_embd, qkvz_dim }, 0);
// note: ssm_in is used by legacy GGUF
layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), { n_embd, qkvz_dim }, TENSOR_NOT_REQUIRED);
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED);
layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED);
layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0);
layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0);
layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0);

View File

@ -466,7 +466,8 @@ private:
ggml_tensor * cur,
int il);
ggml_tensor * build_delta_net_chunking(
// returns pair of output and new state
std::pair<ggml_tensor *, ggml_tensor *> build_delta_net_chunking(
ggml_tensor * q,
ggml_tensor * k,
ggml_tensor * v,
@ -478,7 +479,8 @@ private:
ggml_tensor * diag_mask,
int il);
ggml_tensor * build_delta_net_autoregressive(
// returns pair of output and new state
std::pair<ggml_tensor *, ggml_tensor *> build_delta_net_autoregressive(
ggml_tensor * q,
ggml_tensor * k,
ggml_tensor * v,
@ -493,6 +495,11 @@ private:
ggml_tensor * gate,
int layer);
// returns pair of qkv, z
std::pair<ggml_tensor *, ggml_tensor *> build_qkvz(
ggml_tensor * input,
int il);
const llama_model & model;
};

View File

@ -86,7 +86,15 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr
ggml_build_forward_expand(gf, cur);
}
ggml_tensor * llm_build_qwen3next::build_delta_net_chunking(
// utility to get one slice from the third dimension
// input dim: [x, y, c, b]
// output dim: [x, y, 1, b]
static ggml_tensor * get_slice_2d(ggml_context * ctx0, ggml_tensor * t, int64_t c) {
return ggml_view_4d(ctx0, t, t->ne[0], t->ne[1], 1, t->ne[3],
t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c);
}
std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_chunking(
ggml_tensor * q,
ggml_tensor * k,
ggml_tensor * v,
@ -187,18 +195,16 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_chunking(
beta = ggml_reshape_4d(ctx0, beta, 1, chunk_size, n_chunks, H_k * n_seqs);
ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g);
cb(g_cumsum, "g_cumsum", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs)
cb(g_cumsum, "g_cumsum", il);
ggml_tensor * gcs_i = ggml_reshape_4d(ctx0, g_cumsum, chunk_size, 1, n_chunks, H_v * n_seqs);
ggml_tensor * gcs_i = g_cumsum; // ggml_reshape_4d(ctx0, g_cumsum, chunk_size, 1, n_chunks, H_v * n_seqs);
ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs);
ggml_tensor * gcs_j_broadcast =
ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, n_chunks, H_v * n_seqs);
ggml_tensor * decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i);
cb(decay_mask, "decay_mask", il);
cb(decay_mask, "decay_mask", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
decay_mask = ggml_exp(ctx0, decay_mask);
@ -208,8 +214,7 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_chunking(
ggml_tensor * k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask);
ggml_tensor * attn = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, causal_mask));
cb(attn, "attn_pre_solve", il);
cb(attn, "attn_pre_solve", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
ggml_tensor * attn_lower = ggml_mul(ctx0, attn, causal_mask);
ggml_tensor * lhs = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower);
@ -217,8 +222,7 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_chunking(
ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false);
attn = ggml_mul(ctx0, lin_solve, causal_mask);
attn = ggml_add(ctx0, attn, identity);
cb(attn, "attn_solved", il);
cb(attn, "attn_solved", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), attn);
@ -226,116 +230,126 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_chunking(
ggml_tensor * gexp = ggml_exp(ctx0, g_cumsum_t);
ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, gexp);
cb(kbeta_gexp, "kbeta_gexp", il);
cb(kbeta_gexp, "kbeta_gexp", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs)
ggml_tensor * k_cumdecay =
ggml_cont(ctx0, ggml_transpose(ctx0, ggml_mul_mat(ctx0, attn, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gexp)))));
cb(k_cumdecay, "k_cumdecay", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
cb(k_cumdecay, "k_cumdecay", il);
ggml_tensor * attn_kq = ggml_mul_mat(ctx0, k, q);
attn_kq = ggml_mul(ctx0, attn_kq, decay_mask);
attn_kq = ggml_mul(ctx0, attn_kq, diag_mask);
cb(attn_kq, "attn_kq", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
// vectorized calculation of key_gdiff
// improved from the chunked version:
// g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1)
// g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp()
// key_gdiff = key * g_diff.unsqueeze(-1)
// kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
// last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
// get last element in g_cumsum along chunk_size dimension (ne0)
// example: [[x, y, z, ..., last], ...] -> [[last], ...]
ggml_tensor * g_last = ggml_view_4d(ctx0, g_cumsum, 1, 1, g_cumsum->ne[2], g_cumsum->ne[3],
g_cumsum->nb[1], g_cumsum->nb[2], g_cumsum->nb[3],
(g_cumsum->ne[0] - 1) * ggml_element_size(g_cumsum));
g_last = ggml_cont(ctx0, g_last);
cb(g_last, "g_last", il); // shape: (1, 1, n_chunks, H_v * n_seqs)
ggml_tensor * g_last_exp = ggml_exp(ctx0, g_last);
cb(g_last_exp, "g_last_exp", il); // shape: (1, 1, n_chunks, H_v * n_seqs)
ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum, g_last));
cb(g_diff, "g_diff", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs)
ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff);
ggml_tensor * key_gdiff = ggml_mul(ctx0, k, g_diff_exp);
cb(key_gdiff, "key_gdiff", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs)
// state to be updated per chunk
ggml_tensor * new_state = state; // ggml_dup(ctx0, state);
cb(new_state, "new_state", il); // shape: (S_v, S_v, H_v, n_seqs)
// shape after loop of chunks: (S_v, chunk_size, n_chunks, H_v * n_seqs)
ggml_tensor * core_attn_out = nullptr;
ggml_tensor * new_state = ggml_dup(ctx0, state);
cb(new_state, "new_state", il);
for (int64_t chunk = 0; chunk < n_chunks; chunk++) {
auto chunkify = [=](ggml_tensor * t) {
return ggml_cont(ctx0, ggml_view_4d(ctx0, t, t->ne[0], chunk_size, 1, t->ne[3],
t->nb[1], t->nb[2], t->nb[3], t->nb[2] * chunk));
};
// shape: (S_k, chunk_size, 1, H_k * n_seqs)
ggml_tensor * q_chunk = get_slice_2d(ctx0, q, chunk); // (no cont), next op: ggml_mul
auto chunkify_g = [=](ggml_tensor * t) {
return ggml_cont(ctx0, ggml_view_4d(ctx0, t, chunk_size, t->ne[1], 1, t->ne[3],
t->nb[1], t->nb[2], t->nb[3], t->nb[2] * chunk));
};
// shape: (S_v, chunk_size, 1, H_v * n_seqs)
ggml_tensor * v_chunk = get_slice_2d(ctx0, v, chunk); // (no cont), next op: ggml_repeat
ggml_tensor * k_chunk = chunkify(k);
ggml_tensor * q_chunk = chunkify(q);
ggml_tensor * v_chunk = chunkify(v);
// shape: (chunk_size, 1, n_chunks, H_v * n_seqs)
ggml_tensor * gexp_chunk = get_slice_2d(ctx0, gexp, chunk); // (no cont), next op: ggml_mul
ggml_tensor * g_cs_chunk = chunkify_g(g_cumsum);
ggml_tensor * g_cs_chunk_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cs_chunk));
ggml_tensor * decay_mask_chunk = chunkify(decay_mask);
ggml_tensor * k_cumdecay_chunk = chunkify(k_cumdecay);
ggml_tensor * gexp_chunk = ggml_exp(ctx0, g_cs_chunk_t);
// shape: (chunk_size, 1, H_v * n_seqs)
ggml_tensor * k_cumdecay_chunk = get_slice_2d(ctx0, k_cumdecay, chunk); // (no cont), next op: ggml_mul_mat
// attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
attn = ggml_mul_mat(ctx0, k_chunk, q_chunk);
attn = ggml_mul(ctx0, attn, decay_mask_chunk);
attn = ggml_mul(ctx0, attn, diag_mask);
// replaced by precomputed attn_kq
ggml_tensor * attn_chunk = get_slice_2d(ctx0, attn_kq, chunk);
cb(attn_chunk, "attn_chunk", il);
ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs);
// v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk);
cb(v_prime, "v_prime_chunk", il); // shape: (S_v, 1, H_v * n_seqs)
// v_new = v_i - v_prime
ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, v_chunk, v_prime), v_prime);
ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new));
cb(v_new, "v_new_chunk", il);
// attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
ggml_tensor * q_g_exp = ggml_mul(ctx0, q_chunk, gexp_chunk);
ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_g_exp);
cb(attn_inter, "attn_inter_chunk", il);
// core_attn_out[:, :, i] = attn_inter + attn @ v_new
ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn);
ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn_chunk);
cb(v_attn, "v_attn_chunk", il);
ggml_tensor * core_attn_out_chunk = ggml_add(ctx0, attn_inter, v_attn);
cb(core_attn_out_chunk, "core_attn_out_chunk", il); // shape: (S_v, chunk_size, 1, H_v * n_seqs)
core_attn_out = core_attn_out == nullptr ? core_attn_out_chunk : ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 1);
core_attn_out = core_attn_out == nullptr
? core_attn_out_chunk
: ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 2);
// g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1)
// g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp()
// key_gdiff = key * g_diff.unsqueeze(-1)
// kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
ggml_tensor * k_gdiff = ggml_cont(ctx0, get_slice_2d(ctx0, key_gdiff, chunk));
//ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, k_gdiff, v_new); // this is slower on metal, why?
ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, ggml_cont(ctx0, ggml_transpose(ctx0, k_gdiff)));
// last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
ggml_tensor * g_cum_last =
ggml_cont(ctx0, ggml_view_4d(ctx0, g_cs_chunk_t, g_cs_chunk_t->ne[0], 1, g_cs_chunk_t->ne[2], g_cs_chunk_t->ne[3],
g_cs_chunk_t->nb[1], g_cs_chunk_t->nb[2], g_cs_chunk_t->nb[3],
g_cs_chunk_t->nb[0] * (g_cs_chunk_t->ne[1] - 1)));
ggml_tensor * gexp_last =
ggml_reshape_4d(ctx0, ggml_exp(ctx0, g_cum_last), 1, 1, g_cum_last->ne[0] * g_cum_last->ne[2], g_cum_last->ne[3]);
ggml_tensor * g_cum_last_3d =
ggml_reshape_3d(ctx0, g_cum_last, g_cum_last->ne[0], g_cum_last->ne[2], g_cum_last->ne[3]);
ggml_tensor * g_cumsum_3d = ggml_reshape_3d(ctx0, g_cs_chunk, g_cs_chunk->ne[0], g_cs_chunk->ne[2], g_cs_chunk->ne[3]);
ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum_3d, g_cum_last_3d));
ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff);
ggml_tensor * key_gdiff = ggml_mul(ctx0, k_chunk,
ggml_reshape_4d(ctx0, g_diff_exp, 1, g_diff_exp->ne[0], g_diff_exp->ne[1],
g_diff_exp->ne[2] * g_diff_exp->ne[3]));
ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, ggml_cont(ctx0, ggml_transpose(ctx0, key_gdiff)));
ggml_tensor * gexp_last_chunk = ggml_cont(ctx0, get_slice_2d(ctx0, g_last_exp, chunk));
new_state = ggml_add(ctx0,
ggml_mul(ctx0, new_state, ggml_reshape_4d(ctx0, gexp_last, gexp_last->ne[0], gexp_last->ne[1], H_v, n_seqs)),
ggml_mul(ctx0, new_state, ggml_reshape_4d(ctx0, gexp_last_chunk, gexp_last_chunk->ne[0], gexp_last_chunk->ne[1], H_v, n_seqs)),
ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs));
}
core_attn_out = ggml_cont_4d(ctx0, core_attn_out, S_v, chunk_size * n_chunks, H_v, n_seqs);
ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out, S_v, n_tokens, H_v, n_seqs, core_attn_out->nb[1], core_attn_out->nb[2], core_attn_out->nb[3], 0);
// truncate padded tokens
ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out,
S_v, n_tokens, H_v, n_seqs,
ggml_row_size(core_attn_out->type, S_v),
ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks),
ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks * H_v), 0);
output_tokens = ggml_cont(ctx0, output_tokens);
cb(output_tokens, "output_tokens", il);
// flatten output
ggml_tensor * flat_output =
ggml_cont_1d(ctx0, ggml_permute(ctx0, output_tokens, 0, 2, 1, 3), S_v * H_v * n_tokens * n_seqs);
// permute back to (S_v, H_v, n_tokens, n_seqs)
output_tokens = ggml_permute(ctx0, output_tokens, 0, 2, 1, 3);
output_tokens = ggml_cont(ctx0, output_tokens);
ggml_tensor * flat_state = ggml_cont_1d(ctx0, new_state, S_v * S_v * H_v * n_seqs);
return ggml_concat(ctx0, flat_output, flat_state, 0);
return {output_tokens, new_state};
}
ggml_tensor * llm_build_qwen3next::build_delta_net_autoregressive(
std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_autoregressive(
ggml_tensor * q,
ggml_tensor * k,
ggml_tensor * v,
@ -419,11 +433,7 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_autoregressive(
cb(core_attn_out, "output_tokens", il);
cb(state, "new_state", il);
// flatten output, no need to permute since n_tokens is 1 so [S_v, 1, H_v, n_seqs] and [S_v, H_v, 1, n_seqs] are equivalent memory-layout wise
ggml_tensor * flat_output = ggml_reshape_1d(ctx0, core_attn_out, S_v * H_v * n_tokens * n_seqs);
ggml_tensor * flat_state = ggml_reshape_1d(ctx0, state, S_v * S_v * H_v * n_seqs);
return ggml_concat(ctx0, flat_output, flat_state, 0);
return {core_attn_out, state};
}
ggml_tensor * llm_build_qwen3next::build_norm_gated(
@ -523,6 +533,87 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn(
return cur;
}
std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_qkvz(
ggml_tensor * input,
int il) {
const int64_t d_inner = hparams.ssm_d_inner;
const int64_t n_seqs = ubatch.n_seqs;
const int64_t head_k_dim = hparams.ssm_d_state;
const int64_t num_k_heads = hparams.ssm_n_group;
const int64_t num_v_heads = hparams.ssm_dt_rank;
const int64_t head_v_dim = d_inner / num_v_heads;
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
if (model.layers[il].wqkv) {
// optimized path
ggml_tensor * qkv_mixed = build_lora_mm(model.layers[il].wqkv, input);
qkv_mixed = ggml_reshape_3d(ctx0, qkv_mixed, qkv_mixed->ne[0], n_seq_tokens, n_seqs);
cb(qkv_mixed, "linear_attn_qkv_mixed", il);
ggml_tensor * z = build_lora_mm(model.layers[il].wqkv_gate, input);
cb(z, "z", il);
return { qkv_mixed, z };
} else {
// legacy (slower) path
ggml_tensor * mixed_qkvz = build_lora_mm(model.layers[il].ssm_in, input);
cb(mixed_qkvz, "linear_attn_mixed_qkvz", il);
int64_t qkvz_new_dim = 2 * head_k_dim + 2 * head_v_dim * (num_v_heads / num_k_heads);
ggml_tensor * mixed_qkvz_reshaped = ggml_reshape_4d(ctx0, mixed_qkvz, qkvz_new_dim, num_k_heads, n_seq_tokens, n_seqs);
// Split mixed_qkvz into query, key, value, z
int64_t split_sizes_qkvz[4] = {
head_k_dim, // query size
head_k_dim, // key size
head_v_dim * num_v_heads / num_k_heads, // value size
head_v_dim * num_v_heads / num_k_heads // z size
};
ggml_tensor * query =
ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[0], num_k_heads, n_seq_tokens, n_seqs,
mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], 0);
cb(query, "q", il);
ggml_tensor * key = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[1], num_k_heads, n_seq_tokens, n_seqs,
mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
split_sizes_qkvz[0] * ggml_element_size(mixed_qkvz_reshaped));
cb(key, "k", il);
ggml_tensor * value =
ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[2], num_k_heads, n_seq_tokens, n_seqs,
mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
(split_sizes_qkvz[0] + split_sizes_qkvz[1]) * ggml_element_size(mixed_qkvz_reshaped));
cb(value, "v", il);
ggml_tensor * z = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[3], num_k_heads, n_seq_tokens, n_seqs,
mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
(split_sizes_qkvz[0] + split_sizes_qkvz[1] + split_sizes_qkvz[2]) * ggml_element_size(mixed_qkvz_reshaped));
cb(z, "z", il);
// After creating query, key, and value_reshaped, reshape each to flatten the head dimensions
// query: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs]
ggml_tensor * query_flat = ggml_cont_3d(ctx0, query, head_k_dim * num_k_heads, n_seq_tokens, n_seqs);
cb(query_flat, "query_flat", il);
// key: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs]
ggml_tensor * key_flat = ggml_cont_3d(ctx0, key, head_k_dim * num_k_heads, n_seq_tokens, n_seqs);
cb(key_flat, "key_flat", il);
// value_reshaped: [head_v_dim, num_v_heads, n_tokens, n_seqs] -> [head_v_dim * num_v_heads, n_tokens, n_seqs]
ggml_tensor * value_flat = ggml_cont_3d(ctx0, value, head_v_dim * num_v_heads, n_seq_tokens, n_seqs);
cb(value_flat, "value_flat", il);
// Now concatenate along the feature dimension (dim 0) to get [conv_dim, n_tokens, n_seqs]
ggml_tensor * qkv_mixed = ggml_concat(ctx0, query_flat, key_flat, 0);
qkv_mixed = ggml_concat(ctx0, qkv_mixed, value_flat, 0);
cb(qkv_mixed, "qkv_mixed", il);
return { qkv_mixed, z };
}
}
ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
llm_graph_input_rs * inp,
ggml_tensor * cur,
@ -547,15 +638,13 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
// Input projections
ggml_tensor * mixed_qkvz = build_lora_mm(model.layers[il].ssm_in, cur);
cb(mixed_qkvz, "linear_attn_mixed_qkvz", il);
auto qkvz = build_qkvz(cur, il);
ggml_tensor * qkv_mixed = qkvz.first;
ggml_tensor * z = qkvz.second;
ggml_tensor * mixed_ba = build_lora_mm(model.layers[il].ssm_beta_alpha, cur);
cb(mixed_ba, "linear_attn_mixed_ba", il);
int64_t qkvz_new_dim = 2 * head_k_dim + 2 * head_v_dim * (num_v_heads / num_k_heads);
ggml_tensor * mixed_qkvz_reshaped = ggml_reshape_4d(ctx0, mixed_qkvz, qkvz_new_dim, num_k_heads, n_seq_tokens, n_seqs);
// Reshape mixed_ba: [batch, seq_len, hidden_size] -> [batch, seq_len, num_k_heads, 2*num_v_heads/num_k_heads]
int64_t ba_new_dim = 2 * num_v_heads / num_k_heads;
ggml_tensor * mixed_ba_reshaped = ggml_reshape_4d(ctx0, mixed_ba, ba_new_dim, num_k_heads, n_seq_tokens, n_seqs);
@ -575,8 +664,9 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
split_sizes_ba[0] * ggml_element_size(mixed_ba_reshaped));
cb(a, "a", il);
// Reshape b and a to merge head dimensions: [batch, seq_len, num_k_heads, num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads]
ggml_tensor * beta = ggml_cont_3d(ctx0, b, num_v_heads, n_seq_tokens, n_seqs);
ggml_tensor * beta = ggml_cont_4d(ctx0, b, num_v_heads, 1, n_seq_tokens, n_seqs);
// Reshape a to merge head dimensions: [batch, seq_len, num_k_heads, num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads]
ggml_tensor * alpha = ggml_cont_3d(ctx0, a, num_v_heads, n_seq_tokens, n_seqs);
ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt);
@ -585,48 +675,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a); // -A_log.exp() * softplus
cb(gate, "gate", il);
// Split mixed_qkvz into query, key, value, z
int64_t split_sizes_qkvz[4] = {
head_k_dim, // query size
head_k_dim, // key size
head_v_dim * num_v_heads / num_k_heads, // value size
head_v_dim * num_v_heads / num_k_heads // z size
};
ggml_tensor * query =
ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[0], num_k_heads, n_seq_tokens, n_seqs,
mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], 0);
cb(query, "q", il);
ggml_tensor * key = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[1], num_k_heads, n_seq_tokens, n_seqs,
mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
split_sizes_qkvz[0] * sizeof(float));
cb(key, "k", il);
ggml_tensor * value =
ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[2], num_k_heads, n_seq_tokens, n_seqs,
mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
(split_sizes_qkvz[0] + split_sizes_qkvz[1]) * sizeof(float));
cb(value, "v", il);
ggml_tensor * z = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[3], num_k_heads, n_seq_tokens, n_seqs,
mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
(split_sizes_qkvz[0] + split_sizes_qkvz[1] + split_sizes_qkvz[2]) * sizeof(float));
cb(z, "z", il);
// After creating query, key, and value_reshaped, reshape each to flatten the head dimensions
// query: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs]
ggml_tensor * query_flat = ggml_cont_3d(ctx0, query, head_k_dim * num_k_heads, n_seq_tokens, n_seqs);
cb(query_flat, "query_flat", il);
// key: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs]
ggml_tensor * key_flat = ggml_cont_3d(ctx0, key, head_k_dim * num_k_heads, n_seq_tokens, n_seqs);
cb(key_flat, "key_flat", il);
// value_reshaped: [head_v_dim, num_v_heads, n_tokens, n_seqs] -> [head_v_dim * num_v_heads, n_tokens, n_seqs]
ggml_tensor * value_flat = ggml_cont_3d(ctx0, value, head_v_dim * num_v_heads, n_seq_tokens, n_seqs);
cb(value_flat, "value_flat", il);
// Get convolution states from cache
ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
@ -637,17 +685,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
cb(conv_states, "conv_states", il);
// Now concatenate along the feature dimension (dim 0) to get [conv_dim, n_tokens, n_seqs]
ggml_tensor * qkv_mixed = ggml_concat(ctx0, query_flat, key_flat, 0);
qkv_mixed = ggml_concat(ctx0, qkv_mixed, value_flat, 0);
cb(qkv_mixed, "qkv_mixed", il);
qkv_mixed = ggml_permute(ctx0, qkv_mixed, 1, 0, 2, 3);
cb(qkv_mixed, "qkv_mixed_permuted", il);
// Calculate the total conv dimension
int64_t qkv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads;
// Calculate convolution kernel size
ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d;
const int64_t conv_kernel_size = conv_kernel->ne[0];
@ -655,6 +692,9 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs);
cb(conv_states, "conv_states_reshaped", il);
qkv_mixed = ggml_permute(ctx0, qkv_mixed, 1, 0, 2, 3);
cb(qkv_mixed, "qkv_mixed_permuted", il);
ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0);
cb(conv_input, "conv_input", il);
@ -677,26 +717,25 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
ggml_tensor * conv_output_proper = ggml_ssm_conv(ctx0, conv_input, conv_kernel);
cb(conv_output_proper, "conv_output_raw", il);
conv_output_proper = ggml_cont(ctx0, ggml_transpose(ctx0, conv_output_proper));
cb(conv_output_proper, "conv_output_pre_silu", il);
ggml_tensor * conv_output_silu = ggml_silu(ctx0, conv_output_proper);
cb(conv_output_silu, "conv_output_silu", il);
ggml_tensor * conv_qkv_mix =
ggml_cont_2d(ctx0, ggml_transpose(ctx0, conv_output_silu), qkv_dim, n_seq_tokens * n_seqs);
cb(conv_qkv_mix, "conv_qkv_mix", il);
ggml_tensor * conv_qkv_mix = conv_output_silu;
// Calculate the total conv dimension
int64_t qkv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads;
int64_t nb1_qkv = ggml_row_size(conv_qkv_mix->type, qkv_dim);
// Extract the convolved Q, K, V from conv_output
ggml_tensor * q_conv =
ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, conv_qkv_mix->nb[1], 0);
ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv, 0);
cb(q_conv, "q_conv", il);
ggml_tensor * k_conv =
ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, conv_qkv_mix->nb[1],
ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv,
head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
cb(k_conv, "k_conv", il);
ggml_tensor * v_conv =
ggml_view_2d(ctx0, conv_qkv_mix, head_v_dim * num_v_heads, n_seq_tokens * n_seqs, conv_qkv_mix->nb[1],
ggml_view_2d(ctx0, conv_qkv_mix, head_v_dim * num_v_heads, n_seq_tokens * n_seqs, nb1_qkv,
2 * head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
cb(v_conv, "v_conv", il);
@ -705,8 +744,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
beta = ggml_cont_4d(ctx0, b, num_v_heads, 1, n_seq_tokens, n_seqs);
ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim * num_v_heads, 1, n_seqs);
cb(state, "state_predelta", il);
@ -738,45 +775,29 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
cb(v_conv, "v_conv_predelta", il);
// Choose between build_delta_net_chunking, build_delta_net_recurrent, and build_delta_net_autoregressive based on n_tokens
ggml_tensor * attn_out;
std::pair<ggml_tensor *, ggml_tensor *> attn_out; // pair of (output, new_state)
if (n_seq_tokens == 1) {
attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il);
} else {
attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il);
}
cb(attn_out, "attn_out", il);
// The tensors were concatenated 1d, so we need to extract them 1d as well
const int64_t output_flat_size = head_v_dim * num_v_heads * n_seq_tokens * n_seqs;
ggml_tensor * attn_out_1d = ggml_view_1d(ctx0, attn_out, output_flat_size, 0);
cb(attn_out_1d, "attn_out_1d", il);
ggml_tensor * attn_out_final = ggml_cont_4d(ctx0, attn_out_1d, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
cb(attn_out_final, "attn_out_reshaped", il);
// Extract the state part (second part of the concatenated tensor)
// State starts after n_tokens elements along dimension 1
const int64_t state_flat_size = head_v_dim * head_v_dim * num_v_heads * n_seqs;
ggml_tensor * state_1d =
ggml_view_1d(ctx0, attn_out, state_flat_size, output_flat_size * ggml_element_size(attn_out));
cb(state_1d, "state_1d", il);
ggml_tensor * output = attn_out.first;
ggml_tensor * new_state = attn_out.second;
cb(output, "attn_output", il);
cb(new_state, "new_state", il);
// Update the recurrent states
ggml_build_forward_expand(gf,
ggml_cpy(ctx0, state_1d,
ggml_cpy(ctx0, new_state,
ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs,
kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all))));
GGML_ASSERT(ggml_nelements(attn_out_1d) + ggml_nelements(state_1d) == ggml_nelements(attn_out));
// Reshape both attn_out_final and z to 2D tensors for normalization
// attn_out_final: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
ggml_tensor * attn_out_2d_final =
ggml_cont_2d(ctx0, attn_out_final, head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
ggml_tensor * attn_out_2d_final = ggml_reshape_2d(ctx0, output, head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
// z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
ggml_tensor * z_2d = ggml_cont_2d(ctx0, z, head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
ggml_tensor * z_2d = ggml_reshape_2d(ctx0, z, head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
// Apply gated normalization: self.norm(core_attn_out, z)
ggml_tensor * attn_out_norm = build_norm_gated(attn_out_2d_final, model.layers[il].ssm_norm, z_2d, il);
@ -828,12 +849,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const int
shared_gate = ggml_sigmoid(ctx0, shared_gate);
cb(shared_gate, "shared_expert_gate_sigmoid", il);
// The gate needs to be broadcast to match the dimensions of ffn_shexp
// ffn_shexp is [n_embd, n_tokens, 1, 1] and shared_gate is [1, n_tokens, 1, 1]
// We need to repeat the gate along the feature dimension
shared_gate = ggml_repeat(ctx0, shared_gate, ffn_shexp);
cb(shared_gate, "shared_expert_gate_broadcast", il);
// Apply the gate to the shared expert output
ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate);
cb(ffn_shexp, "ffn_shexp_gated", il);