diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 1aebc012a1..ac143bf031 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -533,6 +533,9 @@ bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) { return res; } +// TODO: Hybrid input classes are a bit redundant. +// Instead of creating a hybrid input, the graph can simply create 2 separate inputs. +// Refactoring is required in the future. void llm_graph_input_mem_hybrid_k::set_input(const llama_ubatch * ubatch) { mctx->get_attn()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch); diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 873c65cea8..756dda1a7a 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -139,10 +139,10 @@ uint32_t llama_hparams::n_embd_r() const { return n_embd * (n_shortconv_l_cache - 1); } - if (kda_head_dim != 0) { + if (n_embd_head_kda != 0) { // for Kimi KDA layers // Conv state for Q, K, V: 3 * (d_conv - 1) * n_head * head_dim - const uint32_t d_inner = n_head() * kda_head_dim; // 32 * 128 = 4096 + const uint32_t d_inner = n_head() * n_embd_head_kda; // 32 * 128 = 4096 return 3 * (ssm_d_conv > 0 ? ssm_d_conv - 1 : 3) * d_inner; } @@ -158,11 +158,11 @@ uint32_t llama_hparams::n_embd_s() const { return n_embd * wkv_head_size; } - if (kda_head_dim != 0) { + if (n_embd_head_kda != 0) { // for Kimi KDA layers // Full recurrent state: head_dim * head_dim * n_head // h tensor shape for delta attention: [head_dim, head_dim, n_head] - return kda_head_dim * kda_head_dim * n_head(); // 128 * 128 * 32 = 524288 + return n_embd_head_kda * n_embd_head_kda * n_head(); // 128 * 128 * 32 = 524288 } // corresponds to Mamba's ssm_states size diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 943161747c..a736ccc3d0 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -138,7 +138,7 @@ struct llama_hparams { uint32_t ssm_n_group = 0; // for Kimi Linear KDA - uint32_t kda_head_dim = 0; + uint32_t n_embd_head_kda = 0; // for hybrid state space models std::array recurrent_layer_arr; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 50900feb2c..40f3ff6e49 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2459,7 +2459,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot); ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); - ml.get_key(LLM_KV_KDA_HEAD_DIM, hparams.kda_head_dim); + ml.get_key(LLM_KV_KDA_HEAD_DIM, hparams.n_embd_head_kda); // MLA qk_rope_head_dim (for reference) // qk_rope_head_dim = 64, qk_nope_head_dim = 128, qk_head_dim = 192 @@ -6801,8 +6801,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // Assuming KDA layer if KDA tensors are present // KDA uses head_dim = 128 (from linear_attn_config.head_dim) - const int64_t n_embd_head_k_kda = hparams.kda_head_dim; - const int64_t n_embd_head_v_kda = hparams.kda_head_dim; + const int64_t n_embd_head_k_kda = hparams.n_embd_head_kda; + const int64_t n_embd_head_v_kda = hparams.n_embd_head_kda; const int64_t ssm_d_conv = hparams.ssm_d_conv; // Try loading KDA specific tensors (using SSM_ prefix) diff --git a/src/models/kimi-linear.cpp b/src/models/kimi-linear.cpp index 40007a6fa3..5f497722d0 100644 --- a/src/models/kimi-linear.cpp +++ b/src/models/kimi-linear.cpp @@ -92,7 +92,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll // Kimi dimension constants const int64_t n_head = hparams.n_head(); - const int64_t head_dim = hparams.kda_head_dim; + const int64_t head_dim = hparams.n_embd_head_kda; const int64_t d_conv = hparams.ssm_d_conv; const int64_t d_inner = n_head * head_dim; // 32 * 128 = 4096 const int64_t n_seqs = ubatch.n_seqs;