kimi linear ggml-cpu

This commit is contained in:
Yee Man Chan 2025-12-02 11:20:46 +08:00
parent 6167f39e08
commit 26a6553155
3 changed files with 202 additions and 0 deletions

View File

@ -1962,6 +1962,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_ssm_scan(params, tensor);
} break;
case GGML_OP_KDA_SCAN:
{
ggml_compute_forward_kda_scan(params, tensor);
} break;
case GGML_OP_WIN_PART:
{
ggml_compute_forward_win_part(params, tensor);
@ -2320,6 +2324,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_FLASH_ATTN_BACK:
case GGML_OP_SSM_CONV:
case GGML_OP_SSM_SCAN:
case GGML_OP_KDA_SCAN:
case GGML_OP_RWKV_WKV6:
case GGML_OP_GATED_LINEAR_ATTN:
case GGML_OP_RWKV_WKV7:

View File

@ -8686,6 +8686,9 @@ static void ggml_compute_forward_ssm_conv_f32(
const int ir1 = MIN(ir0 + dr, nr);
const int ir = ir1 - ir0;
static int conv_debug_count = 0;
bool do_conv_debug = false; // (ith == 0 && conv_debug_count++ < 3);
for (int i3 = 0; i3 < n_s; ++i3) {
for (int i2 = 0; i2 < n_t; ++i2) {
// {d_conv - 1 + n_t, d_inner, n_seqs}
@ -8706,6 +8709,13 @@ static void ggml_compute_forward_ssm_conv_f32(
sumf += s[i0 + i1*ncs] * c[i0 + i1*nc];
}
x[i1] = sumf;
// Debug output
if (do_conv_debug && i1 == 0 && i2 == 0 && i3 == 0) {
fprintf(stderr, "DEBUG SSM_CONV: nc=%d, nr=%d, n_t=%d, n_s=%d\n", nc, nr, n_t, n_s);
fprintf(stderr, "DEBUG SSM_CONV: s[0..3]=%f,%f,%f,%f, c[0..3]=%f,%f,%f,%f, x[0]=%f\n",
s[0], s[1], s[2], s[3], c[0], c[1], c[2], c[3], x[0]);
}
}
}
}
@ -8956,6 +8966,192 @@ void ggml_compute_forward_ssm_scan(
}
}
// ggml_compute_forward_kda_scan
// KDA (Kimi Delta Attention) recurrence:
// h[t] = exp(g[t]) * h[t-1] + k[t]^T * (beta[t] * (v[t] - h[t-1] @ k[t]))
// o[t] = q[t]^T @ h[t]
static void ggml_compute_forward_kda_scan_f32(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0]; // h {head_dim, head_dim, n_head, n_seqs+}
const ggml_tensor * src1 = dst->src[1]; // q {head_dim, n_head, n_seq_tokens, n_seqs}
const ggml_tensor * src2 = dst->src[2]; // k {head_dim, n_head, n_seq_tokens, n_seqs}
const ggml_tensor * src3 = dst->src[3]; // v {head_dim, n_head, n_seq_tokens, n_seqs}
const ggml_tensor * src4 = dst->src[4]; // g {head_dim, n_head, n_seq_tokens, n_seqs}
const ggml_tensor * src5 = dst->src[5]; // beta {n_head, n_seq_tokens, n_seqs}
const ggml_tensor * src6 = dst->src[6]; // ids {n_seqs}
const int ith = params->ith;
const int nth = params->nth;
const int64_t head_dim = src0->ne[0];
const int64_t n_head = src1->ne[1];
const int64_t n_seq_tokens = src1->ne[2];
const int64_t n_seqs = src1->ne[3];
// Output offset for hidden state
const int64_t y_off = ggml_nelements(src1) * sizeof(float);
GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(src1->nb[0] == sizeof(float));
GGML_ASSERT(src2->nb[0] == sizeof(float));
GGML_ASSERT(src3->nb[0] == sizeof(float));
GGML_ASSERT(src4->nb[0] == sizeof(float));
GGML_ASSERT(src5->nb[0] == sizeof(float));
GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
// Parallelize over heads
const int dh = (n_head + nth - 1) / nth;
const int ih0 = dh * ith;
const int ih1 = MIN(ih0 + dh, (int)n_head);
const int32_t * ids = (const int32_t *) src6->data;
// Temporary buffer for h @ k computation
float * hk_buf = (float *) malloc(head_dim * sizeof(float));
static int debug_count = 0;
bool do_debug = false; // (ith == 0 && debug_count++ < 20);
for (int i3 = 0; i3 < n_seqs; ++i3) {
// Get initial hidden state for this sequence
const float * h0 = (const float *) ((const char *) src0->data + ids[i3] * src0->nb[3]);
// Output hidden state location
float * h_out = (float *) ((char *) dst->data + i3 * src0->nb[3] + y_off);
for (int ih = ih0; ih < ih1; ++ih) {
// Per-head hidden state: [head_dim, head_dim]
// Copy initial state to output (will be updated in place)
const float * h_in = h0 + ih * head_dim * head_dim;
float * h = h_out + ih * head_dim * head_dim;
// Copy initial state, but check for invalid values and clear if needed
bool need_clear = false;
for (int i = 0; i < head_dim * head_dim && !need_clear; ++i) {
if (!isfinite(h_in[i]) || fabsf(h_in[i]) > 1e6f) {
need_clear = true;
}
}
for (int i = 0; i < head_dim * head_dim; ++i) {
h[i] = need_clear ? 0.0f : h_in[i];
}
for (int it = 0; it < n_seq_tokens; ++it) {
const float * q_raw = (const float *) ((const char *) src1->data +
it * src1->nb[2] + i3 * src1->nb[3]) + ih * head_dim;
const float * k_raw = (const float *) ((const char *) src2->data +
it * src2->nb[2] + i3 * src2->nb[3]) + ih * head_dim;
const float * v = (const float *) ((const char *) src3->data +
it * src3->nb[2] + i3 * src3->nb[3]) + ih * head_dim;
const float * g = (const float *) ((const char *) src4->data +
it * src4->nb[2] + i3 * src4->nb[3]) + ih * head_dim;
const float beta = ((const float *) ((const char *) src5->data +
it * src5->nb[1] + i3 * src5->nb[2]))[ih];
float * y = (float *) dst->data +
it * n_head * head_dim + i3 * n_seq_tokens * n_head * head_dim + ih * head_dim;
// L2 normalize q and k (critical for KDA stability)
float q_norm = 0.0f, k_norm = 0.0f;
for (int i = 0; i < head_dim; ++i) {
q_norm += q_raw[i] * q_raw[i];
k_norm += k_raw[i] * k_raw[i];
}
q_norm = sqrtf(q_norm + 1e-6f);
k_norm = sqrtf(k_norm + 1e-6f);
// Debug output
if (do_debug && ih == 0 && it == 0 && i3 == 0) {
fprintf(stderr, "DEBUG KDA: q_raw[0]=%f, k_raw[0]=%f, v[0]=%f, g[0]=%f, beta=%f\n",
q_raw[0], k_raw[0], v[0], g[0], beta);
fprintf(stderr, "DEBUG KDA: q_norm=%f, k_norm=%f, exp(g[0])=%f, scale=%f\n",
q_norm, k_norm, expf(g[0]), 1.0f / sqrtf((float)head_dim));
}
// Normalized q and k with scale = 1/sqrt(head_dim)
// Note: scale is applied only to q after L2 normalization
const float scale = 1.0f / sqrtf((float)head_dim);
float q[128], k[128]; // assume head_dim <= 128
for (int i = 0; i < head_dim; ++i) {
// L2 normalize then scale q
q[i] = (q_raw[i] / q_norm) * scale;
// L2 normalize k (no scale)
k[i] = k_raw[i] / k_norm;
}
// KDA recurrence: h[t] = exp(g[t]) * h[t-1] + k[t]^T * (beta[t] * (v[t] - h[t-1] @ k[t]))
// Note: Apply decay first, then compute retrieval and update
// Step 1: Apply decay to h first: h = h * exp(g)
for (int i = 0; i < head_dim; ++i) {
const float exp_gi = expf(g[i]);
for (int j = 0; j < head_dim; ++j) {
h[i * head_dim + j] *= exp_gi;
}
}
// Step 2: Compute h^T @ k -> hk_buf [head_dim]
// hk_buf[j] = sum_i (h[i,j] * k[i]) which is column j of h dotted with k
for (int j = 0; j < head_dim; ++j) {
float sum = 0.0f;
for (int i = 0; i < head_dim; ++i) {
sum += h[i * head_dim + j] * k[i];
}
hk_buf[j] = sum;
}
// Step 3: Compute delta = beta * (v - hk) and update h
// h = h + outer(k, delta) where outer(k,delta)[i,j] = k[i] * delta[j]
for (int i = 0; i < head_dim; ++i) {
for (int j = 0; j < head_dim; ++j) {
const float delta_j = beta * (v[j] - hk_buf[j]);
h[i * head_dim + j] += k[i] * delta_j;
}
}
// Step 4: Compute output y = h^T @ q -> [head_dim]
// vLLM: b_o = tl.sum(b_h * b_q[:, None], 0) means o[j] = sum_i(h[i,j] * q[i])
for (int j = 0; j < head_dim; ++j) {
float sum = 0.0f;
for (int i = 0; i < head_dim; ++i) {
sum += h[i * head_dim + j] * q[i];
}
y[j] = sum;
}
// Debug output
if (do_debug && ih == 0 && it == 0 && i3 == 0) {
// Find max abs value in h for stability check
float h_max = 0.0f;
for (int i = 0; i < head_dim * head_dim; i++) {
if (fabsf(h[i]) > h_max) h_max = fabsf(h[i]);
}
fprintf(stderr, "DEBUG KDA: y[0]=%.6f, h_max=%.6f, exp(g[0])=%.6f\n",
y[0], h_max, expf(g[0]));
}
}
}
}
free(hk_buf);
}
void ggml_compute_forward_kda_scan(
const ggml_compute_params * params,
ggml_tensor * dst) {
switch (dst->src[0]->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_kda_scan_f32(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}
// ggml_compute_forward_win_part
static void ggml_compute_forward_win_part_f32(

View File

@ -92,6 +92,7 @@ void ggml_compute_forward_flash_attn_back(
struct ggml_tensor * dst);
void ggml_compute_forward_ssm_conv(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_ssm_scan(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_kda_scan(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_win_part(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_win_unpart(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_unary(const struct ggml_compute_params * params, struct ggml_tensor * dst);