changed hparams.kda_head_dim to hparams.n_embd_head_kda. added TODO comment for class llama_graph_mem_hybrid_k
This commit is contained in:
parent
0444a4faa0
commit
a6b2c450c8
|
|
@ -533,6 +533,9 @@ bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) {
|
|||
return res;
|
||||
}
|
||||
|
||||
// TODO: Hybrid input classes are a bit redundant.
|
||||
// Instead of creating a hybrid input, the graph can simply create 2 separate inputs.
|
||||
// Refactoring is required in the future.
|
||||
void llm_graph_input_mem_hybrid_k::set_input(const llama_ubatch * ubatch) {
|
||||
mctx->get_attn()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch);
|
||||
|
||||
|
|
|
|||
|
|
@ -139,10 +139,10 @@ uint32_t llama_hparams::n_embd_r() const {
|
|||
return n_embd * (n_shortconv_l_cache - 1);
|
||||
}
|
||||
|
||||
if (kda_head_dim != 0) {
|
||||
if (n_embd_head_kda != 0) {
|
||||
// for Kimi KDA layers
|
||||
// Conv state for Q, K, V: 3 * (d_conv - 1) * n_head * head_dim
|
||||
const uint32_t d_inner = n_head() * kda_head_dim; // 32 * 128 = 4096
|
||||
const uint32_t d_inner = n_head() * n_embd_head_kda; // 32 * 128 = 4096
|
||||
return 3 * (ssm_d_conv > 0 ? ssm_d_conv - 1 : 3) * d_inner;
|
||||
}
|
||||
|
||||
|
|
@ -158,11 +158,11 @@ uint32_t llama_hparams::n_embd_s() const {
|
|||
return n_embd * wkv_head_size;
|
||||
}
|
||||
|
||||
if (kda_head_dim != 0) {
|
||||
if (n_embd_head_kda != 0) {
|
||||
// for Kimi KDA layers
|
||||
// Full recurrent state: head_dim * head_dim * n_head
|
||||
// h tensor shape for delta attention: [head_dim, head_dim, n_head]
|
||||
return kda_head_dim * kda_head_dim * n_head(); // 128 * 128 * 32 = 524288
|
||||
return n_embd_head_kda * n_embd_head_kda * n_head(); // 128 * 128 * 32 = 524288
|
||||
}
|
||||
|
||||
// corresponds to Mamba's ssm_states size
|
||||
|
|
|
|||
|
|
@ -138,7 +138,7 @@ struct llama_hparams {
|
|||
uint32_t ssm_n_group = 0;
|
||||
|
||||
// for Kimi Linear KDA
|
||||
uint32_t kda_head_dim = 0;
|
||||
uint32_t n_embd_head_kda = 0;
|
||||
|
||||
// for hybrid state space models
|
||||
std::array<bool, LLAMA_MAX_LAYERS> recurrent_layer_arr;
|
||||
|
|
|
|||
|
|
@ -2459,7 +2459,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv);
|
||||
ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot);
|
||||
ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv);
|
||||
ml.get_key(LLM_KV_KDA_HEAD_DIM, hparams.kda_head_dim);
|
||||
ml.get_key(LLM_KV_KDA_HEAD_DIM, hparams.n_embd_head_kda);
|
||||
|
||||
// MLA qk_rope_head_dim (for reference)
|
||||
// qk_rope_head_dim = 64, qk_nope_head_dim = 128, qk_head_dim = 192
|
||||
|
|
@ -6801,8 +6801,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|||
// Assuming KDA layer if KDA tensors are present
|
||||
|
||||
// KDA uses head_dim = 128 (from linear_attn_config.head_dim)
|
||||
const int64_t n_embd_head_k_kda = hparams.kda_head_dim;
|
||||
const int64_t n_embd_head_v_kda = hparams.kda_head_dim;
|
||||
const int64_t n_embd_head_k_kda = hparams.n_embd_head_kda;
|
||||
const int64_t n_embd_head_v_kda = hparams.n_embd_head_kda;
|
||||
const int64_t ssm_d_conv = hparams.ssm_d_conv;
|
||||
|
||||
// Try loading KDA specific tensors (using SSM_ prefix)
|
||||
|
|
|
|||
|
|
@ -92,7 +92,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
|
|||
|
||||
// Kimi dimension constants
|
||||
const int64_t n_head = hparams.n_head();
|
||||
const int64_t head_dim = hparams.kda_head_dim;
|
||||
const int64_t head_dim = hparams.n_embd_head_kda;
|
||||
const int64_t d_conv = hparams.ssm_d_conv;
|
||||
const int64_t d_inner = n_head * head_dim; // 32 * 128 = 4096
|
||||
const int64_t n_seqs = ubatch.n_seqs;
|
||||
|
|
|
|||
Loading…
Reference in New Issue