diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index ead180523c..cc5e3691c3 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -4367,7 +4367,37 @@ class Qwen3NextModel(Qwen2MoeModel): elif name.endswith("norm.weight") and not name.endswith("linear_attn.norm.weight"): data_torch = data_torch + 1 - yield from super().modify_tensors(data_torch, name, bid) + if "in_proj_qkvz.weight" in name: + # original order: [q, k, v, z] * head_count + # corrected order: [q * head_count, k * head_count, v * head_count, z * head_count] + head_k_dim = self.hparams["linear_key_head_dim"] + head_v_dim = self.hparams["linear_value_head_dim"] + num_v_heads = self.hparams["linear_num_value_heads"] + num_k_heads = self.hparams["linear_num_key_heads"] + hidden_size = self.hparams["hidden_size"] + split_arg_list_qkvz = [ + head_k_dim, # q partition + head_k_dim, # k partition + (num_v_heads // num_k_heads * head_v_dim), # v partition + (num_v_heads // num_k_heads * head_v_dim), # z partition + ] + # view as (n_embd, head_count, [q+k+v+z]) + data_torch = data_torch.permute(1, 0).contiguous() + data_torch = data_torch.view(-1, num_k_heads, sum(split_arg_list_qkvz)) + # split into q, k, v, z + q, k, v, z = torch.split(data_torch, split_arg_list_qkvz, dim=-1) + # flatten dim + head_count + q = q.contiguous().view(hidden_size, -1) + k = k.contiguous().view(hidden_size, -1) + v = v.contiguous().view(hidden_size, -1) + z = z.contiguous().view(hidden_size, -1) + # stack back + qkv = torch.cat([q, k, v], dim=-1).permute(1, 0).contiguous() + z = z.permute(1, 0).contiguous() + yield (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_QKV, bid, ".weight"), qkv) + yield (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_GATE, bid, ".weight"), z) + else: + yield from super().modify_tensors(data_torch, name, bid) @ModelBase.register("RND1") diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index b240e8e4a6..404af1ef03 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -1738,6 +1738,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.ATTN_OUT, MODEL_TENSOR.ATTN_POST_NORM, MODEL_TENSOR.ATTN_GATE, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.FFN_GATE_INP, MODEL_TENSOR.FFN_GATE_INP_SHEXP, MODEL_TENSOR.FFN_UP_SHEXP, diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 2ead965469..f736ee6705 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -950,6 +950,8 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_ATTN_K_NORM, LLM_TENSOR_ATTN_V, LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_ATTN_QKV, + LLM_TENSOR_ATTN_GATE, LLM_TENSOR_FFN_NORM, LLM_TENSOR_FFN_GATE_INP, LLM_TENSOR_FFN_GATE_EXPS, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 5de6493b9e..f6cea8f8db 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -6763,7 +6763,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } else { // Linear attention (gated delta net) specific tensors // Create tensors with calculated dimensions - layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), { n_embd, qkvz_dim }, 0); + // note: ssm_in is used by legacy GGUF + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), { n_embd, qkvz_dim }, TENSOR_NOT_REQUIRED); + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0); layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0); layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0); diff --git a/src/models/models.h b/src/models/models.h index 72b2b760c6..6c40f48042 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -466,7 +466,8 @@ private: ggml_tensor * cur, int il); - ggml_tensor * build_delta_net_chunking( + // returns pair of output and new state + std::pair build_delta_net_chunking( ggml_tensor * q, ggml_tensor * k, ggml_tensor * v, @@ -478,7 +479,8 @@ private: ggml_tensor * diag_mask, int il); - ggml_tensor * build_delta_net_autoregressive( + // returns pair of output and new state + std::pair build_delta_net_autoregressive( ggml_tensor * q, ggml_tensor * k, ggml_tensor * v, @@ -493,6 +495,11 @@ private: ggml_tensor * gate, int layer); + // returns pair of qkv, z + std::pair build_qkvz( + ggml_tensor * input, + int il); + const llama_model & model; }; diff --git a/src/models/qwen3next.cpp b/src/models/qwen3next.cpp index 775b3135d3..0e4fe7ebdc 100644 --- a/src/models/qwen3next.cpp +++ b/src/models/qwen3next.cpp @@ -86,7 +86,15 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr ggml_build_forward_expand(gf, cur); } -ggml_tensor * llm_build_qwen3next::build_delta_net_chunking( +// utility to get one slice from the third dimension +// input dim: [x, y, c, b] +// output dim: [x, y, 1, b] +static ggml_tensor * get_slice_2d(ggml_context * ctx0, ggml_tensor * t, int64_t c) { + return ggml_view_4d(ctx0, t, t->ne[0], t->ne[1], 1, t->ne[3], + t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c); +} + +std::pair llm_build_qwen3next::build_delta_net_chunking( ggml_tensor * q, ggml_tensor * k, ggml_tensor * v, @@ -187,18 +195,16 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_chunking( beta = ggml_reshape_4d(ctx0, beta, 1, chunk_size, n_chunks, H_k * n_seqs); ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g); + cb(g_cumsum, "g_cumsum", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs) - cb(g_cumsum, "g_cumsum", il); - - ggml_tensor * gcs_i = ggml_reshape_4d(ctx0, g_cumsum, chunk_size, 1, n_chunks, H_v * n_seqs); + ggml_tensor * gcs_i = g_cumsum; // ggml_reshape_4d(ctx0, g_cumsum, chunk_size, 1, n_chunks, H_v * n_seqs); ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs); ggml_tensor * gcs_j_broadcast = ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, n_chunks, H_v * n_seqs); ggml_tensor * decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i); - - cb(decay_mask, "decay_mask", il); + cb(decay_mask, "decay_mask", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); decay_mask = ggml_exp(ctx0, decay_mask); @@ -208,8 +214,7 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_chunking( ggml_tensor * k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask); ggml_tensor * attn = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, causal_mask)); - - cb(attn, "attn_pre_solve", il); + cb(attn, "attn_pre_solve", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) ggml_tensor * attn_lower = ggml_mul(ctx0, attn, causal_mask); ggml_tensor * lhs = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower); @@ -217,8 +222,7 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_chunking( ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false); attn = ggml_mul(ctx0, lin_solve, causal_mask); attn = ggml_add(ctx0, attn, identity); - - cb(attn, "attn_solved", il); + cb(attn, "attn_solved", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), attn); @@ -226,116 +230,126 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_chunking( ggml_tensor * gexp = ggml_exp(ctx0, g_cumsum_t); ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, gexp); - - cb(kbeta_gexp, "kbeta_gexp", il); + cb(kbeta_gexp, "kbeta_gexp", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs) ggml_tensor * k_cumdecay = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_mul_mat(ctx0, attn, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gexp))))); + cb(k_cumdecay, "k_cumdecay", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) - cb(k_cumdecay, "k_cumdecay", il); + ggml_tensor * attn_kq = ggml_mul_mat(ctx0, k, q); + attn_kq = ggml_mul(ctx0, attn_kq, decay_mask); + attn_kq = ggml_mul(ctx0, attn_kq, diag_mask); + cb(attn_kq, "attn_kq", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) + + // vectorized calculation of key_gdiff + // improved from the chunked version: + // g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1) + // g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp() + // key_gdiff = key * g_diff.unsqueeze(-1) + // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new + // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew + + // get last element in g_cumsum along chunk_size dimension (ne0) + // example: [[x, y, z, ..., last], ...] -> [[last], ...] + ggml_tensor * g_last = ggml_view_4d(ctx0, g_cumsum, 1, 1, g_cumsum->ne[2], g_cumsum->ne[3], + g_cumsum->nb[1], g_cumsum->nb[2], g_cumsum->nb[3], + (g_cumsum->ne[0] - 1) * ggml_element_size(g_cumsum)); + g_last = ggml_cont(ctx0, g_last); + cb(g_last, "g_last", il); // shape: (1, 1, n_chunks, H_v * n_seqs) + + ggml_tensor * g_last_exp = ggml_exp(ctx0, g_last); + cb(g_last_exp, "g_last_exp", il); // shape: (1, 1, n_chunks, H_v * n_seqs) + + ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum, g_last)); + cb(g_diff, "g_diff", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs) + + ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff); + ggml_tensor * key_gdiff = ggml_mul(ctx0, k, g_diff_exp); + cb(key_gdiff, "key_gdiff", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs) + + + // state to be updated per chunk + ggml_tensor * new_state = state; // ggml_dup(ctx0, state); + cb(new_state, "new_state", il); // shape: (S_v, S_v, H_v, n_seqs) + + // shape after loop of chunks: (S_v, chunk_size, n_chunks, H_v * n_seqs) ggml_tensor * core_attn_out = nullptr; - ggml_tensor * new_state = ggml_dup(ctx0, state); - - cb(new_state, "new_state", il); for (int64_t chunk = 0; chunk < n_chunks; chunk++) { - auto chunkify = [=](ggml_tensor * t) { - return ggml_cont(ctx0, ggml_view_4d(ctx0, t, t->ne[0], chunk_size, 1, t->ne[3], - t->nb[1], t->nb[2], t->nb[3], t->nb[2] * chunk)); - }; + // shape: (S_k, chunk_size, 1, H_k * n_seqs) + ggml_tensor * q_chunk = get_slice_2d(ctx0, q, chunk); // (no cont), next op: ggml_mul - auto chunkify_g = [=](ggml_tensor * t) { - return ggml_cont(ctx0, ggml_view_4d(ctx0, t, chunk_size, t->ne[1], 1, t->ne[3], - t->nb[1], t->nb[2], t->nb[3], t->nb[2] * chunk)); - }; + // shape: (S_v, chunk_size, 1, H_v * n_seqs) + ggml_tensor * v_chunk = get_slice_2d(ctx0, v, chunk); // (no cont), next op: ggml_repeat - ggml_tensor * k_chunk = chunkify(k); - ggml_tensor * q_chunk = chunkify(q); - ggml_tensor * v_chunk = chunkify(v); + // shape: (chunk_size, 1, n_chunks, H_v * n_seqs) + ggml_tensor * gexp_chunk = get_slice_2d(ctx0, gexp, chunk); // (no cont), next op: ggml_mul - ggml_tensor * g_cs_chunk = chunkify_g(g_cumsum); - ggml_tensor * g_cs_chunk_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cs_chunk)); - - ggml_tensor * decay_mask_chunk = chunkify(decay_mask); - ggml_tensor * k_cumdecay_chunk = chunkify(k_cumdecay); - - ggml_tensor * gexp_chunk = ggml_exp(ctx0, g_cs_chunk_t); + // shape: (chunk_size, 1, H_v * n_seqs) + ggml_tensor * k_cumdecay_chunk = get_slice_2d(ctx0, k_cumdecay, chunk); // (no cont), next op: ggml_mul_mat // attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0) - attn = ggml_mul_mat(ctx0, k_chunk, q_chunk); - attn = ggml_mul(ctx0, attn, decay_mask_chunk); - attn = ggml_mul(ctx0, attn, diag_mask); + // replaced by precomputed attn_kq + ggml_tensor * attn_chunk = get_slice_2d(ctx0, attn_kq, chunk); + cb(attn_chunk, "attn_chunk", il); ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs); // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk); + cb(v_prime, "v_prime_chunk", il); // shape: (S_v, 1, H_v * n_seqs) // v_new = v_i - v_prime ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, v_chunk, v_prime), v_prime); ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new)); + cb(v_new, "v_new_chunk", il); // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state ggml_tensor * q_g_exp = ggml_mul(ctx0, q_chunk, gexp_chunk); ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_g_exp); + cb(attn_inter, "attn_inter_chunk", il); // core_attn_out[:, :, i] = attn_inter + attn @ v_new - ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn); + ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn_chunk); + cb(v_attn, "v_attn_chunk", il); ggml_tensor * core_attn_out_chunk = ggml_add(ctx0, attn_inter, v_attn); + cb(core_attn_out_chunk, "core_attn_out_chunk", il); // shape: (S_v, chunk_size, 1, H_v * n_seqs) - core_attn_out = core_attn_out == nullptr ? core_attn_out_chunk : ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 1); + core_attn_out = core_attn_out == nullptr + ? core_attn_out_chunk + : ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 2); - // g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1) - // g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp() - // key_gdiff = key * g_diff.unsqueeze(-1) // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new + ggml_tensor * k_gdiff = ggml_cont(ctx0, get_slice_2d(ctx0, key_gdiff, chunk)); + //ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, k_gdiff, v_new); // this is slower on metal, why? + ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, ggml_cont(ctx0, ggml_transpose(ctx0, k_gdiff))); + // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew - - ggml_tensor * g_cum_last = - ggml_cont(ctx0, ggml_view_4d(ctx0, g_cs_chunk_t, g_cs_chunk_t->ne[0], 1, g_cs_chunk_t->ne[2], g_cs_chunk_t->ne[3], - g_cs_chunk_t->nb[1], g_cs_chunk_t->nb[2], g_cs_chunk_t->nb[3], - g_cs_chunk_t->nb[0] * (g_cs_chunk_t->ne[1] - 1))); - - ggml_tensor * gexp_last = - ggml_reshape_4d(ctx0, ggml_exp(ctx0, g_cum_last), 1, 1, g_cum_last->ne[0] * g_cum_last->ne[2], g_cum_last->ne[3]); - - ggml_tensor * g_cum_last_3d = - ggml_reshape_3d(ctx0, g_cum_last, g_cum_last->ne[0], g_cum_last->ne[2], g_cum_last->ne[3]); - - ggml_tensor * g_cumsum_3d = ggml_reshape_3d(ctx0, g_cs_chunk, g_cs_chunk->ne[0], g_cs_chunk->ne[2], g_cs_chunk->ne[3]); - - ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum_3d, g_cum_last_3d)); - - ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff); - - ggml_tensor * key_gdiff = ggml_mul(ctx0, k_chunk, - ggml_reshape_4d(ctx0, g_diff_exp, 1, g_diff_exp->ne[0], g_diff_exp->ne[1], - g_diff_exp->ne[2] * g_diff_exp->ne[3])); - - ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, ggml_cont(ctx0, ggml_transpose(ctx0, key_gdiff))); - + ggml_tensor * gexp_last_chunk = ggml_cont(ctx0, get_slice_2d(ctx0, g_last_exp, chunk)); new_state = ggml_add(ctx0, - ggml_mul(ctx0, new_state, ggml_reshape_4d(ctx0, gexp_last, gexp_last->ne[0], gexp_last->ne[1], H_v, n_seqs)), + ggml_mul(ctx0, new_state, ggml_reshape_4d(ctx0, gexp_last_chunk, gexp_last_chunk->ne[0], gexp_last_chunk->ne[1], H_v, n_seqs)), ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs)); } - core_attn_out = ggml_cont_4d(ctx0, core_attn_out, S_v, chunk_size * n_chunks, H_v, n_seqs); - - ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out, S_v, n_tokens, H_v, n_seqs, core_attn_out->nb[1], core_attn_out->nb[2], core_attn_out->nb[3], 0); + // truncate padded tokens + ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out, + S_v, n_tokens, H_v, n_seqs, + ggml_row_size(core_attn_out->type, S_v), + ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks), + ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks * H_v), 0); + output_tokens = ggml_cont(ctx0, output_tokens); cb(output_tokens, "output_tokens", il); - // flatten output - ggml_tensor * flat_output = - ggml_cont_1d(ctx0, ggml_permute(ctx0, output_tokens, 0, 2, 1, 3), S_v * H_v * n_tokens * n_seqs); + // permute back to (S_v, H_v, n_tokens, n_seqs) + output_tokens = ggml_permute(ctx0, output_tokens, 0, 2, 1, 3); + output_tokens = ggml_cont(ctx0, output_tokens); - ggml_tensor * flat_state = ggml_cont_1d(ctx0, new_state, S_v * S_v * H_v * n_seqs); - - return ggml_concat(ctx0, flat_output, flat_state, 0); + return {output_tokens, new_state}; } -ggml_tensor * llm_build_qwen3next::build_delta_net_autoregressive( +std::pair llm_build_qwen3next::build_delta_net_autoregressive( ggml_tensor * q, ggml_tensor * k, ggml_tensor * v, @@ -419,11 +433,7 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_autoregressive( cb(core_attn_out, "output_tokens", il); cb(state, "new_state", il); - // flatten output, no need to permute since n_tokens is 1 so [S_v, 1, H_v, n_seqs] and [S_v, H_v, 1, n_seqs] are equivalent memory-layout wise - ggml_tensor * flat_output = ggml_reshape_1d(ctx0, core_attn_out, S_v * H_v * n_tokens * n_seqs); - ggml_tensor * flat_state = ggml_reshape_1d(ctx0, state, S_v * S_v * H_v * n_seqs); - - return ggml_concat(ctx0, flat_output, flat_state, 0); + return {core_attn_out, state}; } ggml_tensor * llm_build_qwen3next::build_norm_gated( @@ -523,6 +533,87 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn( return cur; } +std::pair llm_build_qwen3next::build_qkvz( + ggml_tensor * input, + int il) { + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t n_seqs = ubatch.n_seqs; + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t num_k_heads = hparams.ssm_n_group; + const int64_t num_v_heads = hparams.ssm_dt_rank; + const int64_t head_v_dim = d_inner / num_v_heads; + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + + if (model.layers[il].wqkv) { + // optimized path + ggml_tensor * qkv_mixed = build_lora_mm(model.layers[il].wqkv, input); + qkv_mixed = ggml_reshape_3d(ctx0, qkv_mixed, qkv_mixed->ne[0], n_seq_tokens, n_seqs); + cb(qkv_mixed, "linear_attn_qkv_mixed", il); + + ggml_tensor * z = build_lora_mm(model.layers[il].wqkv_gate, input); + cb(z, "z", il); + + return { qkv_mixed, z }; + + } else { + // legacy (slower) path + ggml_tensor * mixed_qkvz = build_lora_mm(model.layers[il].ssm_in, input); + cb(mixed_qkvz, "linear_attn_mixed_qkvz", il); + + int64_t qkvz_new_dim = 2 * head_k_dim + 2 * head_v_dim * (num_v_heads / num_k_heads); + ggml_tensor * mixed_qkvz_reshaped = ggml_reshape_4d(ctx0, mixed_qkvz, qkvz_new_dim, num_k_heads, n_seq_tokens, n_seqs); + + // Split mixed_qkvz into query, key, value, z + int64_t split_sizes_qkvz[4] = { + head_k_dim, // query size + head_k_dim, // key size + head_v_dim * num_v_heads / num_k_heads, // value size + head_v_dim * num_v_heads / num_k_heads // z size + }; + + ggml_tensor * query = + ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[0], num_k_heads, n_seq_tokens, n_seqs, + mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], 0); + cb(query, "q", il); + + ggml_tensor * key = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[1], num_k_heads, n_seq_tokens, n_seqs, + mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], + split_sizes_qkvz[0] * ggml_element_size(mixed_qkvz_reshaped)); + cb(key, "k", il); + + ggml_tensor * value = + ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[2], num_k_heads, n_seq_tokens, n_seqs, + mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], + (split_sizes_qkvz[0] + split_sizes_qkvz[1]) * ggml_element_size(mixed_qkvz_reshaped)); + cb(value, "v", il); + + ggml_tensor * z = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[3], num_k_heads, n_seq_tokens, n_seqs, + mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], + (split_sizes_qkvz[0] + split_sizes_qkvz[1] + split_sizes_qkvz[2]) * ggml_element_size(mixed_qkvz_reshaped)); + cb(z, "z", il); + + // After creating query, key, and value_reshaped, reshape each to flatten the head dimensions + // query: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs] + ggml_tensor * query_flat = ggml_cont_3d(ctx0, query, head_k_dim * num_k_heads, n_seq_tokens, n_seqs); + cb(query_flat, "query_flat", il); + + // key: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs] + ggml_tensor * key_flat = ggml_cont_3d(ctx0, key, head_k_dim * num_k_heads, n_seq_tokens, n_seqs); + cb(key_flat, "key_flat", il); + + // value_reshaped: [head_v_dim, num_v_heads, n_tokens, n_seqs] -> [head_v_dim * num_v_heads, n_tokens, n_seqs] + ggml_tensor * value_flat = ggml_cont_3d(ctx0, value, head_v_dim * num_v_heads, n_seq_tokens, n_seqs); + cb(value_flat, "value_flat", il); + + // Now concatenate along the feature dimension (dim 0) to get [conv_dim, n_tokens, n_seqs] + ggml_tensor * qkv_mixed = ggml_concat(ctx0, query_flat, key_flat, 0); + qkv_mixed = ggml_concat(ctx0, qkv_mixed, value_flat, 0); + cb(qkv_mixed, "qkv_mixed", il); + + return { qkv_mixed, z }; + } +} + ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( llm_graph_input_rs * inp, ggml_tensor * cur, @@ -547,15 +638,13 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); // Input projections - ggml_tensor * mixed_qkvz = build_lora_mm(model.layers[il].ssm_in, cur); - cb(mixed_qkvz, "linear_attn_mixed_qkvz", il); + auto qkvz = build_qkvz(cur, il); + ggml_tensor * qkv_mixed = qkvz.first; + ggml_tensor * z = qkvz.second; ggml_tensor * mixed_ba = build_lora_mm(model.layers[il].ssm_beta_alpha, cur); cb(mixed_ba, "linear_attn_mixed_ba", il); - int64_t qkvz_new_dim = 2 * head_k_dim + 2 * head_v_dim * (num_v_heads / num_k_heads); - ggml_tensor * mixed_qkvz_reshaped = ggml_reshape_4d(ctx0, mixed_qkvz, qkvz_new_dim, num_k_heads, n_seq_tokens, n_seqs); - // Reshape mixed_ba: [batch, seq_len, hidden_size] -> [batch, seq_len, num_k_heads, 2*num_v_heads/num_k_heads] int64_t ba_new_dim = 2 * num_v_heads / num_k_heads; ggml_tensor * mixed_ba_reshaped = ggml_reshape_4d(ctx0, mixed_ba, ba_new_dim, num_k_heads, n_seq_tokens, n_seqs); @@ -575,8 +664,9 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( split_sizes_ba[0] * ggml_element_size(mixed_ba_reshaped)); cb(a, "a", il); - // Reshape b and a to merge head dimensions: [batch, seq_len, num_k_heads, num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads] - ggml_tensor * beta = ggml_cont_3d(ctx0, b, num_v_heads, n_seq_tokens, n_seqs); + ggml_tensor * beta = ggml_cont_4d(ctx0, b, num_v_heads, 1, n_seq_tokens, n_seqs); + + // Reshape a to merge head dimensions: [batch, seq_len, num_k_heads, num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads] ggml_tensor * alpha = ggml_cont_3d(ctx0, a, num_v_heads, n_seq_tokens, n_seqs); ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt); @@ -585,48 +675,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a); // -A_log.exp() * softplus cb(gate, "gate", il); - // Split mixed_qkvz into query, key, value, z - int64_t split_sizes_qkvz[4] = { - head_k_dim, // query size - head_k_dim, // key size - head_v_dim * num_v_heads / num_k_heads, // value size - head_v_dim * num_v_heads / num_k_heads // z size - }; - - ggml_tensor * query = - ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[0], num_k_heads, n_seq_tokens, n_seqs, - mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], 0); - cb(query, "q", il); - - ggml_tensor * key = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[1], num_k_heads, n_seq_tokens, n_seqs, - mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], - split_sizes_qkvz[0] * sizeof(float)); - cb(key, "k", il); - - ggml_tensor * value = - ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[2], num_k_heads, n_seq_tokens, n_seqs, - mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], - (split_sizes_qkvz[0] + split_sizes_qkvz[1]) * sizeof(float)); - cb(value, "v", il); - - ggml_tensor * z = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[3], num_k_heads, n_seq_tokens, n_seqs, - mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], - (split_sizes_qkvz[0] + split_sizes_qkvz[1] + split_sizes_qkvz[2]) * sizeof(float)); - cb(z, "z", il); - - // After creating query, key, and value_reshaped, reshape each to flatten the head dimensions - // query: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs] - ggml_tensor * query_flat = ggml_cont_3d(ctx0, query, head_k_dim * num_k_heads, n_seq_tokens, n_seqs); - cb(query_flat, "query_flat", il); - - // key: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs] - ggml_tensor * key_flat = ggml_cont_3d(ctx0, key, head_k_dim * num_k_heads, n_seq_tokens, n_seqs); - cb(key_flat, "key_flat", il); - - // value_reshaped: [head_v_dim, num_v_heads, n_tokens, n_seqs] -> [head_v_dim * num_v_heads, n_tokens, n_seqs] - ggml_tensor * value_flat = ggml_cont_3d(ctx0, value, head_v_dim * num_v_heads, n_seq_tokens, n_seqs); - cb(value_flat, "value_flat", il); - // Get convolution states from cache ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); @@ -637,17 +685,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); cb(conv_states, "conv_states", il); - // Now concatenate along the feature dimension (dim 0) to get [conv_dim, n_tokens, n_seqs] - ggml_tensor * qkv_mixed = ggml_concat(ctx0, query_flat, key_flat, 0); - qkv_mixed = ggml_concat(ctx0, qkv_mixed, value_flat, 0); - cb(qkv_mixed, "qkv_mixed", il); - - qkv_mixed = ggml_permute(ctx0, qkv_mixed, 1, 0, 2, 3); - cb(qkv_mixed, "qkv_mixed_permuted", il); - - // Calculate the total conv dimension - int64_t qkv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads; - // Calculate convolution kernel size ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d; const int64_t conv_kernel_size = conv_kernel->ne[0]; @@ -655,6 +692,9 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); cb(conv_states, "conv_states_reshaped", il); + qkv_mixed = ggml_permute(ctx0, qkv_mixed, 1, 0, 2, 3); + cb(qkv_mixed, "qkv_mixed_permuted", il); + ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0); cb(conv_input, "conv_input", il); @@ -677,26 +717,25 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( ggml_tensor * conv_output_proper = ggml_ssm_conv(ctx0, conv_input, conv_kernel); cb(conv_output_proper, "conv_output_raw", il); - conv_output_proper = ggml_cont(ctx0, ggml_transpose(ctx0, conv_output_proper)); - cb(conv_output_proper, "conv_output_pre_silu", il); - ggml_tensor * conv_output_silu = ggml_silu(ctx0, conv_output_proper); cb(conv_output_silu, "conv_output_silu", il); - ggml_tensor * conv_qkv_mix = - ggml_cont_2d(ctx0, ggml_transpose(ctx0, conv_output_silu), qkv_dim, n_seq_tokens * n_seqs); - cb(conv_qkv_mix, "conv_qkv_mix", il); + ggml_tensor * conv_qkv_mix = conv_output_silu; + + // Calculate the total conv dimension + int64_t qkv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads; + int64_t nb1_qkv = ggml_row_size(conv_qkv_mix->type, qkv_dim); // Extract the convolved Q, K, V from conv_output ggml_tensor * q_conv = - ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, conv_qkv_mix->nb[1], 0); + ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv, 0); cb(q_conv, "q_conv", il); ggml_tensor * k_conv = - ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, conv_qkv_mix->nb[1], + ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv, head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix)); cb(k_conv, "k_conv", il); ggml_tensor * v_conv = - ggml_view_2d(ctx0, conv_qkv_mix, head_v_dim * num_v_heads, n_seq_tokens * n_seqs, conv_qkv_mix->nb[1], + ggml_view_2d(ctx0, conv_qkv_mix, head_v_dim * num_v_heads, n_seq_tokens * n_seqs, nb1_qkv, 2 * head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix)); cb(v_conv, "v_conv", il); @@ -705,8 +744,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); - beta = ggml_cont_4d(ctx0, b, num_v_heads, 1, n_seq_tokens, n_seqs); - ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim * num_v_heads, 1, n_seqs); cb(state, "state_predelta", il); @@ -738,45 +775,29 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( cb(v_conv, "v_conv_predelta", il); // Choose between build_delta_net_chunking, build_delta_net_recurrent, and build_delta_net_autoregressive based on n_tokens - ggml_tensor * attn_out; + std::pair attn_out; // pair of (output, new_state) if (n_seq_tokens == 1) { attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il); } else { attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il); } - cb(attn_out, "attn_out", il); - - // The tensors were concatenated 1d, so we need to extract them 1d as well - const int64_t output_flat_size = head_v_dim * num_v_heads * n_seq_tokens * n_seqs; - ggml_tensor * attn_out_1d = ggml_view_1d(ctx0, attn_out, output_flat_size, 0); - cb(attn_out_1d, "attn_out_1d", il); - - ggml_tensor * attn_out_final = ggml_cont_4d(ctx0, attn_out_1d, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); - cb(attn_out_final, "attn_out_reshaped", il); - - // Extract the state part (second part of the concatenated tensor) - // State starts after n_tokens elements along dimension 1 - const int64_t state_flat_size = head_v_dim * head_v_dim * num_v_heads * n_seqs; - - ggml_tensor * state_1d = - ggml_view_1d(ctx0, attn_out, state_flat_size, output_flat_size * ggml_element_size(attn_out)); - cb(state_1d, "state_1d", il); + ggml_tensor * output = attn_out.first; + ggml_tensor * new_state = attn_out.second; + cb(output, "attn_output", il); + cb(new_state, "new_state", il); // Update the recurrent states ggml_build_forward_expand(gf, - ggml_cpy(ctx0, state_1d, + ggml_cpy(ctx0, new_state, ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs, kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); - GGML_ASSERT(ggml_nelements(attn_out_1d) + ggml_nelements(state_1d) == ggml_nelements(attn_out)); - // Reshape both attn_out_final and z to 2D tensors for normalization // attn_out_final: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] - ggml_tensor * attn_out_2d_final = - ggml_cont_2d(ctx0, attn_out_final, head_v_dim, num_v_heads * n_seq_tokens * n_seqs); + ggml_tensor * attn_out_2d_final = ggml_reshape_2d(ctx0, output, head_v_dim, num_v_heads * n_seq_tokens * n_seqs); // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] - ggml_tensor * z_2d = ggml_cont_2d(ctx0, z, head_v_dim, num_v_heads * n_seq_tokens * n_seqs); + ggml_tensor * z_2d = ggml_reshape_2d(ctx0, z, head_v_dim, num_v_heads * n_seq_tokens * n_seqs); // Apply gated normalization: self.norm(core_attn_out, z) ggml_tensor * attn_out_norm = build_norm_gated(attn_out_2d_final, model.layers[il].ssm_norm, z_2d, il); @@ -828,12 +849,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const int shared_gate = ggml_sigmoid(ctx0, shared_gate); cb(shared_gate, "shared_expert_gate_sigmoid", il); - // The gate needs to be broadcast to match the dimensions of ffn_shexp - // ffn_shexp is [n_embd, n_tokens, 1, 1] and shared_gate is [1, n_tokens, 1, 1] - // We need to repeat the gate along the feature dimension - shared_gate = ggml_repeat(ctx0, shared_gate, ffn_shexp); - cb(shared_gate, "shared_expert_gate_broadcast", il); - // Apply the gate to the shared expert output ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate); cb(ffn_shexp, "ffn_shexp_gated", il);