rewrite get_vocab for KimiLinear. Removed all kda_scan code

This commit is contained in:
Yee Man Chan 2025-12-18 20:46:10 +08:00
parent ae9771d1dc
commit f9a11d7758
7 changed files with 0 additions and 471 deletions

View File

@ -1962,10 +1962,6 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_ssm_scan(params, tensor);
} break;
case GGML_OP_KDA_SCAN:
{
ggml_compute_forward_kda_scan(params, tensor);
} break;
case GGML_OP_WIN_PART:
{
ggml_compute_forward_win_part(params, tensor);

View File

@ -8686,7 +8686,6 @@ static void ggml_compute_forward_ssm_conv_f32(
const int ir1 = MIN(ir0 + dr, nr);
const int ir = ir1 - ir0;
static int conv_debug_count = 0;
bool do_conv_debug = false; // (ith == 0 && conv_debug_count++ < 3);
for (int i3 = 0; i3 < n_s; ++i3) {
@ -8966,192 +8965,6 @@ void ggml_compute_forward_ssm_scan(
}
}
// ggml_compute_forward_kda_scan
// KDA (Kimi Delta Attention) recurrence:
// h[t] = exp(g[t]) * h[t-1] + k[t]^T * (beta[t] * (v[t] - h[t-1] @ k[t]))
// o[t] = q[t]^T @ h[t]
static void ggml_compute_forward_kda_scan_f32(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0]; // h {head_dim, head_dim, n_head, n_seqs+}
const ggml_tensor * src1 = dst->src[1]; // q {head_dim, n_head, n_seq_tokens, n_seqs}
const ggml_tensor * src2 = dst->src[2]; // k {head_dim, n_head, n_seq_tokens, n_seqs}
const ggml_tensor * src3 = dst->src[3]; // v {head_dim, n_head, n_seq_tokens, n_seqs}
const ggml_tensor * src4 = dst->src[4]; // g {head_dim, n_head, n_seq_tokens, n_seqs}
const ggml_tensor * src5 = dst->src[5]; // beta {n_head, n_seq_tokens, n_seqs}
const ggml_tensor * src6 = dst->src[6]; // ids {n_seqs}
const int ith = params->ith;
const int nth = params->nth;
const int64_t head_dim = src0->ne[0];
const int64_t n_head = src1->ne[1];
const int64_t n_seq_tokens = src1->ne[2];
const int64_t n_seqs = src1->ne[3];
// Output offset for hidden state
const int64_t y_off = ggml_nelements(src1) * sizeof(float);
GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(src1->nb[0] == sizeof(float));
GGML_ASSERT(src2->nb[0] == sizeof(float));
GGML_ASSERT(src3->nb[0] == sizeof(float));
GGML_ASSERT(src4->nb[0] == sizeof(float));
GGML_ASSERT(src5->nb[0] == sizeof(float));
GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
// Parallelize over heads
const int dh = (n_head + nth - 1) / nth;
const int ih0 = dh * ith;
const int ih1 = MIN(ih0 + dh, (int)n_head);
const int32_t * ids = (const int32_t *) src6->data;
// Temporary buffer for h @ k computation
float * hk_buf = (float *) malloc(head_dim * sizeof(float));
static int debug_count = 0;
bool do_debug = false; // (ith == 0 && debug_count++ < 20);
for (int i3 = 0; i3 < n_seqs; ++i3) {
// Get initial hidden state for this sequence
const float * h0 = (const float *) ((const char *) src0->data + ids[i3] * src0->nb[3]);
// Output hidden state location
float * h_out = (float *) ((char *) dst->data + i3 * src0->nb[3] + y_off);
for (int ih = ih0; ih < ih1; ++ih) {
// Per-head hidden state: [head_dim, head_dim]
// Copy initial state to output (will be updated in place)
const float * h_in = h0 + ih * head_dim * head_dim;
float * h = h_out + ih * head_dim * head_dim;
// Copy initial state, but check for invalid values and clear if needed
bool need_clear = false;
for (int i = 0; i < head_dim * head_dim && !need_clear; ++i) {
if (!isfinite(h_in[i]) || fabsf(h_in[i]) > 1e6f) {
need_clear = true;
}
}
for (int i = 0; i < head_dim * head_dim; ++i) {
h[i] = need_clear ? 0.0f : h_in[i];
}
for (int it = 0; it < n_seq_tokens; ++it) {
const float * q_raw = (const float *) ((const char *) src1->data +
it * src1->nb[2] + i3 * src1->nb[3]) + ih * head_dim;
const float * k_raw = (const float *) ((const char *) src2->data +
it * src2->nb[2] + i3 * src2->nb[3]) + ih * head_dim;
const float * v = (const float *) ((const char *) src3->data +
it * src3->nb[2] + i3 * src3->nb[3]) + ih * head_dim;
const float * g = (const float *) ((const char *) src4->data +
it * src4->nb[2] + i3 * src4->nb[3]) + ih * head_dim;
const float beta = ((const float *) ((const char *) src5->data +
it * src5->nb[1] + i3 * src5->nb[2]))[ih];
float * y = (float *) dst->data +
it * n_head * head_dim + i3 * n_seq_tokens * n_head * head_dim + ih * head_dim;
// L2 normalize q and k (critical for KDA stability)
float q_norm = 0.0f, k_norm = 0.0f;
for (int i = 0; i < head_dim; ++i) {
q_norm += q_raw[i] * q_raw[i];
k_norm += k_raw[i] * k_raw[i];
}
q_norm = sqrtf(q_norm + 1e-6f);
k_norm = sqrtf(k_norm + 1e-6f);
// Debug output
if (do_debug && ih == 0 && it == 0 && i3 == 0) {
fprintf(stderr, "DEBUG KDA: q_raw[0]=%f, k_raw[0]=%f, v[0]=%f, g[0]=%f, beta=%f\n",
q_raw[0], k_raw[0], v[0], g[0], beta);
fprintf(stderr, "DEBUG KDA: q_norm=%f, k_norm=%f, exp(g[0])=%f, scale=%f\n",
q_norm, k_norm, expf(g[0]), 1.0f / sqrtf((float)head_dim));
}
// Normalized q and k with scale = 1/sqrt(head_dim)
// Note: scale is applied only to q after L2 normalization
const float scale = 1.0f / sqrtf((float)head_dim);
float q[128], k[128]; // assume head_dim <= 128
for (int i = 0; i < head_dim; ++i) {
// L2 normalize then scale q
q[i] = (q_raw[i] / q_norm) * scale;
// L2 normalize k (no scale)
k[i] = k_raw[i] / k_norm;
}
// KDA recurrence: h[t] = exp(g[t]) * h[t-1] + k[t]^T * (beta[t] * (v[t] - h[t-1] @ k[t]))
// Note: Apply decay first, then compute retrieval and update
// Step 1: Apply decay to h first: h = h * exp(g)
for (int i = 0; i < head_dim; ++i) {
const float exp_gi = expf(g[i]);
for (int j = 0; j < head_dim; ++j) {
h[i * head_dim + j] *= exp_gi;
}
}
// Step 2: Compute h^T @ k -> hk_buf [head_dim]
// hk_buf[j] = sum_i (h[i,j] * k[i]) which is column j of h dotted with k
for (int j = 0; j < head_dim; ++j) {
float sum = 0.0f;
for (int i = 0; i < head_dim; ++i) {
sum += h[i * head_dim + j] * k[i];
}
hk_buf[j] = sum;
}
// Step 3: Compute delta = beta * (v - hk) and update h
// h = h + outer(k, delta) where outer(k,delta)[i,j] = k[i] * delta[j]
for (int i = 0; i < head_dim; ++i) {
for (int j = 0; j < head_dim; ++j) {
const float delta_j = beta * (v[j] - hk_buf[j]);
h[i * head_dim + j] += k[i] * delta_j;
}
}
// Step 4: Compute output y = h^T @ q -> [head_dim]
// vLLM: b_o = tl.sum(b_h * b_q[:, None], 0) means o[j] = sum_i(h[i,j] * q[i])
for (int j = 0; j < head_dim; ++j) {
float sum = 0.0f;
for (int i = 0; i < head_dim; ++i) {
sum += h[i * head_dim + j] * q[i];
}
y[j] = sum;
}
// Debug output
if (do_debug && ih == 0 && it == 0 && i3 == 0) {
// Find max abs value in h for stability check
float h_max = 0.0f;
for (int i = 0; i < head_dim * head_dim; i++) {
if (fabsf(h[i]) > h_max) h_max = fabsf(h[i]);
}
fprintf(stderr, "DEBUG KDA: y[0]=%.6f, h_max=%.6f, exp(g[0])=%.6f\n",
y[0], h_max, expf(g[0]));
}
}
}
}
free(hk_buf);
}
void ggml_compute_forward_kda_scan(
const ggml_compute_params * params,
ggml_tensor * dst) {
switch (dst->src[0]->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_kda_scan_f32(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}
// ggml_compute_forward_win_part
static void ggml_compute_forward_win_part_f32(

View File

@ -92,7 +92,6 @@ void ggml_compute_forward_flash_attn_back(
struct ggml_tensor * dst);
void ggml_compute_forward_ssm_conv(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_ssm_scan(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_kda_scan(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_win_part(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_win_unpart(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_unary(const struct ggml_compute_params * params, struct ggml_tensor * dst);

View File

@ -41,7 +41,6 @@
#include "ggml-cuda/softmax.cuh"
#include "ggml-cuda/ssm-conv.cuh"
#include "ggml-cuda/ssm-scan.cuh"
#include "ggml-cuda/kda-scan.cuh"
#include "ggml-cuda/sum.cuh"
#include "ggml-cuda/sumrows.cuh"
#include "ggml-cuda/mean.cuh"
@ -2693,9 +2692,6 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_SSM_SCAN:
ggml_cuda_op_ssm_scan(ctx, dst);
break;
case GGML_OP_KDA_SCAN:
ggml_cuda_op_kda_scan(ctx, dst);
break;
case GGML_OP_ARGSORT:
ggml_cuda_op_argsort(ctx, dst);
break;

View File

@ -1,209 +0,0 @@
#include "kda-scan.cuh"
// KDA (Kimi Delta Attention) scan CUDA kernel
// Recurrence:
// h[t] = exp(g[t]) * h[t-1] + k[t]^T * (beta[t] * (v[t] - h[t-1] @ k[t]))
// o[t] = q[t]^T @ h[t]
//
// This kernel uses global memory for the hidden state to avoid shared memory limits.
// Each block processes one head for one sequence.
__global__ void kda_scan_f32_kernel(
const float * __restrict__ src0, // h: [head_dim, head_dim, n_head, n_seqs+]
const float * __restrict__ src1, // q: [head_dim, n_head, n_seq_tokens, n_seqs]
const float * __restrict__ src2, // k: [head_dim, n_head, n_seq_tokens, n_seqs]
const float * __restrict__ src3, // v: [head_dim, n_head, n_seq_tokens, n_seqs]
const float * __restrict__ src4, // g: [head_dim, n_head, n_seq_tokens, n_seqs]
const float * __restrict__ src5, // beta: [n_head, n_seq_tokens, n_seqs]
const int32_t * __restrict__ src6, // ids: [n_seqs]
float * __restrict__ dst,
const int64_t head_dim,
const int64_t n_head,
const int64_t n_seq_tokens,
const int64_t n_seqs,
const int64_t y_off) // offset to state output in dst (in floats)
{
// Each block handles one head for one sequence
const int seq_idx = blockIdx.x / n_head;
const int head_idx = blockIdx.x % n_head;
const int tid = threadIdx.x;
const int n_threads = blockDim.x;
if (seq_idx >= n_seqs || head_idx >= n_head) return;
// Get sequence ID for initial state
const int src_seq = src6[seq_idx];
// Shared memory for temporary buffers
extern __shared__ float smem[];
float * hk_buf = smem; // [head_dim] - h @ k buffer
float * q_norm = smem + head_dim; // [head_dim] - normalized q
float * k_norm = q_norm + head_dim; // [head_dim] - normalized k
float * warp_sums = k_norm + head_dim; // [64] - for reductions
// Pointers to input/output data for this head
const int64_t h_stride_head = head_dim * head_dim;
const int64_t h_stride_seq = h_stride_head * n_head;
const int64_t qkv_stride_head = head_dim;
const int64_t qkv_stride_token = head_dim * n_head;
const int64_t qkv_stride_seq = qkv_stride_token * n_seq_tokens;
const int64_t beta_stride_token = n_head;
const int64_t beta_stride_seq = beta_stride_token * n_seq_tokens;
const float * h_in = src0 + src_seq * h_stride_seq + head_idx * h_stride_head;
float * h_out = dst + y_off + seq_idx * h_stride_seq + head_idx * h_stride_head;
float * y_out = dst + seq_idx * qkv_stride_seq + head_idx * qkv_stride_head;
// Copy initial state to output (we'll update in place)
for (int i = tid; i < head_dim * head_dim; i += n_threads) {
float val = h_in[i];
if (!isfinite(val) || fabsf(val) > 1e6f) {
val = 0.0f;
}
h_out[i] = val;
}
__syncthreads();
const float scale = 1.0f / sqrtf((float)head_dim);
// Process each token sequentially
for (int t = 0; t < n_seq_tokens; ++t) {
const float * q_raw = src1 + t * qkv_stride_token + seq_idx * qkv_stride_seq + head_idx * qkv_stride_head;
const float * k_raw = src2 + t * qkv_stride_token + seq_idx * qkv_stride_seq + head_idx * qkv_stride_head;
const float * v = src3 + t * qkv_stride_token + seq_idx * qkv_stride_seq + head_idx * qkv_stride_head;
const float * g = src4 + t * qkv_stride_token + seq_idx * qkv_stride_seq + head_idx * qkv_stride_head;
const float beta = src5[t * beta_stride_token + seq_idx * beta_stride_seq + head_idx];
float * y = y_out + t * qkv_stride_token;
// Step 1: L2 normalize q and k
float q_sq_sum = 0.0f, k_sq_sum = 0.0f;
for (int i = tid; i < head_dim; i += n_threads) {
q_sq_sum += q_raw[i] * q_raw[i];
k_sq_sum += k_raw[i] * k_raw[i];
}
// Warp reduction
for (int offset = warpSize/2; offset > 0; offset /= 2) {
q_sq_sum += __shfl_down_sync(0xffffffff, q_sq_sum, offset);
k_sq_sum += __shfl_down_sync(0xffffffff, k_sq_sum, offset);
}
// Cross-warp reduction
int warp_id = tid / warpSize;
int lane_id = tid % warpSize;
if (lane_id == 0 && warp_id < 32) {
warp_sums[warp_id] = q_sq_sum;
warp_sums[32 + warp_id] = k_sq_sum;
}
__syncthreads();
if (tid == 0) {
float total_q = 0.0f, total_k = 0.0f;
for (int i = 0; i < (n_threads + warpSize - 1) / warpSize; ++i) {
total_q += warp_sums[i];
total_k += warp_sums[32 + i];
}
warp_sums[0] = rsqrtf(total_q + 1e-6f) * scale;
warp_sums[1] = rsqrtf(total_k + 1e-6f);
}
__syncthreads();
float q_norm_factor = warp_sums[0];
float k_norm_factor = warp_sums[1];
// Store normalized q and k
for (int i = tid; i < head_dim; i += n_threads) {
q_norm[i] = q_raw[i] * q_norm_factor;
k_norm[i] = k_raw[i] * k_norm_factor;
}
__syncthreads();
// KDA recurrence: h[t] = exp(g[t]) * h[t-1] + k[t]^T * (beta[t] * (v[t] - h[t-1] @ k[t]))
// Apply decay first, then compute retrieval and update
// Step 2: Apply decay to h: h = h * exp(g)
for (int idx = tid; idx < head_dim * head_dim; idx += n_threads) {
int i = idx / head_dim;
float exp_gi = expf(g[i]);
h_out[idx] *= exp_gi;
}
__syncthreads();
// Step 3: Compute h^T @ k -> hk_buf
for (int j = tid; j < head_dim; j += n_threads) {
float sum = 0.0f;
for (int i = 0; i < head_dim; ++i) {
sum += h_out[i * head_dim + j] * k_norm[i];
}
hk_buf[j] = sum;
}
__syncthreads();
// Step 4: Update h: h = h + outer(k, beta * (v - hk))
for (int idx = tid; idx < head_dim * head_dim; idx += n_threads) {
int i = idx / head_dim;
int j = idx % head_dim;
float delta_j = beta * (v[j] - hk_buf[j]);
h_out[idx] += k_norm[i] * delta_j;
}
__syncthreads();
// Step 5: Compute output y = h^T @ q
for (int j = tid; j < head_dim; j += n_threads) {
float sum = 0.0f;
for (int i = 0; i < head_dim; ++i) {
sum += h_out[i * head_dim + j] * q_norm[i];
}
y[j] = sum;
}
__syncthreads();
}
}
void ggml_cuda_op_kda_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0]; // h
const ggml_tensor * src1 = dst->src[1]; // q
const ggml_tensor * src2 = dst->src[2]; // k
const ggml_tensor * src3 = dst->src[3]; // v
const ggml_tensor * src4 = dst->src[4]; // g
const ggml_tensor * src5 = dst->src[5]; // beta
const ggml_tensor * src6 = dst->src[6]; // ids
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(src2->type == GGML_TYPE_F32);
GGML_ASSERT(src3->type == GGML_TYPE_F32);
GGML_ASSERT(src4->type == GGML_TYPE_F32);
GGML_ASSERT(src5->type == GGML_TYPE_F32);
GGML_ASSERT(src6->type == GGML_TYPE_I32);
const int64_t head_dim = src0->ne[0];
const int64_t n_head = src1->ne[1];
const int64_t n_seq_tokens = src1->ne[2];
const int64_t n_seqs = src1->ne[3];
// Output offset for hidden state (after attention output) - in floats
const int64_t y_off = ggml_nelements(src1);
const float * h_d = (const float *)src0->data;
const float * q_d = (const float *)src1->data;
const float * k_d = (const float *)src2->data;
const float * v_d = (const float *)src3->data;
const float * g_d = (const float *)src4->data;
const float * beta_d = (const float *)src5->data;
const int32_t * ids_d = (const int32_t *)src6->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
// Launch kernel: one block per (sequence, head) pair
const int n_blocks = n_seqs * n_head;
const int n_threads = 128;
// Shared memory: hk_buf[head_dim] + q_norm[head_dim] + k_norm[head_dim] + warp_sums[64]
size_t smem_size = (3 * head_dim + 64) * sizeof(float);
kda_scan_f32_kernel<<<n_blocks, n_threads, smem_size, stream>>>(
h_d, q_d, k_d, v_d, g_d, beta_d, ids_d, dst_d,
head_dim, n_head, n_seq_tokens, n_seqs, y_off);
}

View File

@ -1,3 +0,0 @@
#include "common.cuh"
void ggml_cuda_op_kda_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@ -5435,69 +5435,6 @@ struct ggml_tensor * ggml_ssm_scan(
return result;
}
// ggml_kda_scan
struct ggml_tensor * ggml_kda_scan(
struct ggml_context * ctx,
struct ggml_tensor * h,
struct ggml_tensor * q,
struct ggml_tensor * k,
struct ggml_tensor * v,
struct ggml_tensor * g,
struct ggml_tensor * beta,
struct ggml_tensor * ids) {
GGML_ASSERT(ggml_is_contiguous(h));
GGML_ASSERT(ggml_is_contiguous(q));
GGML_ASSERT(ggml_is_contiguous(k));
GGML_ASSERT(ggml_is_contiguous(v));
GGML_ASSERT(ggml_is_contiguous(g));
GGML_ASSERT(ggml_is_contiguous(beta));
GGML_ASSERT(ids->type == GGML_TYPE_I32);
{
const int64_t head_dim = h->ne[0];
const int64_t n_head = q->ne[1];
const int64_t n_seq_tokens = q->ne[2];
const int64_t n_seqs = q->ne[3];
GGML_ASSERT(h->ne[0] == head_dim);
GGML_ASSERT(h->ne[1] == head_dim);
GGML_ASSERT(h->ne[2] == n_head);
GGML_ASSERT(q->ne[0] == head_dim);
GGML_ASSERT(k->ne[0] == head_dim);
GGML_ASSERT(v->ne[0] == head_dim);
GGML_ASSERT(g->ne[0] == head_dim);
GGML_ASSERT(ggml_are_same_shape(q, k));
GGML_ASSERT(ggml_are_same_shape(q, v));
GGML_ASSERT(ggml_are_same_shape(q, g));
GGML_ASSERT(beta->ne[0] == n_head);
GGML_ASSERT(beta->ne[1] == n_seq_tokens);
GGML_ASSERT(beta->ne[2] == n_seqs);
GGML_ASSERT(ids->ne[0] == n_seqs);
GGML_ASSERT(ggml_is_vector(ids));
}
// Output: y (attention output) + updated hidden states
// y: {head_dim, n_head, n_seq_tokens, n_seqs}
// h_new: {head_dim, head_dim, n_head, n_seqs}
const int64_t head_dim = h->ne[0];
const int64_t n_head = q->ne[1];
const int64_t n_seqs = q->ne[3];
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,
ggml_nelements(q) + head_dim * head_dim * n_head * n_seqs);
result->op = GGML_OP_KDA_SCAN;
result->src[0] = h;
result->src[1] = q;
result->src[2] = k;
result->src[3] = v;
result->src[4] = g;
result->src[5] = beta;
result->src[6] = ids;
return result;
}
// ggml_win_part
struct ggml_tensor * ggml_win_part(