MLA KV cache support

This commit is contained in:
Yee Man Chan 2026-01-11 15:58:46 +08:00
parent dce064c0a3
commit b9360c7fe1
4 changed files with 95 additions and 41 deletions

View File

@ -5118,6 +5118,9 @@ class KimiLinearModel(TextModel):
raise NotImplementedError(f"Deepseek pre-tokenizer {tokpre!r} is not supported yet!")
def set_gguf_parameters(self):
# note: To enable MLA KV cache, attention needs to be converted into MQA (ie: GQA with 1 group)
self.hparams["num_key_value_heads"] = 1
super().set_gguf_parameters()
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
@ -5141,7 +5144,7 @@ class KimiLinearModel(TextModel):
_full_attn_layers = linear_attn_config["full_attn_layers"]
for il in range(self.hparams["num_hidden_layers"]):
if il+1 in _full_attn_layers:
_num_kv_heads.append(linear_attn_config["num_heads"])
_num_kv_heads.append(self.hparams["num_key_value_heads"])
else:
_num_kv_heads.append(0)
assert(len(_num_kv_heads) == self.hparams["num_hidden_layers"])
@ -5156,8 +5159,6 @@ class KimiLinearModel(TextModel):
if kda_head_dim is not None:
self.gguf_writer.add_kda_head_dim(kda_head_dim)
# MLA params - use add_* methods that handle arch substitution
# MLA params - use add_* methods that handle arch substitution
# Support both HuggingFace naming (q_lora_rank, kv_lora_rank) and internal naming (n_lora_q, n_lora_kv)
q_lora_rank = self.hparams.get("q_lora_rank", self.hparams.get("n_lora_q"))
@ -5172,9 +5173,11 @@ class KimiLinearModel(TextModel):
# Support HuggingFace naming: qk_nope_head_dim, qk_rope_head_dim, v_head_dim
qk_nope_head_dim = self.hparams.get("qk_nope_head_dim")
qk_rope_head_dim = self.hparams.get("qk_rope_head_dim")
self.gguf_writer.add_key_length(qk_nope_head_dim + qk_rope_head_dim)
v_head_dim = self.hparams.get("v_head_dim")
self.gguf_writer.add_value_length(v_head_dim)
# To enable MLA KV cache, MLA needs to be converted into MQA with larger heads, then decompresses to MHA
self.gguf_writer.add_key_length(self.hparams["kv_lora_rank"] + self.hparams["qk_rope_head_dim"])
self.gguf_writer.add_value_length(self.hparams["kv_lora_rank"])
# Calculate n_embd_head_k_mla = qk_nope_head_dim + qk_rope_head_dim
if "n_embd_head_k_mla" in self.hparams:
@ -5315,6 +5318,7 @@ class KimiLinearModel(TextModel):
n_head_kv = self.hparams["num_key_value_heads"]
v_head_dim = self.hparams["v_head_dim"]
qk_nope_head_dim = self.hparams["qk_nope_head_dim"]
logger.info("Split kv_b n_head_kv %d\n" % n_head_kv)
assert data_torch.shape[0] == n_head_kv * (v_head_dim + qk_nope_head_dim)

View File

@ -2312,6 +2312,8 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
LLM_TENSOR_ATTN_Q_A_NORM,
LLM_TENSOR_ATTN_KV_A_MQA,
LLM_TENSOR_ATTN_KV_B,
LLM_TENSOR_ATTN_K_B,
LLM_TENSOR_ATTN_V_B,
LLM_TENSOR_ATTN_KV_A_NORM,
};
default:

View File

@ -6771,8 +6771,12 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
// Note: hparams.n_rot may be 72 (from conversion) but actual is 64
const int64_t qk_rope_head_dim = hparams.n_rot; // From config: qk_rope_head_dim
layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + qk_rope_head_dim}, 0);
layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_k_mla - qk_rope_head_dim + n_embd_head_v_mla)}, 0);
// Support Legacy GGUFs that don't split wkv_b (MLA KV cache disabled)
layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_k_mla - qk_rope_head_dim + n_embd_head_v_mla)}, TENSOR_NOT_REQUIRED);
if (!layer.wkv_b) { // MLA KV cache enabled
layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_k_mla - qk_rope_head_dim, kv_lora_rank, n_head}, 0);
layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_embd_head_v_mla, n_head}, 0);
}
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v_mla, n_embd}, 0);
}

View File

@ -321,9 +321,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
// vLLM Kimi: q = q_proj(hidden_states), then view as [n_tokens, n_head, qk_head_dim]
// Note: Kimi MLA does NOT use RoPE (rotary_emb=None in vLLM)
ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.wq, cur);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k_mla, n_head, n_tokens);
cb(Qcur, "mla_Q", il);
// Step 2: KV compression
// kv_cmpr_pe = kv_a_proj_with_mqa(hidden_states) -> [kv_lora_rank + qk_rope_head_dim, n_tokens]
ggml_tensor * kv_cmpr_pe = ggml_mul_mat(ctx0, layer.wkv_a_mqa, cur);
@ -341,37 +339,83 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
// Normalize kv_c
kv_cmpr = build_norm(kv_cmpr, layer.attn_kv_a_norm, nullptr, LLM_NORM_RMS, il);
// KV decompression: kv = kv_b_proj(kv_c_normed)
ggml_tensor * kv = ggml_mul_mat(ctx0, layer.wkv_b, kv_cmpr);
const int64_t kv_per_head = n_embd_head_qk_nope + n_embd_head_v_mla;
// Split kv into k_nope and v
ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens,
ggml_row_size(kv->type, kv_per_head),
ggml_row_size(kv->type, kv_per_head * n_head), 0);
ggml_tensor * Vcur = ggml_view_3d(ctx0, kv, n_embd_head_v_mla, n_head, n_tokens,
ggml_row_size(kv->type, kv_per_head),
ggml_row_size(kv->type, kv_per_head * n_head),
ggml_row_size(kv->type, n_embd_head_qk_nope));
k_nope = ggml_cont(ctx0, k_nope);
Vcur = ggml_cont(ctx0, Vcur);
cb(Vcur, "mla_V", il);
// Concatenate k_nope + k_pe (broadcast k_pe to all heads)
// K = [k_nope, k_pe] where k_nope is [qk_nope_head_dim, n_head, n_tokens]
// and k_pe is [qk_rope_head_dim, 1, n_tokens] broadcast to all heads
// Need to broadcast k_pe from [qk_rope, 1, n_tokens] to [qk_rope, n_head, n_tokens]
ggml_tensor * k_pe_target = ggml_new_tensor_3d(ctx0, k_pe->type, n_embd_head_qk_rope, n_head, n_tokens);
ggml_tensor * k_pe_repeated = ggml_repeat(ctx0, k_pe, k_pe_target);
ggml_tensor * Kcur = ggml_concat(ctx0, k_nope, k_pe_repeated, 0);
cb(Kcur, "mla_K", il);
// Direct softmax attention (with KV cache)
// Use build_attn with inp_attn for proper mask handling
cur = build_attn(inp_attn, layer.wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale_mla, il);
cb(cur, "mla_out", il);
if (layer.wk_b && layer.wv_b) { // MLA KV cache enabled
// extract q_nope
ggml_tensor * q_nope =
ggml_view_3d(ctx0, Qcur, n_embd_head_qk_nope, n_head, n_tokens, ggml_row_size(Qcur->type, n_embd_head_k_mla),
ggml_row_size(Qcur->type, n_embd_head_k_mla) * n_head, 0);
cb(q_nope, "q_nope", il);
// and {n_embd_head_qk_rope, n_head, n_tokens}
ggml_tensor * q_pe = ggml_view_3d(
ctx0, Qcur, n_embd_head_qk_rope, n_head, n_tokens, ggml_row_size(Qcur->type, n_embd_head_k_mla),
ggml_row_size(Qcur->type, n_embd_head_k_mla) * n_head, ggml_row_size(Qcur->type, n_embd_head_qk_nope));
cb(q_pe, "q_pe", il);
// {n_embd_head_qk_nope, n_tokens, n_head}
q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3);
cb(q_nope, "q_nope_perm", il);
// {n_embd_head_qk_nope, kv_lora_rank, n_head} x {n_embd_head_qk_nope, n_tokens, n_head}
ggml_tensor * q_nope_absorbed = ggml_mul_mat(ctx0, layer.wk_b, q_nope);
cb(q_nope_absorbed, "q_nope_absorbed", il);
// {kv_lora_rank, n_head, n_tokens}
q_nope_absorbed = ggml_permute(ctx0, q_nope_absorbed, 0, 2, 1, 3);
cb(q_nope_absorbed, "q_nope_absorbed_perm", il);
// {n_embd_head_qk_rope + kv_lora_rank, n_head, n_tokens}
// note: rope must go first for in-place context shifting in build_rope_shift()
Qcur = ggml_concat(ctx0, q_pe, q_nope_absorbed, 0);
cb(Qcur, "Qcur", il);
kv_cmpr = ggml_reshape_3d(ctx0, kv_cmpr, kv_lora_rank, 1, n_tokens);
cb(kv_cmpr, "kv_cmpr_reshape", il);
// {n_embd_head_qk_rope + kv_lora_rank, 1, n_tokens}
ggml_tensor * Kcur = ggml_concat(ctx0, k_pe, kv_cmpr, 0);
cb(Kcur, "Kcur", il);
// {kv_lora_rank, 1, n_tokens}
ggml_tensor * Vcur = kv_cmpr;
cb(Vcur, "Vcur", il);
cur = build_attn(inp_attn, layer.wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, layer.wv_b, kq_scale_mla, il);
cb(cur, "mla_out", il);
} else { // MLA KV cache disabled. Fall back to MHA KV cache.
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k_mla, n_head, n_tokens);
cb(Qcur, "mla_Q", il);
// KV decompression: kv = kv_b_proj(kv_c_normed)
ggml_tensor * kv = ggml_mul_mat(ctx0, layer.wkv_b, kv_cmpr);
const int64_t kv_per_head = n_embd_head_qk_nope + n_embd_head_v_mla;
// Split kv into k_nope and v
ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens,
ggml_row_size(kv->type, kv_per_head),
ggml_row_size(kv->type, kv_per_head * n_head), 0);
ggml_tensor * Vcur = ggml_view_3d(ctx0, kv, n_embd_head_v_mla, n_head, n_tokens,
ggml_row_size(kv->type, kv_per_head),
ggml_row_size(kv->type, kv_per_head * n_head),
ggml_row_size(kv->type, n_embd_head_qk_nope));
k_nope = ggml_cont(ctx0, k_nope);
Vcur = ggml_cont(ctx0, Vcur);
cb(Vcur, "mla_V", il);
// Concatenate k_nope + k_pe (broadcast k_pe to all heads)
// K = [k_nope, k_pe] where k_nope is [qk_nope_head_dim, n_head, n_tokens]
// and k_pe is [qk_rope_head_dim, 1, n_tokens] broadcast to all heads
// Need to broadcast k_pe from [qk_rope, 1, n_tokens] to [qk_rope, n_head, n_tokens]
ggml_tensor * k_pe_target = ggml_new_tensor_3d(ctx0, k_pe->type, n_embd_head_qk_rope, n_head, n_tokens);
ggml_tensor * k_pe_repeated = ggml_repeat(ctx0, k_pe, k_pe_target);
ggml_tensor * Kcur = ggml_concat(ctx0, k_nope, k_pe_repeated, 0);
cb(Kcur, "mla_K", il);
// Direct softmax attention (with MHA KV cache)
// Use build_attn with inp_attn for proper mask handling
cur = build_attn(inp_attn, layer.wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale_mla, il);
cb(cur, "mla_out", il);
}
} else {
// Unknown layer type - this should not happen
GGML_ABORT("Kimi layer is neither KDA nor MLA - missing required tensors");