use is_mla to switch between different mem_hybrid types

This commit is contained in:
Yee Man Chan 2026-02-01 20:12:20 +08:00
parent 2c8cd844d0
commit 11282a0f60
1 changed files with 7 additions and 5 deletions

View File

@ -72,9 +72,11 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
// Note: Kimi MLA does NOT use RoPE (rotary_emb=None in vLLM)
// So we don't need inp_pos
auto * inp = build_inp_mem_hybrid_k();
auto * inp_rs = inp->get_recr();
auto * inp_attn = inp->get_attn();
auto * inp_kv = !hparams.is_mla() ? build_inp_mem_hybrid() : nullptr;
auto * inp_k = hparams.is_mla() ? build_inp_mem_hybrid_k() : nullptr;
auto * inp_rs = hparams.is_mla() ? inp_k->get_recr() : inp_kv->get_recr();
auto * inp_attn_kv = !hparams.is_mla() ? inp_kv->get_attn() : nullptr;
auto * inp_attn_k = hparams.is_mla() ? inp_k->get_attn() : nullptr;
// Output ids for selecting which tokens to output
ggml_tensor * inp_out_ids = build_inp_out_ids();
@ -272,7 +274,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
ggml_tensor * Vcur = kv_cmpr;
cb(Vcur, "Vcur", il);
cur = build_attn(inp_attn, layer.wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, layer.wv_b, kq_scale_mla, il);
cur = build_attn(inp_attn_k, layer.wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, layer.wv_b, kq_scale_mla, il);
cb(cur, "mla_out", il);
} else { // MLA KV cache disabled. Fall back to MHA KV cache.
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k_mla, n_head, n_tokens);
@ -302,7 +304,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
// Direct softmax attention (with MHA KV cache)
// Use build_attn with inp_attn for proper mask handling
cur = build_attn(inp_attn, layer.wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale_mla, il);
cur = build_attn(inp_attn_kv, layer.wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale_mla, il);
cb(cur, "mla_out", il);
}
} else {