From 26a6553155cb735c67b1db01f3901404ee0b8c9e Mon Sep 17 00:00:00 2001 From: Yee Man Chan Date: Tue, 2 Dec 2025 11:20:46 +0800 Subject: [PATCH] kimi linear ggml-cpu --- ggml/src/ggml-cpu/ggml-cpu.c | 5 + ggml/src/ggml-cpu/ops.cpp | 196 +++++++++++++++++++++++++++++++++++ ggml/src/ggml-cpu/ops.h | 1 + 3 files changed, 202 insertions(+) diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 3247af8bb0..7b40f1e8c2 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -1962,6 +1962,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_ssm_scan(params, tensor); } break; + case GGML_OP_KDA_SCAN: + { + ggml_compute_forward_kda_scan(params, tensor); + } break; case GGML_OP_WIN_PART: { ggml_compute_forward_win_part(params, tensor); @@ -2320,6 +2324,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_FLASH_ATTN_BACK: case GGML_OP_SSM_CONV: case GGML_OP_SSM_SCAN: + case GGML_OP_KDA_SCAN: case GGML_OP_RWKV_WKV6: case GGML_OP_GATED_LINEAR_ATTN: case GGML_OP_RWKV_WKV7: diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 608e82af69..9c93e0c101 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -8686,6 +8686,9 @@ static void ggml_compute_forward_ssm_conv_f32( const int ir1 = MIN(ir0 + dr, nr); const int ir = ir1 - ir0; + static int conv_debug_count = 0; + bool do_conv_debug = false; // (ith == 0 && conv_debug_count++ < 3); + for (int i3 = 0; i3 < n_s; ++i3) { for (int i2 = 0; i2 < n_t; ++i2) { // {d_conv - 1 + n_t, d_inner, n_seqs} @@ -8706,6 +8709,13 @@ static void ggml_compute_forward_ssm_conv_f32( sumf += s[i0 + i1*ncs] * c[i0 + i1*nc]; } x[i1] = sumf; + + // Debug output + if (do_conv_debug && i1 == 0 && i2 == 0 && i3 == 0) { + fprintf(stderr, "DEBUG SSM_CONV: nc=%d, nr=%d, n_t=%d, n_s=%d\n", nc, nr, n_t, n_s); + fprintf(stderr, "DEBUG SSM_CONV: s[0..3]=%f,%f,%f,%f, c[0..3]=%f,%f,%f,%f, x[0]=%f\n", + s[0], s[1], s[2], s[3], c[0], c[1], c[2], c[3], x[0]); + } } } } @@ -8956,6 +8966,192 @@ void ggml_compute_forward_ssm_scan( } } +// ggml_compute_forward_kda_scan +// KDA (Kimi Delta Attention) recurrence: +// h[t] = exp(g[t]) * h[t-1] + k[t]^T * (beta[t] * (v[t] - h[t-1] @ k[t])) +// o[t] = q[t]^T @ h[t] + +static void ggml_compute_forward_kda_scan_f32( + const ggml_compute_params * params, + ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; // h {head_dim, head_dim, n_head, n_seqs+} + const ggml_tensor * src1 = dst->src[1]; // q {head_dim, n_head, n_seq_tokens, n_seqs} + const ggml_tensor * src2 = dst->src[2]; // k {head_dim, n_head, n_seq_tokens, n_seqs} + const ggml_tensor * src3 = dst->src[3]; // v {head_dim, n_head, n_seq_tokens, n_seqs} + const ggml_tensor * src4 = dst->src[4]; // g {head_dim, n_head, n_seq_tokens, n_seqs} + const ggml_tensor * src5 = dst->src[5]; // beta {n_head, n_seq_tokens, n_seqs} + const ggml_tensor * src6 = dst->src[6]; // ids {n_seqs} + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t head_dim = src0->ne[0]; + const int64_t n_head = src1->ne[1]; + const int64_t n_seq_tokens = src1->ne[2]; + const int64_t n_seqs = src1->ne[3]; + + // Output offset for hidden state + const int64_t y_off = ggml_nelements(src1) * sizeof(float); + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + GGML_ASSERT(src1->nb[0] == sizeof(float)); + GGML_ASSERT(src2->nb[0] == sizeof(float)); + GGML_ASSERT(src3->nb[0] == sizeof(float)); + GGML_ASSERT(src4->nb[0] == sizeof(float)); + GGML_ASSERT(src5->nb[0] == sizeof(float)); + GGML_ASSERT(src6->nb[0] == sizeof(int32_t)); + + // Parallelize over heads + const int dh = (n_head + nth - 1) / nth; + const int ih0 = dh * ith; + const int ih1 = MIN(ih0 + dh, (int)n_head); + + const int32_t * ids = (const int32_t *) src6->data; + + // Temporary buffer for h @ k computation + float * hk_buf = (float *) malloc(head_dim * sizeof(float)); + + static int debug_count = 0; + bool do_debug = false; // (ith == 0 && debug_count++ < 20); + + for (int i3 = 0; i3 < n_seqs; ++i3) { + // Get initial hidden state for this sequence + const float * h0 = (const float *) ((const char *) src0->data + ids[i3] * src0->nb[3]); + // Output hidden state location + float * h_out = (float *) ((char *) dst->data + i3 * src0->nb[3] + y_off); + + for (int ih = ih0; ih < ih1; ++ih) { + // Per-head hidden state: [head_dim, head_dim] + // Copy initial state to output (will be updated in place) + const float * h_in = h0 + ih * head_dim * head_dim; + float * h = h_out + ih * head_dim * head_dim; + + // Copy initial state, but check for invalid values and clear if needed + bool need_clear = false; + for (int i = 0; i < head_dim * head_dim && !need_clear; ++i) { + if (!isfinite(h_in[i]) || fabsf(h_in[i]) > 1e6f) { + need_clear = true; + } + } + for (int i = 0; i < head_dim * head_dim; ++i) { + h[i] = need_clear ? 0.0f : h_in[i]; + } + + for (int it = 0; it < n_seq_tokens; ++it) { + const float * q_raw = (const float *) ((const char *) src1->data + + it * src1->nb[2] + i3 * src1->nb[3]) + ih * head_dim; + const float * k_raw = (const float *) ((const char *) src2->data + + it * src2->nb[2] + i3 * src2->nb[3]) + ih * head_dim; + const float * v = (const float *) ((const char *) src3->data + + it * src3->nb[2] + i3 * src3->nb[3]) + ih * head_dim; + const float * g = (const float *) ((const char *) src4->data + + it * src4->nb[2] + i3 * src4->nb[3]) + ih * head_dim; + const float beta = ((const float *) ((const char *) src5->data + + it * src5->nb[1] + i3 * src5->nb[2]))[ih]; + + float * y = (float *) dst->data + + it * n_head * head_dim + i3 * n_seq_tokens * n_head * head_dim + ih * head_dim; + + // L2 normalize q and k (critical for KDA stability) + float q_norm = 0.0f, k_norm = 0.0f; + for (int i = 0; i < head_dim; ++i) { + q_norm += q_raw[i] * q_raw[i]; + k_norm += k_raw[i] * k_raw[i]; + } + q_norm = sqrtf(q_norm + 1e-6f); + k_norm = sqrtf(k_norm + 1e-6f); + + // Debug output + if (do_debug && ih == 0 && it == 0 && i3 == 0) { + fprintf(stderr, "DEBUG KDA: q_raw[0]=%f, k_raw[0]=%f, v[0]=%f, g[0]=%f, beta=%f\n", + q_raw[0], k_raw[0], v[0], g[0], beta); + fprintf(stderr, "DEBUG KDA: q_norm=%f, k_norm=%f, exp(g[0])=%f, scale=%f\n", + q_norm, k_norm, expf(g[0]), 1.0f / sqrtf((float)head_dim)); + } + + // Normalized q and k with scale = 1/sqrt(head_dim) + // Note: scale is applied only to q after L2 normalization + const float scale = 1.0f / sqrtf((float)head_dim); + float q[128], k[128]; // assume head_dim <= 128 + for (int i = 0; i < head_dim; ++i) { + // L2 normalize then scale q + q[i] = (q_raw[i] / q_norm) * scale; + // L2 normalize k (no scale) + k[i] = k_raw[i] / k_norm; + } + + // KDA recurrence: h[t] = exp(g[t]) * h[t-1] + k[t]^T * (beta[t] * (v[t] - h[t-1] @ k[t])) + // Note: Apply decay first, then compute retrieval and update + + // Step 1: Apply decay to h first: h = h * exp(g) + for (int i = 0; i < head_dim; ++i) { + const float exp_gi = expf(g[i]); + for (int j = 0; j < head_dim; ++j) { + h[i * head_dim + j] *= exp_gi; + } + } + + // Step 2: Compute h^T @ k -> hk_buf [head_dim] + // hk_buf[j] = sum_i (h[i,j] * k[i]) which is column j of h dotted with k + for (int j = 0; j < head_dim; ++j) { + float sum = 0.0f; + for (int i = 0; i < head_dim; ++i) { + sum += h[i * head_dim + j] * k[i]; + } + hk_buf[j] = sum; + } + + // Step 3: Compute delta = beta * (v - hk) and update h + // h = h + outer(k, delta) where outer(k,delta)[i,j] = k[i] * delta[j] + for (int i = 0; i < head_dim; ++i) { + for (int j = 0; j < head_dim; ++j) { + const float delta_j = beta * (v[j] - hk_buf[j]); + h[i * head_dim + j] += k[i] * delta_j; + } + } + + // Step 4: Compute output y = h^T @ q -> [head_dim] + // vLLM: b_o = tl.sum(b_h * b_q[:, None], 0) means o[j] = sum_i(h[i,j] * q[i]) + for (int j = 0; j < head_dim; ++j) { + float sum = 0.0f; + for (int i = 0; i < head_dim; ++i) { + sum += h[i * head_dim + j] * q[i]; + } + y[j] = sum; + } + + // Debug output + if (do_debug && ih == 0 && it == 0 && i3 == 0) { + // Find max abs value in h for stability check + float h_max = 0.0f; + for (int i = 0; i < head_dim * head_dim; i++) { + if (fabsf(h[i]) > h_max) h_max = fabsf(h[i]); + } + fprintf(stderr, "DEBUG KDA: y[0]=%.6f, h_max=%.6f, exp(g[0])=%.6f\n", + y[0], h_max, expf(g[0])); + } + } + } + } + + free(hk_buf); +} + +void ggml_compute_forward_kda_scan( + const ggml_compute_params * params, + ggml_tensor * dst) { + switch (dst->src[0]->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_kda_scan_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_win_part static void ggml_compute_forward_win_part_f32( diff --git a/ggml/src/ggml-cpu/ops.h b/ggml/src/ggml-cpu/ops.h index 0fdfee7976..080cf6e090 100644 --- a/ggml/src/ggml-cpu/ops.h +++ b/ggml/src/ggml-cpu/ops.h @@ -92,6 +92,7 @@ void ggml_compute_forward_flash_attn_back( struct ggml_tensor * dst); void ggml_compute_forward_ssm_conv(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_ssm_scan(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_kda_scan(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_win_part(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_win_unpart(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_unary(const struct ggml_compute_params * params, struct ggml_tensor * dst);