working unified delta net
This commit is contained in:
parent
62dbea1b19
commit
240bd4b29e
|
|
@ -139,6 +139,13 @@ uint32_t llama_hparams::n_embd_r() const {
|
|||
return n_embd * (n_shortconv_l_cache - 1);
|
||||
}
|
||||
|
||||
if (kda_head_dim != 0) {
|
||||
// for Kimi KDA layers
|
||||
// Conv state for Q, K, V: 3 * (d_conv - 1) * n_head * head_dim
|
||||
const uint32_t d_inner = n_head() * kda_head_dim; // 32 * 128 = 4096
|
||||
return 3 * (ssm_d_conv > 0 ? ssm_d_conv - 1 : 3) * d_inner;
|
||||
}
|
||||
|
||||
// TODO: maybe support other convolution strides than 1
|
||||
// NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
|
||||
// Corresponds to Mamba's conv_states size
|
||||
|
|
@ -151,6 +158,13 @@ uint32_t llama_hparams::n_embd_s() const {
|
|||
return n_embd * wkv_head_size;
|
||||
}
|
||||
|
||||
if (kda_head_dim != 0) {
|
||||
// for Kimi KDA layers
|
||||
// Full recurrent state: head_dim * head_dim * n_head
|
||||
// h tensor shape for delta attention: [head_dim, head_dim, n_head]
|
||||
return kda_head_dim * kda_head_dim * n_head(); // 128 * 128 * 32 = 524288
|
||||
}
|
||||
|
||||
// corresponds to Mamba's ssm_states size
|
||||
return ssm_d_state * ssm_d_inner;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -137,6 +137,9 @@ struct llama_hparams {
|
|||
uint32_t ssm_dt_rank = 0;
|
||||
uint32_t ssm_n_group = 0;
|
||||
|
||||
// for Kimi Linear KDA
|
||||
uint32_t kda_head_dim = 0;
|
||||
|
||||
// for hybrid state space models
|
||||
std::array<bool, LLAMA_MAX_LAYERS> recurrent_layer_arr;
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue