changed hparams.kda_head_dim to hparams.n_embd_head_kda. added TODO comment for class llama_graph_mem_hybrid_k

This commit is contained in:
Yee Man Chan 2026-01-29 08:35:35 +08:00
parent 0444a4faa0
commit a6b2c450c8
5 changed files with 12 additions and 9 deletions

View File

@ -533,6 +533,9 @@ bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) {
return res;
}
// TODO: Hybrid input classes are a bit redundant.
// Instead of creating a hybrid input, the graph can simply create 2 separate inputs.
// Refactoring is required in the future.
void llm_graph_input_mem_hybrid_k::set_input(const llama_ubatch * ubatch) {
mctx->get_attn()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch);

View File

@ -139,10 +139,10 @@ uint32_t llama_hparams::n_embd_r() const {
return n_embd * (n_shortconv_l_cache - 1);
}
if (kda_head_dim != 0) {
if (n_embd_head_kda != 0) {
// for Kimi KDA layers
// Conv state for Q, K, V: 3 * (d_conv - 1) * n_head * head_dim
const uint32_t d_inner = n_head() * kda_head_dim; // 32 * 128 = 4096
const uint32_t d_inner = n_head() * n_embd_head_kda; // 32 * 128 = 4096
return 3 * (ssm_d_conv > 0 ? ssm_d_conv - 1 : 3) * d_inner;
}
@ -158,11 +158,11 @@ uint32_t llama_hparams::n_embd_s() const {
return n_embd * wkv_head_size;
}
if (kda_head_dim != 0) {
if (n_embd_head_kda != 0) {
// for Kimi KDA layers
// Full recurrent state: head_dim * head_dim * n_head
// h tensor shape for delta attention: [head_dim, head_dim, n_head]
return kda_head_dim * kda_head_dim * n_head(); // 128 * 128 * 32 = 524288
return n_embd_head_kda * n_embd_head_kda * n_head(); // 128 * 128 * 32 = 524288
}
// corresponds to Mamba's ssm_states size

View File

@ -138,7 +138,7 @@ struct llama_hparams {
uint32_t ssm_n_group = 0;
// for Kimi Linear KDA
uint32_t kda_head_dim = 0;
uint32_t n_embd_head_kda = 0;
// for hybrid state space models
std::array<bool, LLAMA_MAX_LAYERS> recurrent_layer_arr;

View File

@ -2459,7 +2459,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv);
ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot);
ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv);
ml.get_key(LLM_KV_KDA_HEAD_DIM, hparams.kda_head_dim);
ml.get_key(LLM_KV_KDA_HEAD_DIM, hparams.n_embd_head_kda);
// MLA qk_rope_head_dim (for reference)
// qk_rope_head_dim = 64, qk_nope_head_dim = 128, qk_head_dim = 192
@ -6801,8 +6801,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
// Assuming KDA layer if KDA tensors are present
// KDA uses head_dim = 128 (from linear_attn_config.head_dim)
const int64_t n_embd_head_k_kda = hparams.kda_head_dim;
const int64_t n_embd_head_v_kda = hparams.kda_head_dim;
const int64_t n_embd_head_k_kda = hparams.n_embd_head_kda;
const int64_t n_embd_head_v_kda = hparams.n_embd_head_kda;
const int64_t ssm_d_conv = hparams.ssm_d_conv;
// Try loading KDA specific tensors (using SSM_ prefix)

View File

@ -92,7 +92,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
// Kimi dimension constants
const int64_t n_head = hparams.n_head();
const int64_t head_dim = hparams.kda_head_dim;
const int64_t head_dim = hparams.n_embd_head_kda;
const int64_t d_conv = hparams.ssm_d_conv;
const int64_t d_inner = n_head * head_dim; // 32 * 128 = 4096
const int64_t n_seqs = ubatch.n_seqs;