From b9360c7fe194e8190e1ee8b9da258699d7666e17 Mon Sep 17 00:00:00 2001 From: Yee Man Chan Date: Sun, 11 Jan 2026 15:58:46 +0800 Subject: [PATCH] MLA KV cache support --- convert_hf_to_gguf.py | 14 +++-- src/llama-arch.cpp | 2 + src/llama-model.cpp | 8 ++- src/models/kimi-linear.cpp | 112 ++++++++++++++++++++++++++----------- 4 files changed, 95 insertions(+), 41 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 321930d7e6..3f402a9acb 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -5118,6 +5118,9 @@ class KimiLinearModel(TextModel): raise NotImplementedError(f"Deepseek pre-tokenizer {tokpre!r} is not supported yet!") def set_gguf_parameters(self): + # note: To enable MLA KV cache, attention needs to be converted into MQA (ie: GQA with 1 group) + self.hparams["num_key_value_heads"] = 1 + super().set_gguf_parameters() self.gguf_writer.add_vocab_size(self.hparams["vocab_size"]) self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID) @@ -5141,7 +5144,7 @@ class KimiLinearModel(TextModel): _full_attn_layers = linear_attn_config["full_attn_layers"] for il in range(self.hparams["num_hidden_layers"]): if il+1 in _full_attn_layers: - _num_kv_heads.append(linear_attn_config["num_heads"]) + _num_kv_heads.append(self.hparams["num_key_value_heads"]) else: _num_kv_heads.append(0) assert(len(_num_kv_heads) == self.hparams["num_hidden_layers"]) @@ -5156,8 +5159,6 @@ class KimiLinearModel(TextModel): if kda_head_dim is not None: self.gguf_writer.add_kda_head_dim(kda_head_dim) - # MLA params - use add_* methods that handle arch substitution - # MLA params - use add_* methods that handle arch substitution # Support both HuggingFace naming (q_lora_rank, kv_lora_rank) and internal naming (n_lora_q, n_lora_kv) q_lora_rank = self.hparams.get("q_lora_rank", self.hparams.get("n_lora_q")) @@ -5172,9 +5173,11 @@ class KimiLinearModel(TextModel): # Support HuggingFace naming: qk_nope_head_dim, qk_rope_head_dim, v_head_dim qk_nope_head_dim = self.hparams.get("qk_nope_head_dim") qk_rope_head_dim = self.hparams.get("qk_rope_head_dim") - self.gguf_writer.add_key_length(qk_nope_head_dim + qk_rope_head_dim) v_head_dim = self.hparams.get("v_head_dim") - self.gguf_writer.add_value_length(v_head_dim) + # To enable MLA KV cache, MLA needs to be converted into MQA with larger heads, then decompresses to MHA + self.gguf_writer.add_key_length(self.hparams["kv_lora_rank"] + self.hparams["qk_rope_head_dim"]) + self.gguf_writer.add_value_length(self.hparams["kv_lora_rank"]) + # Calculate n_embd_head_k_mla = qk_nope_head_dim + qk_rope_head_dim if "n_embd_head_k_mla" in self.hparams: @@ -5315,6 +5318,7 @@ class KimiLinearModel(TextModel): n_head_kv = self.hparams["num_key_value_heads"] v_head_dim = self.hparams["v_head_dim"] qk_nope_head_dim = self.hparams["qk_nope_head_dim"] + logger.info("Split kv_b n_head_kv %d\n" % n_head_kv) assert data_torch.shape[0] == n_head_kv * (v_head_dim + qk_nope_head_dim) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 076509ed8e..6baf3bd4da 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -2312,6 +2312,8 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_ATTN_Q_A_NORM, LLM_TENSOR_ATTN_KV_A_MQA, LLM_TENSOR_ATTN_KV_B, + LLM_TENSOR_ATTN_K_B, + LLM_TENSOR_ATTN_V_B, LLM_TENSOR_ATTN_KV_A_NORM, }; default: diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 59e8d49f08..712c341fd5 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -6771,8 +6771,12 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // Note: hparams.n_rot may be 72 (from conversion) but actual is 64 const int64_t qk_rope_head_dim = hparams.n_rot; // From config: qk_rope_head_dim layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + qk_rope_head_dim}, 0); - layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_k_mla - qk_rope_head_dim + n_embd_head_v_mla)}, 0); - + // Support Legacy GGUFs that don't split wkv_b (MLA KV cache disabled) + layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_k_mla - qk_rope_head_dim + n_embd_head_v_mla)}, TENSOR_NOT_REQUIRED); + if (!layer.wkv_b) { // MLA KV cache enabled + layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_k_mla - qk_rope_head_dim, kv_lora_rank, n_head}, 0); + layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_embd_head_v_mla, n_head}, 0); + } layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v_mla, n_embd}, 0); } diff --git a/src/models/kimi-linear.cpp b/src/models/kimi-linear.cpp index c55116bc69..9d83ca8fa5 100644 --- a/src/models/kimi-linear.cpp +++ b/src/models/kimi-linear.cpp @@ -321,9 +321,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll // vLLM Kimi: q = q_proj(hidden_states), then view as [n_tokens, n_head, qk_head_dim] // Note: Kimi MLA does NOT use RoPE (rotary_emb=None in vLLM) ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.wq, cur); - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k_mla, n_head, n_tokens); - cb(Qcur, "mla_Q", il); - + // Step 2: KV compression // kv_cmpr_pe = kv_a_proj_with_mqa(hidden_states) -> [kv_lora_rank + qk_rope_head_dim, n_tokens] ggml_tensor * kv_cmpr_pe = ggml_mul_mat(ctx0, layer.wkv_a_mqa, cur); @@ -341,37 +339,83 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll // Normalize kv_c kv_cmpr = build_norm(kv_cmpr, layer.attn_kv_a_norm, nullptr, LLM_NORM_RMS, il); - - // KV decompression: kv = kv_b_proj(kv_c_normed) - ggml_tensor * kv = ggml_mul_mat(ctx0, layer.wkv_b, kv_cmpr); - const int64_t kv_per_head = n_embd_head_qk_nope + n_embd_head_v_mla; - - // Split kv into k_nope and v - ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens, - ggml_row_size(kv->type, kv_per_head), - ggml_row_size(kv->type, kv_per_head * n_head), 0); - ggml_tensor * Vcur = ggml_view_3d(ctx0, kv, n_embd_head_v_mla, n_head, n_tokens, - ggml_row_size(kv->type, kv_per_head), - ggml_row_size(kv->type, kv_per_head * n_head), - ggml_row_size(kv->type, n_embd_head_qk_nope)); - k_nope = ggml_cont(ctx0, k_nope); - Vcur = ggml_cont(ctx0, Vcur); - cb(Vcur, "mla_V", il); - - // Concatenate k_nope + k_pe (broadcast k_pe to all heads) - // K = [k_nope, k_pe] where k_nope is [qk_nope_head_dim, n_head, n_tokens] - // and k_pe is [qk_rope_head_dim, 1, n_tokens] broadcast to all heads - // Need to broadcast k_pe from [qk_rope, 1, n_tokens] to [qk_rope, n_head, n_tokens] - ggml_tensor * k_pe_target = ggml_new_tensor_3d(ctx0, k_pe->type, n_embd_head_qk_rope, n_head, n_tokens); - ggml_tensor * k_pe_repeated = ggml_repeat(ctx0, k_pe, k_pe_target); - ggml_tensor * Kcur = ggml_concat(ctx0, k_nope, k_pe_repeated, 0); - cb(Kcur, "mla_K", il); - - // Direct softmax attention (with KV cache) - // Use build_attn with inp_attn for proper mask handling - cur = build_attn(inp_attn, layer.wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale_mla, il); - cb(cur, "mla_out", il); - + + if (layer.wk_b && layer.wv_b) { // MLA KV cache enabled + // extract q_nope + ggml_tensor * q_nope = + ggml_view_3d(ctx0, Qcur, n_embd_head_qk_nope, n_head, n_tokens, ggml_row_size(Qcur->type, n_embd_head_k_mla), + ggml_row_size(Qcur->type, n_embd_head_k_mla) * n_head, 0); + cb(q_nope, "q_nope", il); + + // and {n_embd_head_qk_rope, n_head, n_tokens} + ggml_tensor * q_pe = ggml_view_3d( + ctx0, Qcur, n_embd_head_qk_rope, n_head, n_tokens, ggml_row_size(Qcur->type, n_embd_head_k_mla), + ggml_row_size(Qcur->type, n_embd_head_k_mla) * n_head, ggml_row_size(Qcur->type, n_embd_head_qk_nope)); + cb(q_pe, "q_pe", il); + + // {n_embd_head_qk_nope, n_tokens, n_head} + q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3); + cb(q_nope, "q_nope_perm", il); + + // {n_embd_head_qk_nope, kv_lora_rank, n_head} x {n_embd_head_qk_nope, n_tokens, n_head} + ggml_tensor * q_nope_absorbed = ggml_mul_mat(ctx0, layer.wk_b, q_nope); + cb(q_nope_absorbed, "q_nope_absorbed", il); + + // {kv_lora_rank, n_head, n_tokens} + q_nope_absorbed = ggml_permute(ctx0, q_nope_absorbed, 0, 2, 1, 3); + cb(q_nope_absorbed, "q_nope_absorbed_perm", il); + + // {n_embd_head_qk_rope + kv_lora_rank, n_head, n_tokens} + // note: rope must go first for in-place context shifting in build_rope_shift() + Qcur = ggml_concat(ctx0, q_pe, q_nope_absorbed, 0); + cb(Qcur, "Qcur", il); + + kv_cmpr = ggml_reshape_3d(ctx0, kv_cmpr, kv_lora_rank, 1, n_tokens); + cb(kv_cmpr, "kv_cmpr_reshape", il); + + // {n_embd_head_qk_rope + kv_lora_rank, 1, n_tokens} + ggml_tensor * Kcur = ggml_concat(ctx0, k_pe, kv_cmpr, 0); + cb(Kcur, "Kcur", il); + + // {kv_lora_rank, 1, n_tokens} + ggml_tensor * Vcur = kv_cmpr; + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, layer.wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, layer.wv_b, kq_scale_mla, il); + cb(cur, "mla_out", il); + } else { // MLA KV cache disabled. Fall back to MHA KV cache. + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k_mla, n_head, n_tokens); + cb(Qcur, "mla_Q", il); + // KV decompression: kv = kv_b_proj(kv_c_normed) + ggml_tensor * kv = ggml_mul_mat(ctx0, layer.wkv_b, kv_cmpr); + const int64_t kv_per_head = n_embd_head_qk_nope + n_embd_head_v_mla; + + // Split kv into k_nope and v + ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens, + ggml_row_size(kv->type, kv_per_head), + ggml_row_size(kv->type, kv_per_head * n_head), 0); + ggml_tensor * Vcur = ggml_view_3d(ctx0, kv, n_embd_head_v_mla, n_head, n_tokens, + ggml_row_size(kv->type, kv_per_head), + ggml_row_size(kv->type, kv_per_head * n_head), + ggml_row_size(kv->type, n_embd_head_qk_nope)); + k_nope = ggml_cont(ctx0, k_nope); + Vcur = ggml_cont(ctx0, Vcur); + cb(Vcur, "mla_V", il); + + // Concatenate k_nope + k_pe (broadcast k_pe to all heads) + // K = [k_nope, k_pe] where k_nope is [qk_nope_head_dim, n_head, n_tokens] + // and k_pe is [qk_rope_head_dim, 1, n_tokens] broadcast to all heads + // Need to broadcast k_pe from [qk_rope, 1, n_tokens] to [qk_rope, n_head, n_tokens] + ggml_tensor * k_pe_target = ggml_new_tensor_3d(ctx0, k_pe->type, n_embd_head_qk_rope, n_head, n_tokens); + ggml_tensor * k_pe_repeated = ggml_repeat(ctx0, k_pe, k_pe_target); + ggml_tensor * Kcur = ggml_concat(ctx0, k_nope, k_pe_repeated, 0); + cb(Kcur, "mla_K", il); + + // Direct softmax attention (with MHA KV cache) + // Use build_attn with inp_attn for proper mask handling + cur = build_attn(inp_attn, layer.wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale_mla, il); + cb(cur, "mla_out", il); + } } else { // Unknown layer type - this should not happen GGML_ABORT("Kimi layer is neither KDA nor MLA - missing required tensors");