kimi linear ggml-cpu
This commit is contained in:
parent
6167f39e08
commit
26a6553155
|
|
@ -1962,6 +1962,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|||
{
|
||||
ggml_compute_forward_ssm_scan(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_KDA_SCAN:
|
||||
{
|
||||
ggml_compute_forward_kda_scan(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_WIN_PART:
|
||||
{
|
||||
ggml_compute_forward_win_part(params, tensor);
|
||||
|
|
@ -2320,6 +2324,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|||
case GGML_OP_FLASH_ATTN_BACK:
|
||||
case GGML_OP_SSM_CONV:
|
||||
case GGML_OP_SSM_SCAN:
|
||||
case GGML_OP_KDA_SCAN:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_GATED_LINEAR_ATTN:
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
|
|
|
|||
|
|
@ -8686,6 +8686,9 @@ static void ggml_compute_forward_ssm_conv_f32(
|
|||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
const int ir = ir1 - ir0;
|
||||
|
||||
static int conv_debug_count = 0;
|
||||
bool do_conv_debug = false; // (ith == 0 && conv_debug_count++ < 3);
|
||||
|
||||
for (int i3 = 0; i3 < n_s; ++i3) {
|
||||
for (int i2 = 0; i2 < n_t; ++i2) {
|
||||
// {d_conv - 1 + n_t, d_inner, n_seqs}
|
||||
|
|
@ -8706,6 +8709,13 @@ static void ggml_compute_forward_ssm_conv_f32(
|
|||
sumf += s[i0 + i1*ncs] * c[i0 + i1*nc];
|
||||
}
|
||||
x[i1] = sumf;
|
||||
|
||||
// Debug output
|
||||
if (do_conv_debug && i1 == 0 && i2 == 0 && i3 == 0) {
|
||||
fprintf(stderr, "DEBUG SSM_CONV: nc=%d, nr=%d, n_t=%d, n_s=%d\n", nc, nr, n_t, n_s);
|
||||
fprintf(stderr, "DEBUG SSM_CONV: s[0..3]=%f,%f,%f,%f, c[0..3]=%f,%f,%f,%f, x[0]=%f\n",
|
||||
s[0], s[1], s[2], s[3], c[0], c[1], c[2], c[3], x[0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -8956,6 +8966,192 @@ void ggml_compute_forward_ssm_scan(
|
|||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_kda_scan
|
||||
// KDA (Kimi Delta Attention) recurrence:
|
||||
// h[t] = exp(g[t]) * h[t-1] + k[t]^T * (beta[t] * (v[t] - h[t-1] @ k[t]))
|
||||
// o[t] = q[t]^T @ h[t]
|
||||
|
||||
static void ggml_compute_forward_kda_scan_f32(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0]; // h {head_dim, head_dim, n_head, n_seqs+}
|
||||
const ggml_tensor * src1 = dst->src[1]; // q {head_dim, n_head, n_seq_tokens, n_seqs}
|
||||
const ggml_tensor * src2 = dst->src[2]; // k {head_dim, n_head, n_seq_tokens, n_seqs}
|
||||
const ggml_tensor * src3 = dst->src[3]; // v {head_dim, n_head, n_seq_tokens, n_seqs}
|
||||
const ggml_tensor * src4 = dst->src[4]; // g {head_dim, n_head, n_seq_tokens, n_seqs}
|
||||
const ggml_tensor * src5 = dst->src[5]; // beta {n_head, n_seq_tokens, n_seqs}
|
||||
const ggml_tensor * src6 = dst->src[6]; // ids {n_seqs}
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const int64_t head_dim = src0->ne[0];
|
||||
const int64_t n_head = src1->ne[1];
|
||||
const int64_t n_seq_tokens = src1->ne[2];
|
||||
const int64_t n_seqs = src1->ne[3];
|
||||
|
||||
// Output offset for hidden state
|
||||
const int64_t y_off = ggml_nelements(src1) * sizeof(float);
|
||||
|
||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src2->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src3->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src4->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src5->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
|
||||
|
||||
// Parallelize over heads
|
||||
const int dh = (n_head + nth - 1) / nth;
|
||||
const int ih0 = dh * ith;
|
||||
const int ih1 = MIN(ih0 + dh, (int)n_head);
|
||||
|
||||
const int32_t * ids = (const int32_t *) src6->data;
|
||||
|
||||
// Temporary buffer for h @ k computation
|
||||
float * hk_buf = (float *) malloc(head_dim * sizeof(float));
|
||||
|
||||
static int debug_count = 0;
|
||||
bool do_debug = false; // (ith == 0 && debug_count++ < 20);
|
||||
|
||||
for (int i3 = 0; i3 < n_seqs; ++i3) {
|
||||
// Get initial hidden state for this sequence
|
||||
const float * h0 = (const float *) ((const char *) src0->data + ids[i3] * src0->nb[3]);
|
||||
// Output hidden state location
|
||||
float * h_out = (float *) ((char *) dst->data + i3 * src0->nb[3] + y_off);
|
||||
|
||||
for (int ih = ih0; ih < ih1; ++ih) {
|
||||
// Per-head hidden state: [head_dim, head_dim]
|
||||
// Copy initial state to output (will be updated in place)
|
||||
const float * h_in = h0 + ih * head_dim * head_dim;
|
||||
float * h = h_out + ih * head_dim * head_dim;
|
||||
|
||||
// Copy initial state, but check for invalid values and clear if needed
|
||||
bool need_clear = false;
|
||||
for (int i = 0; i < head_dim * head_dim && !need_clear; ++i) {
|
||||
if (!isfinite(h_in[i]) || fabsf(h_in[i]) > 1e6f) {
|
||||
need_clear = true;
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < head_dim * head_dim; ++i) {
|
||||
h[i] = need_clear ? 0.0f : h_in[i];
|
||||
}
|
||||
|
||||
for (int it = 0; it < n_seq_tokens; ++it) {
|
||||
const float * q_raw = (const float *) ((const char *) src1->data +
|
||||
it * src1->nb[2] + i3 * src1->nb[3]) + ih * head_dim;
|
||||
const float * k_raw = (const float *) ((const char *) src2->data +
|
||||
it * src2->nb[2] + i3 * src2->nb[3]) + ih * head_dim;
|
||||
const float * v = (const float *) ((const char *) src3->data +
|
||||
it * src3->nb[2] + i3 * src3->nb[3]) + ih * head_dim;
|
||||
const float * g = (const float *) ((const char *) src4->data +
|
||||
it * src4->nb[2] + i3 * src4->nb[3]) + ih * head_dim;
|
||||
const float beta = ((const float *) ((const char *) src5->data +
|
||||
it * src5->nb[1] + i3 * src5->nb[2]))[ih];
|
||||
|
||||
float * y = (float *) dst->data +
|
||||
it * n_head * head_dim + i3 * n_seq_tokens * n_head * head_dim + ih * head_dim;
|
||||
|
||||
// L2 normalize q and k (critical for KDA stability)
|
||||
float q_norm = 0.0f, k_norm = 0.0f;
|
||||
for (int i = 0; i < head_dim; ++i) {
|
||||
q_norm += q_raw[i] * q_raw[i];
|
||||
k_norm += k_raw[i] * k_raw[i];
|
||||
}
|
||||
q_norm = sqrtf(q_norm + 1e-6f);
|
||||
k_norm = sqrtf(k_norm + 1e-6f);
|
||||
|
||||
// Debug output
|
||||
if (do_debug && ih == 0 && it == 0 && i3 == 0) {
|
||||
fprintf(stderr, "DEBUG KDA: q_raw[0]=%f, k_raw[0]=%f, v[0]=%f, g[0]=%f, beta=%f\n",
|
||||
q_raw[0], k_raw[0], v[0], g[0], beta);
|
||||
fprintf(stderr, "DEBUG KDA: q_norm=%f, k_norm=%f, exp(g[0])=%f, scale=%f\n",
|
||||
q_norm, k_norm, expf(g[0]), 1.0f / sqrtf((float)head_dim));
|
||||
}
|
||||
|
||||
// Normalized q and k with scale = 1/sqrt(head_dim)
|
||||
// Note: scale is applied only to q after L2 normalization
|
||||
const float scale = 1.0f / sqrtf((float)head_dim);
|
||||
float q[128], k[128]; // assume head_dim <= 128
|
||||
for (int i = 0; i < head_dim; ++i) {
|
||||
// L2 normalize then scale q
|
||||
q[i] = (q_raw[i] / q_norm) * scale;
|
||||
// L2 normalize k (no scale)
|
||||
k[i] = k_raw[i] / k_norm;
|
||||
}
|
||||
|
||||
// KDA recurrence: h[t] = exp(g[t]) * h[t-1] + k[t]^T * (beta[t] * (v[t] - h[t-1] @ k[t]))
|
||||
// Note: Apply decay first, then compute retrieval and update
|
||||
|
||||
// Step 1: Apply decay to h first: h = h * exp(g)
|
||||
for (int i = 0; i < head_dim; ++i) {
|
||||
const float exp_gi = expf(g[i]);
|
||||
for (int j = 0; j < head_dim; ++j) {
|
||||
h[i * head_dim + j] *= exp_gi;
|
||||
}
|
||||
}
|
||||
|
||||
// Step 2: Compute h^T @ k -> hk_buf [head_dim]
|
||||
// hk_buf[j] = sum_i (h[i,j] * k[i]) which is column j of h dotted with k
|
||||
for (int j = 0; j < head_dim; ++j) {
|
||||
float sum = 0.0f;
|
||||
for (int i = 0; i < head_dim; ++i) {
|
||||
sum += h[i * head_dim + j] * k[i];
|
||||
}
|
||||
hk_buf[j] = sum;
|
||||
}
|
||||
|
||||
// Step 3: Compute delta = beta * (v - hk) and update h
|
||||
// h = h + outer(k, delta) where outer(k,delta)[i,j] = k[i] * delta[j]
|
||||
for (int i = 0; i < head_dim; ++i) {
|
||||
for (int j = 0; j < head_dim; ++j) {
|
||||
const float delta_j = beta * (v[j] - hk_buf[j]);
|
||||
h[i * head_dim + j] += k[i] * delta_j;
|
||||
}
|
||||
}
|
||||
|
||||
// Step 4: Compute output y = h^T @ q -> [head_dim]
|
||||
// vLLM: b_o = tl.sum(b_h * b_q[:, None], 0) means o[j] = sum_i(h[i,j] * q[i])
|
||||
for (int j = 0; j < head_dim; ++j) {
|
||||
float sum = 0.0f;
|
||||
for (int i = 0; i < head_dim; ++i) {
|
||||
sum += h[i * head_dim + j] * q[i];
|
||||
}
|
||||
y[j] = sum;
|
||||
}
|
||||
|
||||
// Debug output
|
||||
if (do_debug && ih == 0 && it == 0 && i3 == 0) {
|
||||
// Find max abs value in h for stability check
|
||||
float h_max = 0.0f;
|
||||
for (int i = 0; i < head_dim * head_dim; i++) {
|
||||
if (fabsf(h[i]) > h_max) h_max = fabsf(h[i]);
|
||||
}
|
||||
fprintf(stderr, "DEBUG KDA: y[0]=%.6f, h_max=%.6f, exp(g[0])=%.6f\n",
|
||||
y[0], h_max, expf(g[0]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
free(hk_buf);
|
||||
}
|
||||
|
||||
void ggml_compute_forward_kda_scan(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
switch (dst->src[0]->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_kda_scan_f32(params, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_win_part
|
||||
|
||||
static void ggml_compute_forward_win_part_f32(
|
||||
|
|
|
|||
|
|
@ -92,6 +92,7 @@ void ggml_compute_forward_flash_attn_back(
|
|||
struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_ssm_conv(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_ssm_scan(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_kda_scan(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_win_part(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_win_unpart(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_unary(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
|
|
|
|||
Loading…
Reference in New Issue