From f1525b36959e24776c57031f00fce0212cc3eff8 Mon Sep 17 00:00:00 2001 From: Yee Man Chan Date: Tue, 27 Jan 2026 11:25:13 +0800 Subject: [PATCH] new class llm_graph_input_mem_hybrid_k to get around the new MLA change. switch the concat order of ggml_concat calls in kimi-linear.cpp to accommodate MLA changes. Removed support for exp_probs_b.weight --- src/llama-graph.cpp | 52 ++++++++++++++++++++++++++++++++++++++ src/llama-graph.h | 29 +++++++++++++++++++++ src/llama-model.cpp | 30 ++++++++++------------ src/models/kimi-linear.cpp | 12 ++++----- 4 files changed, 100 insertions(+), 23 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 06d0d4c558..1aebc012a1 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -533,6 +533,47 @@ bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) { return res; } +void llm_graph_input_mem_hybrid_k::set_input(const llama_ubatch * ubatch) { + mctx->get_attn()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch); + + mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn); + + const int64_t n_rs = mctx->get_recr()->get_n_rs(); + + if (inp_rs->s_copy) { + GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer)); + int32_t * data = (int32_t *) inp_rs->s_copy->data; + + // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n + for (uint32_t i = 0; i < n_rs; ++i) { + data[i] = mctx->get_recr()->s_copy(i); + } + } +} + +bool llm_graph_input_mem_hybrid_k::can_reuse(const llm_graph_params & params) { + const auto * mctx = static_cast(params.mctx); + + this->mctx = mctx; + + bool res = true; + + res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens; + + res &= inp_attn->self_kq_mask->ne[0] == mctx->get_attn()->get_n_kv(); + res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens; + + res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs(); + + res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs; + res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs; + + res &= inp_rs->head == mctx->get_recr()->get_head(); + res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z(); + + return res; +} + void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) { const auto * attn_ctx = mctx->get_attn(); @@ -2272,6 +2313,17 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp)); } +llm_graph_input_mem_hybrid_k * llm_graph_context::build_inp_mem_hybrid_k() const { + const auto * mctx_cur = static_cast(mctx); + + auto inp_rs = build_rs_inp_impl (ctx0, ubatch, mctx_cur->get_recr()); + auto inp_attn = build_attn_inp_k_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn()); + + auto inp = std::make_unique(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur); + + return (llm_graph_input_mem_hybrid_k *) res->add_input(std::move(inp)); +} + llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa() const { const auto * mctx_cur = static_cast(mctx); diff --git a/src/llama-graph.h b/src/llama-graph.h index 4090d8116c..1d69ff1a6f 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -433,6 +433,34 @@ public: const llama_memory_hybrid_context * mctx; }; +class llm_graph_input_mem_hybrid_k : public llm_graph_input_i { +public: + llm_graph_input_mem_hybrid_k( + const llama_cparams & cparams, + std::unique_ptr inp_attn, + std::unique_ptr inp_rs, + const llama_memory_hybrid_context * mctx) : + inp_attn(std::move(inp_attn)), + inp_rs(std::move(inp_rs)), + cparams(cparams), + mctx(mctx) { } + virtual ~llm_graph_input_mem_hybrid_k() = default; + + void set_input(const llama_ubatch * ubatch) override; + + bool can_reuse(const llm_graph_params & params) override; + + std::unique_ptr inp_attn; + std::unique_ptr inp_rs; + + llm_graph_input_attn_k * get_attn() const { return inp_attn.get(); } + llm_graph_input_rs * get_recr() const { return inp_rs.get(); } + + const llama_cparams cparams; + + const llama_memory_hybrid_context * mctx; +}; + class llm_graph_input_mem_hybrid_iswa : public llm_graph_input_i { public: llm_graph_input_mem_hybrid_iswa( @@ -960,6 +988,7 @@ struct llm_graph_context { // llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const; + llm_graph_input_mem_hybrid_k * build_inp_mem_hybrid_k() const; llm_graph_input_mem_hybrid_iswa * build_inp_mem_hybrid_iswa() const; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 67b0314de9..84ac4d3a9e 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2454,12 +2454,12 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_KIMI_LINEAR: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla, false); - ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla, false); - ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv, false); - ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false); - ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv, false); - ml.get_key(LLM_KV_KDA_HEAD_DIM, hparams.kda_head_dim, false); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla_impl); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla_impl); + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot); + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_KDA_HEAD_DIM, hparams.kda_head_dim); // MLA qk_rope_head_dim (for reference) // qk_rope_head_dim = 64, qk_nope_head_dim = 128, qk_head_dim = 192 @@ -2471,11 +2471,10 @@ void llama_model::load_hparams(llama_model_loader & ml) { } // MoE parameters - Kimi uses moe_intermediate_size = 1024 - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); - ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); - ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared, false); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); switch (hparams.n_layer) { @@ -6863,8 +6862,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // MLA Layer - use MLA-specific head dimensions const int64_t q_lora_rank = hparams.n_lora_q; const int64_t kv_lora_rank = hparams.n_lora_kv; - const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla; - const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla; + const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla(); + const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla(); layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, TENSOR_NOT_REQUIRED); layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0); @@ -6917,10 +6916,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp_actual, n_embd}, TENSOR_NOT_REQUIRED); layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp_actual}, TENSOR_NOT_REQUIRED); - layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); - if (!layer.ffn_exp_probs_b) { - layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "weight", i), {n_expert}, TENSOR_NOT_REQUIRED); - } + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0); } } } break; diff --git a/src/models/kimi-linear.cpp b/src/models/kimi-linear.cpp index 721bef9e7f..3ea404dd0b 100644 --- a/src/models/kimi-linear.cpp +++ b/src/models/kimi-linear.cpp @@ -72,7 +72,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll // Note: Kimi MLA does NOT use RoPE (rotary_emb=None in vLLM) // So we don't need inp_pos - auto * inp = build_inp_mem_hybrid(); + auto * inp = build_inp_mem_hybrid_k(); auto * inp_rs = inp->get_recr(); auto * inp_attn = inp->get_attn(); @@ -104,8 +104,8 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); // MLA params - const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla; - const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla; + const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla(); + const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla(); const int64_t kv_lora_rank = hparams.n_lora_kv; // qk_rope_head_dim = 64 (from Kimi config) which is hparams.n_rot // Confirmed from tensor shape: wkv_a_mqa [2304, 576] = [n_embd, kv_lora_rank + qk_rope_head_dim] @@ -258,14 +258,14 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll // {n_embd_head_qk_rope + kv_lora_rank, n_head, n_tokens} // note: rope must go first for in-place context shifting in build_rope_shift() - Qcur = ggml_concat(ctx0, q_pe, q_nope_absorbed, 0); + Qcur = ggml_concat(ctx0, q_nope_absorbed, q_pe, 0); cb(Qcur, "Qcur", il); kv_cmpr = ggml_reshape_3d(ctx0, kv_cmpr, kv_lora_rank, 1, n_tokens); cb(kv_cmpr, "kv_cmpr_reshape", il); // {n_embd_head_qk_rope + kv_lora_rank, 1, n_tokens} - ggml_tensor * Kcur = ggml_concat(ctx0, k_pe, kv_cmpr, 0); + ggml_tensor * Kcur = ggml_concat(ctx0, kv_cmpr, k_pe, 0); cb(Kcur, "Kcur", il); // {kv_lora_rank, 1, n_tokens} @@ -299,7 +299,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll // Need to broadcast k_pe from [qk_rope, 1, n_tokens] to [qk_rope, n_head, n_tokens] ggml_tensor * k_pe_target = ggml_new_tensor_3d(ctx0, k_pe->type, n_embd_head_qk_rope, n_head, n_tokens); ggml_tensor * k_pe_repeated = ggml_repeat(ctx0, k_pe, k_pe_target); - ggml_tensor * Kcur = ggml_concat(ctx0, k_nope, k_pe_repeated, 0); + ggml_tensor * Kcur = ggml_concat(ctx0, k_pe_repeated, k_nope, 0); cb(Kcur, "mla_K", il); // Direct softmax attention (with MHA KV cache)