MLA KV cache support
This commit is contained in:
parent
dce064c0a3
commit
b9360c7fe1
|
|
@ -5118,6 +5118,9 @@ class KimiLinearModel(TextModel):
|
|||
raise NotImplementedError(f"Deepseek pre-tokenizer {tokpre!r} is not supported yet!")
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
# note: To enable MLA KV cache, attention needs to be converted into MQA (ie: GQA with 1 group)
|
||||
self.hparams["num_key_value_heads"] = 1
|
||||
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
|
||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
|
||||
|
|
@ -5141,7 +5144,7 @@ class KimiLinearModel(TextModel):
|
|||
_full_attn_layers = linear_attn_config["full_attn_layers"]
|
||||
for il in range(self.hparams["num_hidden_layers"]):
|
||||
if il+1 in _full_attn_layers:
|
||||
_num_kv_heads.append(linear_attn_config["num_heads"])
|
||||
_num_kv_heads.append(self.hparams["num_key_value_heads"])
|
||||
else:
|
||||
_num_kv_heads.append(0)
|
||||
assert(len(_num_kv_heads) == self.hparams["num_hidden_layers"])
|
||||
|
|
@ -5156,8 +5159,6 @@ class KimiLinearModel(TextModel):
|
|||
if kda_head_dim is not None:
|
||||
self.gguf_writer.add_kda_head_dim(kda_head_dim)
|
||||
|
||||
# MLA params - use add_* methods that handle arch substitution
|
||||
|
||||
# MLA params - use add_* methods that handle arch substitution
|
||||
# Support both HuggingFace naming (q_lora_rank, kv_lora_rank) and internal naming (n_lora_q, n_lora_kv)
|
||||
q_lora_rank = self.hparams.get("q_lora_rank", self.hparams.get("n_lora_q"))
|
||||
|
|
@ -5172,9 +5173,11 @@ class KimiLinearModel(TextModel):
|
|||
# Support HuggingFace naming: qk_nope_head_dim, qk_rope_head_dim, v_head_dim
|
||||
qk_nope_head_dim = self.hparams.get("qk_nope_head_dim")
|
||||
qk_rope_head_dim = self.hparams.get("qk_rope_head_dim")
|
||||
self.gguf_writer.add_key_length(qk_nope_head_dim + qk_rope_head_dim)
|
||||
v_head_dim = self.hparams.get("v_head_dim")
|
||||
self.gguf_writer.add_value_length(v_head_dim)
|
||||
# To enable MLA KV cache, MLA needs to be converted into MQA with larger heads, then decompresses to MHA
|
||||
self.gguf_writer.add_key_length(self.hparams["kv_lora_rank"] + self.hparams["qk_rope_head_dim"])
|
||||
self.gguf_writer.add_value_length(self.hparams["kv_lora_rank"])
|
||||
|
||||
|
||||
# Calculate n_embd_head_k_mla = qk_nope_head_dim + qk_rope_head_dim
|
||||
if "n_embd_head_k_mla" in self.hparams:
|
||||
|
|
@ -5315,6 +5318,7 @@ class KimiLinearModel(TextModel):
|
|||
n_head_kv = self.hparams["num_key_value_heads"]
|
||||
v_head_dim = self.hparams["v_head_dim"]
|
||||
qk_nope_head_dim = self.hparams["qk_nope_head_dim"]
|
||||
logger.info("Split kv_b n_head_kv %d\n" % n_head_kv)
|
||||
|
||||
assert data_torch.shape[0] == n_head_kv * (v_head_dim + qk_nope_head_dim)
|
||||
|
||||
|
|
|
|||
|
|
@ -2312,6 +2312,8 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
|
|||
LLM_TENSOR_ATTN_Q_A_NORM,
|
||||
LLM_TENSOR_ATTN_KV_A_MQA,
|
||||
LLM_TENSOR_ATTN_KV_B,
|
||||
LLM_TENSOR_ATTN_K_B,
|
||||
LLM_TENSOR_ATTN_V_B,
|
||||
LLM_TENSOR_ATTN_KV_A_NORM,
|
||||
};
|
||||
default:
|
||||
|
|
|
|||
|
|
@ -6771,8 +6771,12 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|||
// Note: hparams.n_rot may be 72 (from conversion) but actual is 64
|
||||
const int64_t qk_rope_head_dim = hparams.n_rot; // From config: qk_rope_head_dim
|
||||
layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + qk_rope_head_dim}, 0);
|
||||
layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_k_mla - qk_rope_head_dim + n_embd_head_v_mla)}, 0);
|
||||
|
||||
// Support Legacy GGUFs that don't split wkv_b (MLA KV cache disabled)
|
||||
layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_k_mla - qk_rope_head_dim + n_embd_head_v_mla)}, TENSOR_NOT_REQUIRED);
|
||||
if (!layer.wkv_b) { // MLA KV cache enabled
|
||||
layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_k_mla - qk_rope_head_dim, kv_lora_rank, n_head}, 0);
|
||||
layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_embd_head_v_mla, n_head}, 0);
|
||||
}
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v_mla, n_embd}, 0);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -321,9 +321,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
|
|||
// vLLM Kimi: q = q_proj(hidden_states), then view as [n_tokens, n_head, qk_head_dim]
|
||||
// Note: Kimi MLA does NOT use RoPE (rotary_emb=None in vLLM)
|
||||
ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.wq, cur);
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k_mla, n_head, n_tokens);
|
||||
cb(Qcur, "mla_Q", il);
|
||||
|
||||
|
||||
// Step 2: KV compression
|
||||
// kv_cmpr_pe = kv_a_proj_with_mqa(hidden_states) -> [kv_lora_rank + qk_rope_head_dim, n_tokens]
|
||||
ggml_tensor * kv_cmpr_pe = ggml_mul_mat(ctx0, layer.wkv_a_mqa, cur);
|
||||
|
|
@ -341,37 +339,83 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
|
|||
|
||||
// Normalize kv_c
|
||||
kv_cmpr = build_norm(kv_cmpr, layer.attn_kv_a_norm, nullptr, LLM_NORM_RMS, il);
|
||||
|
||||
// KV decompression: kv = kv_b_proj(kv_c_normed)
|
||||
ggml_tensor * kv = ggml_mul_mat(ctx0, layer.wkv_b, kv_cmpr);
|
||||
const int64_t kv_per_head = n_embd_head_qk_nope + n_embd_head_v_mla;
|
||||
|
||||
// Split kv into k_nope and v
|
||||
ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens,
|
||||
ggml_row_size(kv->type, kv_per_head),
|
||||
ggml_row_size(kv->type, kv_per_head * n_head), 0);
|
||||
ggml_tensor * Vcur = ggml_view_3d(ctx0, kv, n_embd_head_v_mla, n_head, n_tokens,
|
||||
ggml_row_size(kv->type, kv_per_head),
|
||||
ggml_row_size(kv->type, kv_per_head * n_head),
|
||||
ggml_row_size(kv->type, n_embd_head_qk_nope));
|
||||
k_nope = ggml_cont(ctx0, k_nope);
|
||||
Vcur = ggml_cont(ctx0, Vcur);
|
||||
cb(Vcur, "mla_V", il);
|
||||
|
||||
// Concatenate k_nope + k_pe (broadcast k_pe to all heads)
|
||||
// K = [k_nope, k_pe] where k_nope is [qk_nope_head_dim, n_head, n_tokens]
|
||||
// and k_pe is [qk_rope_head_dim, 1, n_tokens] broadcast to all heads
|
||||
// Need to broadcast k_pe from [qk_rope, 1, n_tokens] to [qk_rope, n_head, n_tokens]
|
||||
ggml_tensor * k_pe_target = ggml_new_tensor_3d(ctx0, k_pe->type, n_embd_head_qk_rope, n_head, n_tokens);
|
||||
ggml_tensor * k_pe_repeated = ggml_repeat(ctx0, k_pe, k_pe_target);
|
||||
ggml_tensor * Kcur = ggml_concat(ctx0, k_nope, k_pe_repeated, 0);
|
||||
cb(Kcur, "mla_K", il);
|
||||
|
||||
// Direct softmax attention (with KV cache)
|
||||
// Use build_attn with inp_attn for proper mask handling
|
||||
cur = build_attn(inp_attn, layer.wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale_mla, il);
|
||||
cb(cur, "mla_out", il);
|
||||
|
||||
|
||||
if (layer.wk_b && layer.wv_b) { // MLA KV cache enabled
|
||||
// extract q_nope
|
||||
ggml_tensor * q_nope =
|
||||
ggml_view_3d(ctx0, Qcur, n_embd_head_qk_nope, n_head, n_tokens, ggml_row_size(Qcur->type, n_embd_head_k_mla),
|
||||
ggml_row_size(Qcur->type, n_embd_head_k_mla) * n_head, 0);
|
||||
cb(q_nope, "q_nope", il);
|
||||
|
||||
// and {n_embd_head_qk_rope, n_head, n_tokens}
|
||||
ggml_tensor * q_pe = ggml_view_3d(
|
||||
ctx0, Qcur, n_embd_head_qk_rope, n_head, n_tokens, ggml_row_size(Qcur->type, n_embd_head_k_mla),
|
||||
ggml_row_size(Qcur->type, n_embd_head_k_mla) * n_head, ggml_row_size(Qcur->type, n_embd_head_qk_nope));
|
||||
cb(q_pe, "q_pe", il);
|
||||
|
||||
// {n_embd_head_qk_nope, n_tokens, n_head}
|
||||
q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3);
|
||||
cb(q_nope, "q_nope_perm", il);
|
||||
|
||||
// {n_embd_head_qk_nope, kv_lora_rank, n_head} x {n_embd_head_qk_nope, n_tokens, n_head}
|
||||
ggml_tensor * q_nope_absorbed = ggml_mul_mat(ctx0, layer.wk_b, q_nope);
|
||||
cb(q_nope_absorbed, "q_nope_absorbed", il);
|
||||
|
||||
// {kv_lora_rank, n_head, n_tokens}
|
||||
q_nope_absorbed = ggml_permute(ctx0, q_nope_absorbed, 0, 2, 1, 3);
|
||||
cb(q_nope_absorbed, "q_nope_absorbed_perm", il);
|
||||
|
||||
// {n_embd_head_qk_rope + kv_lora_rank, n_head, n_tokens}
|
||||
// note: rope must go first for in-place context shifting in build_rope_shift()
|
||||
Qcur = ggml_concat(ctx0, q_pe, q_nope_absorbed, 0);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
kv_cmpr = ggml_reshape_3d(ctx0, kv_cmpr, kv_lora_rank, 1, n_tokens);
|
||||
cb(kv_cmpr, "kv_cmpr_reshape", il);
|
||||
|
||||
// {n_embd_head_qk_rope + kv_lora_rank, 1, n_tokens}
|
||||
ggml_tensor * Kcur = ggml_concat(ctx0, k_pe, kv_cmpr, 0);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
// {kv_lora_rank, 1, n_tokens}
|
||||
ggml_tensor * Vcur = kv_cmpr;
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
cur = build_attn(inp_attn, layer.wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, layer.wv_b, kq_scale_mla, il);
|
||||
cb(cur, "mla_out", il);
|
||||
} else { // MLA KV cache disabled. Fall back to MHA KV cache.
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k_mla, n_head, n_tokens);
|
||||
cb(Qcur, "mla_Q", il);
|
||||
// KV decompression: kv = kv_b_proj(kv_c_normed)
|
||||
ggml_tensor * kv = ggml_mul_mat(ctx0, layer.wkv_b, kv_cmpr);
|
||||
const int64_t kv_per_head = n_embd_head_qk_nope + n_embd_head_v_mla;
|
||||
|
||||
// Split kv into k_nope and v
|
||||
ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens,
|
||||
ggml_row_size(kv->type, kv_per_head),
|
||||
ggml_row_size(kv->type, kv_per_head * n_head), 0);
|
||||
ggml_tensor * Vcur = ggml_view_3d(ctx0, kv, n_embd_head_v_mla, n_head, n_tokens,
|
||||
ggml_row_size(kv->type, kv_per_head),
|
||||
ggml_row_size(kv->type, kv_per_head * n_head),
|
||||
ggml_row_size(kv->type, n_embd_head_qk_nope));
|
||||
k_nope = ggml_cont(ctx0, k_nope);
|
||||
Vcur = ggml_cont(ctx0, Vcur);
|
||||
cb(Vcur, "mla_V", il);
|
||||
|
||||
// Concatenate k_nope + k_pe (broadcast k_pe to all heads)
|
||||
// K = [k_nope, k_pe] where k_nope is [qk_nope_head_dim, n_head, n_tokens]
|
||||
// and k_pe is [qk_rope_head_dim, 1, n_tokens] broadcast to all heads
|
||||
// Need to broadcast k_pe from [qk_rope, 1, n_tokens] to [qk_rope, n_head, n_tokens]
|
||||
ggml_tensor * k_pe_target = ggml_new_tensor_3d(ctx0, k_pe->type, n_embd_head_qk_rope, n_head, n_tokens);
|
||||
ggml_tensor * k_pe_repeated = ggml_repeat(ctx0, k_pe, k_pe_target);
|
||||
ggml_tensor * Kcur = ggml_concat(ctx0, k_nope, k_pe_repeated, 0);
|
||||
cb(Kcur, "mla_K", il);
|
||||
|
||||
// Direct softmax attention (with MHA KV cache)
|
||||
// Use build_attn with inp_attn for proper mask handling
|
||||
cur = build_attn(inp_attn, layer.wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale_mla, il);
|
||||
cb(cur, "mla_out", il);
|
||||
}
|
||||
} else {
|
||||
// Unknown layer type - this should not happen
|
||||
GGML_ABORT("Kimi layer is neither KDA nor MLA - missing required tensors");
|
||||
|
|
|
|||
Loading…
Reference in New Issue