From 240bd4b29e309dfe564ceaa6e763d585a53cd45d Mon Sep 17 00:00:00 2001 From: Yee Man Chan Date: Tue, 27 Jan 2026 14:11:27 +0800 Subject: [PATCH] working unified delta net --- src/llama-hparams.cpp | 14 ++++++++++++++ src/llama-hparams.h | 3 +++ 2 files changed, 17 insertions(+) diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 5f1df995f3..d538c56216 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -139,6 +139,13 @@ uint32_t llama_hparams::n_embd_r() const { return n_embd * (n_shortconv_l_cache - 1); } + if (kda_head_dim != 0) { + // for Kimi KDA layers + // Conv state for Q, K, V: 3 * (d_conv - 1) * n_head * head_dim + const uint32_t d_inner = n_head() * kda_head_dim; // 32 * 128 = 4096 + return 3 * (ssm_d_conv > 0 ? ssm_d_conv - 1 : 3) * d_inner; + } + // TODO: maybe support other convolution strides than 1 // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed // Corresponds to Mamba's conv_states size @@ -151,6 +158,13 @@ uint32_t llama_hparams::n_embd_s() const { return n_embd * wkv_head_size; } + if (kda_head_dim != 0) { + // for Kimi KDA layers + // Full recurrent state: head_dim * head_dim * n_head + // h tensor shape for delta attention: [head_dim, head_dim, n_head] + return kda_head_dim * kda_head_dim * n_head(); // 128 * 128 * 32 = 524288 + } + // corresponds to Mamba's ssm_states size return ssm_d_state * ssm_d_inner; } diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 2bf8665520..1876e294f6 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -137,6 +137,9 @@ struct llama_hparams { uint32_t ssm_dt_rank = 0; uint32_t ssm_n_group = 0; + // for Kimi Linear KDA + uint32_t kda_head_dim = 0; + // for hybrid state space models std::array recurrent_layer_arr;