working unified delta net

This commit is contained in:
Yee Man Chan 2026-01-27 14:11:27 +08:00
parent 62dbea1b19
commit 240bd4b29e
2 changed files with 17 additions and 0 deletions

View File

@ -139,6 +139,13 @@ uint32_t llama_hparams::n_embd_r() const {
return n_embd * (n_shortconv_l_cache - 1);
}
if (kda_head_dim != 0) {
// for Kimi KDA layers
// Conv state for Q, K, V: 3 * (d_conv - 1) * n_head * head_dim
const uint32_t d_inner = n_head() * kda_head_dim; // 32 * 128 = 4096
return 3 * (ssm_d_conv > 0 ? ssm_d_conv - 1 : 3) * d_inner;
}
// TODO: maybe support other convolution strides than 1
// NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
// Corresponds to Mamba's conv_states size
@ -151,6 +158,13 @@ uint32_t llama_hparams::n_embd_s() const {
return n_embd * wkv_head_size;
}
if (kda_head_dim != 0) {
// for Kimi KDA layers
// Full recurrent state: head_dim * head_dim * n_head
// h tensor shape for delta attention: [head_dim, head_dim, n_head]
return kda_head_dim * kda_head_dim * n_head(); // 128 * 128 * 32 = 524288
}
// corresponds to Mamba's ssm_states size
return ssm_d_state * ssm_d_inner;
}

View File

@ -137,6 +137,9 @@ struct llama_hparams {
uint32_t ssm_dt_rank = 0;
uint32_t ssm_n_group = 0;
// for Kimi Linear KDA
uint32_t kda_head_dim = 0;
// for hybrid state space models
std::array<bool, LLAMA_MAX_LAYERS> recurrent_layer_arr;