rewrite get_vocab for KimiLinear. Removed all kda_scan code
This commit is contained in:
parent
ae9771d1dc
commit
f9a11d7758
|
|
@ -1962,10 +1962,6 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|||
{
|
||||
ggml_compute_forward_ssm_scan(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_KDA_SCAN:
|
||||
{
|
||||
ggml_compute_forward_kda_scan(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_WIN_PART:
|
||||
{
|
||||
ggml_compute_forward_win_part(params, tensor);
|
||||
|
|
|
|||
|
|
@ -8686,7 +8686,6 @@ static void ggml_compute_forward_ssm_conv_f32(
|
|||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
const int ir = ir1 - ir0;
|
||||
|
||||
static int conv_debug_count = 0;
|
||||
bool do_conv_debug = false; // (ith == 0 && conv_debug_count++ < 3);
|
||||
|
||||
for (int i3 = 0; i3 < n_s; ++i3) {
|
||||
|
|
@ -8966,192 +8965,6 @@ void ggml_compute_forward_ssm_scan(
|
|||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_kda_scan
|
||||
// KDA (Kimi Delta Attention) recurrence:
|
||||
// h[t] = exp(g[t]) * h[t-1] + k[t]^T * (beta[t] * (v[t] - h[t-1] @ k[t]))
|
||||
// o[t] = q[t]^T @ h[t]
|
||||
|
||||
static void ggml_compute_forward_kda_scan_f32(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0]; // h {head_dim, head_dim, n_head, n_seqs+}
|
||||
const ggml_tensor * src1 = dst->src[1]; // q {head_dim, n_head, n_seq_tokens, n_seqs}
|
||||
const ggml_tensor * src2 = dst->src[2]; // k {head_dim, n_head, n_seq_tokens, n_seqs}
|
||||
const ggml_tensor * src3 = dst->src[3]; // v {head_dim, n_head, n_seq_tokens, n_seqs}
|
||||
const ggml_tensor * src4 = dst->src[4]; // g {head_dim, n_head, n_seq_tokens, n_seqs}
|
||||
const ggml_tensor * src5 = dst->src[5]; // beta {n_head, n_seq_tokens, n_seqs}
|
||||
const ggml_tensor * src6 = dst->src[6]; // ids {n_seqs}
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const int64_t head_dim = src0->ne[0];
|
||||
const int64_t n_head = src1->ne[1];
|
||||
const int64_t n_seq_tokens = src1->ne[2];
|
||||
const int64_t n_seqs = src1->ne[3];
|
||||
|
||||
// Output offset for hidden state
|
||||
const int64_t y_off = ggml_nelements(src1) * sizeof(float);
|
||||
|
||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src2->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src3->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src4->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src5->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
|
||||
|
||||
// Parallelize over heads
|
||||
const int dh = (n_head + nth - 1) / nth;
|
||||
const int ih0 = dh * ith;
|
||||
const int ih1 = MIN(ih0 + dh, (int)n_head);
|
||||
|
||||
const int32_t * ids = (const int32_t *) src6->data;
|
||||
|
||||
// Temporary buffer for h @ k computation
|
||||
float * hk_buf = (float *) malloc(head_dim * sizeof(float));
|
||||
|
||||
static int debug_count = 0;
|
||||
bool do_debug = false; // (ith == 0 && debug_count++ < 20);
|
||||
|
||||
for (int i3 = 0; i3 < n_seqs; ++i3) {
|
||||
// Get initial hidden state for this sequence
|
||||
const float * h0 = (const float *) ((const char *) src0->data + ids[i3] * src0->nb[3]);
|
||||
// Output hidden state location
|
||||
float * h_out = (float *) ((char *) dst->data + i3 * src0->nb[3] + y_off);
|
||||
|
||||
for (int ih = ih0; ih < ih1; ++ih) {
|
||||
// Per-head hidden state: [head_dim, head_dim]
|
||||
// Copy initial state to output (will be updated in place)
|
||||
const float * h_in = h0 + ih * head_dim * head_dim;
|
||||
float * h = h_out + ih * head_dim * head_dim;
|
||||
|
||||
// Copy initial state, but check for invalid values and clear if needed
|
||||
bool need_clear = false;
|
||||
for (int i = 0; i < head_dim * head_dim && !need_clear; ++i) {
|
||||
if (!isfinite(h_in[i]) || fabsf(h_in[i]) > 1e6f) {
|
||||
need_clear = true;
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < head_dim * head_dim; ++i) {
|
||||
h[i] = need_clear ? 0.0f : h_in[i];
|
||||
}
|
||||
|
||||
for (int it = 0; it < n_seq_tokens; ++it) {
|
||||
const float * q_raw = (const float *) ((const char *) src1->data +
|
||||
it * src1->nb[2] + i3 * src1->nb[3]) + ih * head_dim;
|
||||
const float * k_raw = (const float *) ((const char *) src2->data +
|
||||
it * src2->nb[2] + i3 * src2->nb[3]) + ih * head_dim;
|
||||
const float * v = (const float *) ((const char *) src3->data +
|
||||
it * src3->nb[2] + i3 * src3->nb[3]) + ih * head_dim;
|
||||
const float * g = (const float *) ((const char *) src4->data +
|
||||
it * src4->nb[2] + i3 * src4->nb[3]) + ih * head_dim;
|
||||
const float beta = ((const float *) ((const char *) src5->data +
|
||||
it * src5->nb[1] + i3 * src5->nb[2]))[ih];
|
||||
|
||||
float * y = (float *) dst->data +
|
||||
it * n_head * head_dim + i3 * n_seq_tokens * n_head * head_dim + ih * head_dim;
|
||||
|
||||
// L2 normalize q and k (critical for KDA stability)
|
||||
float q_norm = 0.0f, k_norm = 0.0f;
|
||||
for (int i = 0; i < head_dim; ++i) {
|
||||
q_norm += q_raw[i] * q_raw[i];
|
||||
k_norm += k_raw[i] * k_raw[i];
|
||||
}
|
||||
q_norm = sqrtf(q_norm + 1e-6f);
|
||||
k_norm = sqrtf(k_norm + 1e-6f);
|
||||
|
||||
// Debug output
|
||||
if (do_debug && ih == 0 && it == 0 && i3 == 0) {
|
||||
fprintf(stderr, "DEBUG KDA: q_raw[0]=%f, k_raw[0]=%f, v[0]=%f, g[0]=%f, beta=%f\n",
|
||||
q_raw[0], k_raw[0], v[0], g[0], beta);
|
||||
fprintf(stderr, "DEBUG KDA: q_norm=%f, k_norm=%f, exp(g[0])=%f, scale=%f\n",
|
||||
q_norm, k_norm, expf(g[0]), 1.0f / sqrtf((float)head_dim));
|
||||
}
|
||||
|
||||
// Normalized q and k with scale = 1/sqrt(head_dim)
|
||||
// Note: scale is applied only to q after L2 normalization
|
||||
const float scale = 1.0f / sqrtf((float)head_dim);
|
||||
float q[128], k[128]; // assume head_dim <= 128
|
||||
for (int i = 0; i < head_dim; ++i) {
|
||||
// L2 normalize then scale q
|
||||
q[i] = (q_raw[i] / q_norm) * scale;
|
||||
// L2 normalize k (no scale)
|
||||
k[i] = k_raw[i] / k_norm;
|
||||
}
|
||||
|
||||
// KDA recurrence: h[t] = exp(g[t]) * h[t-1] + k[t]^T * (beta[t] * (v[t] - h[t-1] @ k[t]))
|
||||
// Note: Apply decay first, then compute retrieval and update
|
||||
|
||||
// Step 1: Apply decay to h first: h = h * exp(g)
|
||||
for (int i = 0; i < head_dim; ++i) {
|
||||
const float exp_gi = expf(g[i]);
|
||||
for (int j = 0; j < head_dim; ++j) {
|
||||
h[i * head_dim + j] *= exp_gi;
|
||||
}
|
||||
}
|
||||
|
||||
// Step 2: Compute h^T @ k -> hk_buf [head_dim]
|
||||
// hk_buf[j] = sum_i (h[i,j] * k[i]) which is column j of h dotted with k
|
||||
for (int j = 0; j < head_dim; ++j) {
|
||||
float sum = 0.0f;
|
||||
for (int i = 0; i < head_dim; ++i) {
|
||||
sum += h[i * head_dim + j] * k[i];
|
||||
}
|
||||
hk_buf[j] = sum;
|
||||
}
|
||||
|
||||
// Step 3: Compute delta = beta * (v - hk) and update h
|
||||
// h = h + outer(k, delta) where outer(k,delta)[i,j] = k[i] * delta[j]
|
||||
for (int i = 0; i < head_dim; ++i) {
|
||||
for (int j = 0; j < head_dim; ++j) {
|
||||
const float delta_j = beta * (v[j] - hk_buf[j]);
|
||||
h[i * head_dim + j] += k[i] * delta_j;
|
||||
}
|
||||
}
|
||||
|
||||
// Step 4: Compute output y = h^T @ q -> [head_dim]
|
||||
// vLLM: b_o = tl.sum(b_h * b_q[:, None], 0) means o[j] = sum_i(h[i,j] * q[i])
|
||||
for (int j = 0; j < head_dim; ++j) {
|
||||
float sum = 0.0f;
|
||||
for (int i = 0; i < head_dim; ++i) {
|
||||
sum += h[i * head_dim + j] * q[i];
|
||||
}
|
||||
y[j] = sum;
|
||||
}
|
||||
|
||||
// Debug output
|
||||
if (do_debug && ih == 0 && it == 0 && i3 == 0) {
|
||||
// Find max abs value in h for stability check
|
||||
float h_max = 0.0f;
|
||||
for (int i = 0; i < head_dim * head_dim; i++) {
|
||||
if (fabsf(h[i]) > h_max) h_max = fabsf(h[i]);
|
||||
}
|
||||
fprintf(stderr, "DEBUG KDA: y[0]=%.6f, h_max=%.6f, exp(g[0])=%.6f\n",
|
||||
y[0], h_max, expf(g[0]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
free(hk_buf);
|
||||
}
|
||||
|
||||
void ggml_compute_forward_kda_scan(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
switch (dst->src[0]->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_kda_scan_f32(params, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_win_part
|
||||
|
||||
static void ggml_compute_forward_win_part_f32(
|
||||
|
|
|
|||
|
|
@ -92,7 +92,6 @@ void ggml_compute_forward_flash_attn_back(
|
|||
struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_ssm_conv(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_ssm_scan(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_kda_scan(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_win_part(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_win_unpart(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_unary(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
|
|
|
|||
|
|
@ -41,7 +41,6 @@
|
|||
#include "ggml-cuda/softmax.cuh"
|
||||
#include "ggml-cuda/ssm-conv.cuh"
|
||||
#include "ggml-cuda/ssm-scan.cuh"
|
||||
#include "ggml-cuda/kda-scan.cuh"
|
||||
#include "ggml-cuda/sum.cuh"
|
||||
#include "ggml-cuda/sumrows.cuh"
|
||||
#include "ggml-cuda/mean.cuh"
|
||||
|
|
@ -2693,9 +2692,6 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|||
case GGML_OP_SSM_SCAN:
|
||||
ggml_cuda_op_ssm_scan(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_KDA_SCAN:
|
||||
ggml_cuda_op_kda_scan(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_ARGSORT:
|
||||
ggml_cuda_op_argsort(ctx, dst);
|
||||
break;
|
||||
|
|
|
|||
|
|
@ -1,209 +0,0 @@
|
|||
#include "kda-scan.cuh"
|
||||
|
||||
// KDA (Kimi Delta Attention) scan CUDA kernel
|
||||
// Recurrence:
|
||||
// h[t] = exp(g[t]) * h[t-1] + k[t]^T * (beta[t] * (v[t] - h[t-1] @ k[t]))
|
||||
// o[t] = q[t]^T @ h[t]
|
||||
//
|
||||
// This kernel uses global memory for the hidden state to avoid shared memory limits.
|
||||
// Each block processes one head for one sequence.
|
||||
|
||||
__global__ void kda_scan_f32_kernel(
|
||||
const float * __restrict__ src0, // h: [head_dim, head_dim, n_head, n_seqs+]
|
||||
const float * __restrict__ src1, // q: [head_dim, n_head, n_seq_tokens, n_seqs]
|
||||
const float * __restrict__ src2, // k: [head_dim, n_head, n_seq_tokens, n_seqs]
|
||||
const float * __restrict__ src3, // v: [head_dim, n_head, n_seq_tokens, n_seqs]
|
||||
const float * __restrict__ src4, // g: [head_dim, n_head, n_seq_tokens, n_seqs]
|
||||
const float * __restrict__ src5, // beta: [n_head, n_seq_tokens, n_seqs]
|
||||
const int32_t * __restrict__ src6, // ids: [n_seqs]
|
||||
float * __restrict__ dst,
|
||||
const int64_t head_dim,
|
||||
const int64_t n_head,
|
||||
const int64_t n_seq_tokens,
|
||||
const int64_t n_seqs,
|
||||
const int64_t y_off) // offset to state output in dst (in floats)
|
||||
{
|
||||
// Each block handles one head for one sequence
|
||||
const int seq_idx = blockIdx.x / n_head;
|
||||
const int head_idx = blockIdx.x % n_head;
|
||||
const int tid = threadIdx.x;
|
||||
const int n_threads = blockDim.x;
|
||||
|
||||
if (seq_idx >= n_seqs || head_idx >= n_head) return;
|
||||
|
||||
// Get sequence ID for initial state
|
||||
const int src_seq = src6[seq_idx];
|
||||
|
||||
// Shared memory for temporary buffers
|
||||
extern __shared__ float smem[];
|
||||
float * hk_buf = smem; // [head_dim] - h @ k buffer
|
||||
float * q_norm = smem + head_dim; // [head_dim] - normalized q
|
||||
float * k_norm = q_norm + head_dim; // [head_dim] - normalized k
|
||||
float * warp_sums = k_norm + head_dim; // [64] - for reductions
|
||||
|
||||
// Pointers to input/output data for this head
|
||||
const int64_t h_stride_head = head_dim * head_dim;
|
||||
const int64_t h_stride_seq = h_stride_head * n_head;
|
||||
const int64_t qkv_stride_head = head_dim;
|
||||
const int64_t qkv_stride_token = head_dim * n_head;
|
||||
const int64_t qkv_stride_seq = qkv_stride_token * n_seq_tokens;
|
||||
const int64_t beta_stride_token = n_head;
|
||||
const int64_t beta_stride_seq = beta_stride_token * n_seq_tokens;
|
||||
|
||||
const float * h_in = src0 + src_seq * h_stride_seq + head_idx * h_stride_head;
|
||||
float * h_out = dst + y_off + seq_idx * h_stride_seq + head_idx * h_stride_head;
|
||||
float * y_out = dst + seq_idx * qkv_stride_seq + head_idx * qkv_stride_head;
|
||||
|
||||
// Copy initial state to output (we'll update in place)
|
||||
for (int i = tid; i < head_dim * head_dim; i += n_threads) {
|
||||
float val = h_in[i];
|
||||
if (!isfinite(val) || fabsf(val) > 1e6f) {
|
||||
val = 0.0f;
|
||||
}
|
||||
h_out[i] = val;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
const float scale = 1.0f / sqrtf((float)head_dim);
|
||||
|
||||
// Process each token sequentially
|
||||
for (int t = 0; t < n_seq_tokens; ++t) {
|
||||
const float * q_raw = src1 + t * qkv_stride_token + seq_idx * qkv_stride_seq + head_idx * qkv_stride_head;
|
||||
const float * k_raw = src2 + t * qkv_stride_token + seq_idx * qkv_stride_seq + head_idx * qkv_stride_head;
|
||||
const float * v = src3 + t * qkv_stride_token + seq_idx * qkv_stride_seq + head_idx * qkv_stride_head;
|
||||
const float * g = src4 + t * qkv_stride_token + seq_idx * qkv_stride_seq + head_idx * qkv_stride_head;
|
||||
const float beta = src5[t * beta_stride_token + seq_idx * beta_stride_seq + head_idx];
|
||||
float * y = y_out + t * qkv_stride_token;
|
||||
|
||||
// Step 1: L2 normalize q and k
|
||||
float q_sq_sum = 0.0f, k_sq_sum = 0.0f;
|
||||
for (int i = tid; i < head_dim; i += n_threads) {
|
||||
q_sq_sum += q_raw[i] * q_raw[i];
|
||||
k_sq_sum += k_raw[i] * k_raw[i];
|
||||
}
|
||||
|
||||
// Warp reduction
|
||||
for (int offset = warpSize/2; offset > 0; offset /= 2) {
|
||||
q_sq_sum += __shfl_down_sync(0xffffffff, q_sq_sum, offset);
|
||||
k_sq_sum += __shfl_down_sync(0xffffffff, k_sq_sum, offset);
|
||||
}
|
||||
|
||||
// Cross-warp reduction
|
||||
int warp_id = tid / warpSize;
|
||||
int lane_id = tid % warpSize;
|
||||
if (lane_id == 0 && warp_id < 32) {
|
||||
warp_sums[warp_id] = q_sq_sum;
|
||||
warp_sums[32 + warp_id] = k_sq_sum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (tid == 0) {
|
||||
float total_q = 0.0f, total_k = 0.0f;
|
||||
for (int i = 0; i < (n_threads + warpSize - 1) / warpSize; ++i) {
|
||||
total_q += warp_sums[i];
|
||||
total_k += warp_sums[32 + i];
|
||||
}
|
||||
warp_sums[0] = rsqrtf(total_q + 1e-6f) * scale;
|
||||
warp_sums[1] = rsqrtf(total_k + 1e-6f);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float q_norm_factor = warp_sums[0];
|
||||
float k_norm_factor = warp_sums[1];
|
||||
|
||||
// Store normalized q and k
|
||||
for (int i = tid; i < head_dim; i += n_threads) {
|
||||
q_norm[i] = q_raw[i] * q_norm_factor;
|
||||
k_norm[i] = k_raw[i] * k_norm_factor;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// KDA recurrence: h[t] = exp(g[t]) * h[t-1] + k[t]^T * (beta[t] * (v[t] - h[t-1] @ k[t]))
|
||||
// Apply decay first, then compute retrieval and update
|
||||
|
||||
// Step 2: Apply decay to h: h = h * exp(g)
|
||||
for (int idx = tid; idx < head_dim * head_dim; idx += n_threads) {
|
||||
int i = idx / head_dim;
|
||||
float exp_gi = expf(g[i]);
|
||||
h_out[idx] *= exp_gi;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Step 3: Compute h^T @ k -> hk_buf
|
||||
for (int j = tid; j < head_dim; j += n_threads) {
|
||||
float sum = 0.0f;
|
||||
for (int i = 0; i < head_dim; ++i) {
|
||||
sum += h_out[i * head_dim + j] * k_norm[i];
|
||||
}
|
||||
hk_buf[j] = sum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Step 4: Update h: h = h + outer(k, beta * (v - hk))
|
||||
for (int idx = tid; idx < head_dim * head_dim; idx += n_threads) {
|
||||
int i = idx / head_dim;
|
||||
int j = idx % head_dim;
|
||||
float delta_j = beta * (v[j] - hk_buf[j]);
|
||||
h_out[idx] += k_norm[i] * delta_j;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Step 5: Compute output y = h^T @ q
|
||||
for (int j = tid; j < head_dim; j += n_threads) {
|
||||
float sum = 0.0f;
|
||||
for (int i = 0; i < head_dim; ++i) {
|
||||
sum += h_out[i * head_dim + j] * q_norm[i];
|
||||
}
|
||||
y[j] = sum;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_kda_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0]; // h
|
||||
const ggml_tensor * src1 = dst->src[1]; // q
|
||||
const ggml_tensor * src2 = dst->src[2]; // k
|
||||
const ggml_tensor * src3 = dst->src[3]; // v
|
||||
const ggml_tensor * src4 = dst->src[4]; // g
|
||||
const ggml_tensor * src5 = dst->src[5]; // beta
|
||||
const ggml_tensor * src6 = dst->src[6]; // ids
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src2->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src3->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src4->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src5->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src6->type == GGML_TYPE_I32);
|
||||
|
||||
const int64_t head_dim = src0->ne[0];
|
||||
const int64_t n_head = src1->ne[1];
|
||||
const int64_t n_seq_tokens = src1->ne[2];
|
||||
const int64_t n_seqs = src1->ne[3];
|
||||
|
||||
// Output offset for hidden state (after attention output) - in floats
|
||||
const int64_t y_off = ggml_nelements(src1);
|
||||
|
||||
const float * h_d = (const float *)src0->data;
|
||||
const float * q_d = (const float *)src1->data;
|
||||
const float * k_d = (const float *)src2->data;
|
||||
const float * v_d = (const float *)src3->data;
|
||||
const float * g_d = (const float *)src4->data;
|
||||
const float * beta_d = (const float *)src5->data;
|
||||
const int32_t * ids_d = (const int32_t *)src6->data;
|
||||
float * dst_d = (float *)dst->data;
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
// Launch kernel: one block per (sequence, head) pair
|
||||
const int n_blocks = n_seqs * n_head;
|
||||
const int n_threads = 128;
|
||||
|
||||
// Shared memory: hk_buf[head_dim] + q_norm[head_dim] + k_norm[head_dim] + warp_sums[64]
|
||||
size_t smem_size = (3 * head_dim + 64) * sizeof(float);
|
||||
|
||||
kda_scan_f32_kernel<<<n_blocks, n_threads, smem_size, stream>>>(
|
||||
h_d, q_d, k_d, v_d, g_d, beta_d, ids_d, dst_d,
|
||||
head_dim, n_head, n_seq_tokens, n_seqs, y_off);
|
||||
}
|
||||
|
|
@ -1,3 +0,0 @@
|
|||
#include "common.cuh"
|
||||
|
||||
void ggml_cuda_op_kda_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
|
@ -5435,69 +5435,6 @@ struct ggml_tensor * ggml_ssm_scan(
|
|||
return result;
|
||||
}
|
||||
|
||||
// ggml_kda_scan
|
||||
|
||||
struct ggml_tensor * ggml_kda_scan(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * h,
|
||||
struct ggml_tensor * q,
|
||||
struct ggml_tensor * k,
|
||||
struct ggml_tensor * v,
|
||||
struct ggml_tensor * g,
|
||||
struct ggml_tensor * beta,
|
||||
struct ggml_tensor * ids) {
|
||||
GGML_ASSERT(ggml_is_contiguous(h));
|
||||
GGML_ASSERT(ggml_is_contiguous(q));
|
||||
GGML_ASSERT(ggml_is_contiguous(k));
|
||||
GGML_ASSERT(ggml_is_contiguous(v));
|
||||
GGML_ASSERT(ggml_is_contiguous(g));
|
||||
GGML_ASSERT(ggml_is_contiguous(beta));
|
||||
GGML_ASSERT(ids->type == GGML_TYPE_I32);
|
||||
|
||||
{
|
||||
const int64_t head_dim = h->ne[0];
|
||||
const int64_t n_head = q->ne[1];
|
||||
const int64_t n_seq_tokens = q->ne[2];
|
||||
const int64_t n_seqs = q->ne[3];
|
||||
|
||||
GGML_ASSERT(h->ne[0] == head_dim);
|
||||
GGML_ASSERT(h->ne[1] == head_dim);
|
||||
GGML_ASSERT(h->ne[2] == n_head);
|
||||
GGML_ASSERT(q->ne[0] == head_dim);
|
||||
GGML_ASSERT(k->ne[0] == head_dim);
|
||||
GGML_ASSERT(v->ne[0] == head_dim);
|
||||
GGML_ASSERT(g->ne[0] == head_dim);
|
||||
GGML_ASSERT(ggml_are_same_shape(q, k));
|
||||
GGML_ASSERT(ggml_are_same_shape(q, v));
|
||||
GGML_ASSERT(ggml_are_same_shape(q, g));
|
||||
GGML_ASSERT(beta->ne[0] == n_head);
|
||||
GGML_ASSERT(beta->ne[1] == n_seq_tokens);
|
||||
GGML_ASSERT(beta->ne[2] == n_seqs);
|
||||
GGML_ASSERT(ids->ne[0] == n_seqs);
|
||||
GGML_ASSERT(ggml_is_vector(ids));
|
||||
}
|
||||
|
||||
// Output: y (attention output) + updated hidden states
|
||||
// y: {head_dim, n_head, n_seq_tokens, n_seqs}
|
||||
// h_new: {head_dim, head_dim, n_head, n_seqs}
|
||||
const int64_t head_dim = h->ne[0];
|
||||
const int64_t n_head = q->ne[1];
|
||||
const int64_t n_seqs = q->ne[3];
|
||||
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,
|
||||
ggml_nelements(q) + head_dim * head_dim * n_head * n_seqs);
|
||||
|
||||
result->op = GGML_OP_KDA_SCAN;
|
||||
result->src[0] = h;
|
||||
result->src[1] = q;
|
||||
result->src[2] = k;
|
||||
result->src[3] = v;
|
||||
result->src[4] = g;
|
||||
result->src[5] = beta;
|
||||
result->src[6] = ids;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_win_part
|
||||
|
||||
struct ggml_tensor * ggml_win_part(
|
||||
|
|
|
|||
Loading…
Reference in New Issue