use is_mla to switch between different mem_hybrid types
This commit is contained in:
parent
2c8cd844d0
commit
11282a0f60
|
|
@ -72,9 +72,11 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
|
|||
// Note: Kimi MLA does NOT use RoPE (rotary_emb=None in vLLM)
|
||||
// So we don't need inp_pos
|
||||
|
||||
auto * inp = build_inp_mem_hybrid_k();
|
||||
auto * inp_rs = inp->get_recr();
|
||||
auto * inp_attn = inp->get_attn();
|
||||
auto * inp_kv = !hparams.is_mla() ? build_inp_mem_hybrid() : nullptr;
|
||||
auto * inp_k = hparams.is_mla() ? build_inp_mem_hybrid_k() : nullptr;
|
||||
auto * inp_rs = hparams.is_mla() ? inp_k->get_recr() : inp_kv->get_recr();
|
||||
auto * inp_attn_kv = !hparams.is_mla() ? inp_kv->get_attn() : nullptr;
|
||||
auto * inp_attn_k = hparams.is_mla() ? inp_k->get_attn() : nullptr;
|
||||
|
||||
// Output ids for selecting which tokens to output
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
|
@ -272,7 +274,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
|
|||
ggml_tensor * Vcur = kv_cmpr;
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
cur = build_attn(inp_attn, layer.wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, layer.wv_b, kq_scale_mla, il);
|
||||
cur = build_attn(inp_attn_k, layer.wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, layer.wv_b, kq_scale_mla, il);
|
||||
cb(cur, "mla_out", il);
|
||||
} else { // MLA KV cache disabled. Fall back to MHA KV cache.
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k_mla, n_head, n_tokens);
|
||||
|
|
@ -302,7 +304,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
|
|||
|
||||
// Direct softmax attention (with MHA KV cache)
|
||||
// Use build_attn with inp_attn for proper mask handling
|
||||
cur = build_attn(inp_attn, layer.wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale_mla, il);
|
||||
cur = build_attn(inp_attn_kv, layer.wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale_mla, il);
|
||||
cb(cur, "mla_out", il);
|
||||
}
|
||||
} else {
|
||||
|
|
|
|||
Loading…
Reference in New Issue