diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 6cc28eff28..e8e1bbf1cd 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -7462,6 +7462,12 @@ bool llama_model::load_tensors(llama_model_loader & ml) { if (!layer.wo_s && layer.wo) { layer.wo_s = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "scale", i), {1}, TENSOR_NOT_REQUIRED); } + if (!layer.wqkv_s && layer.wqkv) { + layer.wqkv_s = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.wqkv_gate_s && layer.wqkv_gate) { + layer.wqkv_gate_s = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } // dense FFN weight scales (per-tensor, shape {1}) if (!layer.ffn_gate_s && layer.ffn_gate) { @@ -7473,6 +7479,15 @@ bool llama_model::load_tensors(llama_model_loader & ml) { if (!layer.ffn_up_s && layer.ffn_up) { layer.ffn_up_s = create_tensor(tn(LLM_TENSOR_FFN_UP, "scale", i), {1}, TENSOR_NOT_REQUIRED); } + if (!layer.ffn_gate_shexp_s && layer.ffn_gate_shexp) { + layer.ffn_gate_shexp_s = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_down_shexp_s && layer.ffn_down_shexp) { + layer.ffn_down_shexp_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_up_shexp_s && layer.ffn_up_shexp) { + layer.ffn_up_shexp_s = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } // MoE expert weight scales (per-expert, shape {n_expert}) if (!layer.ffn_gate_exps_s && layer.ffn_gate_exps) { @@ -7484,6 +7499,17 @@ bool llama_model::load_tensors(llama_model_loader & ml) { if (!layer.ffn_up_exps_s && layer.ffn_up_exps) { layer.ffn_up_exps_s = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "scale", i), {n_expert}, TENSOR_NOT_REQUIRED); } + + // recurrent / linear-attention weight scales (per-tensor, shape {1}) + if (!layer.ssm_out_s && layer.ssm_out) { + layer.ssm_out_s = create_tensor(tn(LLM_TENSOR_SSM_OUT, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ssm_alpha_s && layer.ssm_alpha) { + layer.ssm_alpha_s = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ssm_beta_s && layer.ssm_beta) { + layer.ssm_beta_s = create_tensor(tn(LLM_TENSOR_SSM_BETA, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } } } diff --git a/src/llama-model.h b/src/llama-model.h index 9a2dacecca..25bf892e7e 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -401,9 +401,17 @@ struct llama_layer { struct ggml_tensor * wk_s = nullptr; struct ggml_tensor * wv_s = nullptr; struct ggml_tensor * wo_s = nullptr; + struct ggml_tensor * wqkv_s = nullptr; + struct ggml_tensor * wqkv_gate_s = nullptr; struct ggml_tensor * ffn_gate_s = nullptr; struct ggml_tensor * ffn_up_s = nullptr; struct ggml_tensor * ffn_down_s = nullptr; + struct ggml_tensor * ffn_gate_shexp_s = nullptr; + struct ggml_tensor * ffn_up_shexp_s = nullptr; + struct ggml_tensor * ffn_down_shexp_s = nullptr; + struct ggml_tensor * ssm_out_s = nullptr; + struct ggml_tensor * ssm_alpha_s = nullptr; + struct ggml_tensor * ssm_beta_s = nullptr; // altup & laurel struct ggml_tensor * per_layer_inp_gate = nullptr; diff --git a/src/models/qwen35.cpp b/src/models/qwen35.cpp index e12dad7001..3108bf331a 100644 --- a/src/models/qwen35.cpp +++ b/src/models/qwen35.cpp @@ -90,11 +90,11 @@ std::pair llm_build_qwen35::build_qkvz( const int64_t n_seqs = ubatch.n_seqs; const int64_t n_seq_tokens = ubatch.n_seq_tokens; - ggml_tensor * qkv_mixed = build_lora_mm(model.layers[il].wqkv, input); + ggml_tensor * qkv_mixed = build_lora_mm(model.layers[il].wqkv, input, model.layers[il].wqkv_s); qkv_mixed = ggml_reshape_3d(ctx0, qkv_mixed, qkv_mixed->ne[0], n_seq_tokens, n_seqs); cb(qkv_mixed, "linear_attn_qkv_mixed", il); - ggml_tensor * z = build_lora_mm(model.layers[il].wqkv_gate, input); + ggml_tensor * z = build_lora_mm(model.layers[il].wqkv_gate, input, model.layers[il].wqkv_gate_s); cb(z, "z", il); return { qkv_mixed, z }; @@ -123,7 +123,7 @@ ggml_tensor * llm_build_qwen35::build_layer_attn( // Order: joint QG projection, QG split, Q norm, KV projection, K norm, RoPE, attention // Qwen3Next uses a single Q projection that outputs query + gate - ggml_tensor * Qcur_full = build_lora_mm(model.layers[il].wq, cur); // [ (n_embd_head * 2) * n_head, n_tokens ] + ggml_tensor * Qcur_full = build_lora_mm(model.layers[il].wq, cur, model.layers[il].wq_s); // [ (n_embd_head * 2) * n_head, n_tokens ] cb(Qcur_full, "Qcur_full", il); ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, @@ -135,10 +135,10 @@ ggml_tensor * llm_build_qwen35::build_layer_attn( Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur, model.layers[il].wk_s); cb(Kcur, "Kcur", il); - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur, model.layers[il].wv_s); cb(Vcur, "Vcur", il); // Apply K normalization @@ -186,7 +186,7 @@ ggml_tensor * llm_build_qwen35::build_layer_attn( cur = ggml_mul(ctx0, cur, gate_sigmoid); cb(cur, "attn_gated", il); - cur = build_lora_mm(model.layers[il].wo, cur); + cur = build_lora_mm(model.layers[il].wo, cur, model.layers[il].wo_s); cb(cur, "attn_output", il); return cur; @@ -217,13 +217,13 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear( ggml_tensor * qkv_mixed = qkvz.first; ggml_tensor * z = qkvz.second; - ggml_tensor * beta = build_lora_mm(model.layers[il].ssm_beta, cur); + ggml_tensor * beta = build_lora_mm(model.layers[il].ssm_beta, cur, model.layers[il].ssm_beta_s); beta = ggml_reshape_4d(ctx0, beta, 1, num_v_heads, n_seq_tokens, n_seqs); cb(beta, "beta", il); beta = ggml_sigmoid(ctx0, beta); - ggml_tensor * alpha = build_lora_mm(model.layers[il].ssm_alpha, cur); + ggml_tensor * alpha = build_lora_mm(model.layers[il].ssm_alpha, cur, model.layers[il].ssm_alpha_s); alpha = ggml_cont_3d(ctx0, alpha, num_v_heads, n_seq_tokens, n_seqs); cb(alpha, "alpha", il); @@ -356,7 +356,7 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear( cb(final_output, "final_output", il); // Output projection - cur = build_lora_mm(model.layers[il].ssm_out, final_output); + cur = build_lora_mm(model.layers[il].ssm_out, final_output, model.layers[il].ssm_out_s); cb(cur, "linear_attn_out", il); // Reshape back to original dimensions @@ -370,9 +370,9 @@ ggml_tensor * llm_build_qwen35::build_layer_ffn(ggml_tensor * cur, const int il) GGML_ASSERT(model.layers[il].ffn_gate_inp == nullptr); cur = build_ffn(cur, - model.layers[il].ffn_up, NULL, NULL, - model.layers[il].ffn_gate, NULL, NULL, - model.layers[il].ffn_down, NULL, NULL, + model.layers[il].ffn_up, NULL, model.layers[il].ffn_up_s, + model.layers[il].ffn_gate, NULL, model.layers[il].ffn_gate_s, + model.layers[il].ffn_down, NULL, model.layers[il].ffn_down_s, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); cb(cur, "ffn_out", il); diff --git a/src/models/qwen35moe.cpp b/src/models/qwen35moe.cpp index 8d07c7ed27..165e2412e5 100644 --- a/src/models/qwen35moe.cpp +++ b/src/models/qwen35moe.cpp @@ -90,11 +90,11 @@ std::pair llm_build_qwen35moe::build_qkvz( const int64_t n_seqs = ubatch.n_seqs; const int64_t n_seq_tokens = ubatch.n_seq_tokens; - ggml_tensor * qkv_mixed = build_lora_mm(model.layers[il].wqkv, input); + ggml_tensor * qkv_mixed = build_lora_mm(model.layers[il].wqkv, input, model.layers[il].wqkv_s); qkv_mixed = ggml_reshape_3d(ctx0, qkv_mixed, qkv_mixed->ne[0], n_seq_tokens, n_seqs); cb(qkv_mixed, "linear_attn_qkv_mixed", il); - ggml_tensor * z = build_lora_mm(model.layers[il].wqkv_gate, input); + ggml_tensor * z = build_lora_mm(model.layers[il].wqkv_gate, input, model.layers[il].wqkv_gate_s); cb(z, "z", il); return { qkv_mixed, z }; @@ -123,7 +123,7 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn( // Order: joint QG projection, QG split, Q norm, KV projection, K norm, RoPE, attention // Qwen3Next uses a single Q projection that outputs query + gate - ggml_tensor * Qcur_full = build_lora_mm(model.layers[il].wq, cur); // [ (n_embd_head * 2) * n_head, n_tokens ] + ggml_tensor * Qcur_full = build_lora_mm(model.layers[il].wq, cur, model.layers[il].wq_s); // [ (n_embd_head * 2) * n_head, n_tokens ] cb(Qcur_full, "Qcur_full", il); ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, @@ -135,10 +135,10 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn( Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur, model.layers[il].wk_s); cb(Kcur, "Kcur", il); - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur, model.layers[il].wv_s); cb(Vcur, "Vcur", il); // Apply K normalization @@ -186,7 +186,7 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn( cur = ggml_mul(ctx0, cur, gate_sigmoid); cb(cur, "attn_gated", il); - cur = build_lora_mm(model.layers[il].wo, cur); + cur = build_lora_mm(model.layers[il].wo, cur, model.layers[il].wo_s); cb(cur, "attn_output", il); return cur; @@ -217,13 +217,13 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear( ggml_tensor * qkv_mixed = qkvz.first; ggml_tensor * z = qkvz.second; - ggml_tensor * beta = build_lora_mm(model.layers[il].ssm_beta, cur); + ggml_tensor * beta = build_lora_mm(model.layers[il].ssm_beta, cur, model.layers[il].ssm_beta_s); beta = ggml_reshape_4d(ctx0, beta, 1, num_v_heads, n_seq_tokens, n_seqs); cb(beta, "beta", il); beta = ggml_sigmoid(ctx0, beta); - ggml_tensor * alpha = build_lora_mm(model.layers[il].ssm_alpha, cur); + ggml_tensor * alpha = build_lora_mm(model.layers[il].ssm_alpha, cur, model.layers[il].ssm_alpha_s); alpha = ggml_cont_3d(ctx0, alpha, num_v_heads, n_seq_tokens, n_seqs); cb(alpha, "alpha", il); @@ -356,7 +356,7 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear( cb(final_output, "final_output", il); // Output projection - cur = build_lora_mm(model.layers[il].ssm_out, final_output); + cur = build_lora_mm(model.layers[il].ssm_out, final_output, model.layers[il].ssm_out_s); cb(cur, "linear_attn_out", il); // Reshape back to original dimensions @@ -380,16 +380,19 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_ffn(ggml_tensor * cur, const int LLM_FFN_SILU, true, hparams.expert_weights_scale, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il, - nullptr, model.layers[il].ffn_gate_up_exps); + nullptr, model.layers[il].ffn_gate_up_exps, + model.layers[il].ffn_up_exps_s, + model.layers[il].ffn_gate_exps_s, + model.layers[il].ffn_down_exps_s); cb(moe_out, "ffn_moe_out", il); // Add shared experts if present - following Qwen3Next reference implementation if (model.layers[il].ffn_up_shexp != nullptr) { ggml_tensor * ffn_shexp = build_ffn(cur, - model.layers[il].ffn_up_shexp, NULL, NULL, - model.layers[il].ffn_gate_shexp, NULL, NULL, - model.layers[il].ffn_down_shexp, NULL, NULL, + model.layers[il].ffn_up_shexp, NULL, model.layers[il].ffn_up_shexp_s, + model.layers[il].ffn_gate_shexp, NULL, model.layers[il].ffn_gate_shexp_s, + model.layers[il].ffn_down_shexp, NULL, model.layers[il].ffn_down_shexp_s, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); cb(ffn_shexp, "ffn_shexp", il);