Kimi Linear ggml.c

This commit is contained in:
Yee Man Chan 2025-12-02 11:27:57 +08:00
parent bf42bc0606
commit d73d3e51a5
1 changed files with 66 additions and 1 deletions

View File

@ -999,6 +999,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"FLASH_ATTN_BACK",
"SSM_CONV",
"SSM_SCAN",
"KDA_SCAN",
"WIN_PART",
"WIN_UNPART",
"GET_REL_POS",
@ -1024,7 +1025,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"GLU",
};
static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95");
static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
@ -5434,6 +5435,70 @@ struct ggml_tensor * ggml_ssm_scan(
return result;
}
// ggml_kda_scan
struct ggml_tensor * ggml_kda_scan(
struct ggml_context * ctx,
struct ggml_tensor * h,
struct ggml_tensor * q,
struct ggml_tensor * k,
struct ggml_tensor * v,
struct ggml_tensor * g,
struct ggml_tensor * beta,
struct ggml_tensor * ids) {
GGML_ASSERT(ggml_is_contiguous(h));
GGML_ASSERT(ggml_is_contiguous(q));
GGML_ASSERT(ggml_is_contiguous(k));
GGML_ASSERT(ggml_is_contiguous(v));
GGML_ASSERT(ggml_is_contiguous(g));
GGML_ASSERT(ggml_is_contiguous(beta));
GGML_ASSERT(ids->type == GGML_TYPE_I32);
{
const int64_t head_dim = h->ne[0];
const int64_t n_head = q->ne[1];
const int64_t n_seq_tokens = q->ne[2];
const int64_t n_seqs = q->ne[3];
GGML_ASSERT(h->ne[0] == head_dim);
GGML_ASSERT(h->ne[1] == head_dim);
GGML_ASSERT(h->ne[2] == n_head);
GGML_ASSERT(q->ne[0] == head_dim);
GGML_ASSERT(k->ne[0] == head_dim);
GGML_ASSERT(v->ne[0] == head_dim);
GGML_ASSERT(g->ne[0] == head_dim);
GGML_ASSERT(ggml_are_same_shape(q, k));
GGML_ASSERT(ggml_are_same_shape(q, v));
GGML_ASSERT(ggml_are_same_shape(q, g));
GGML_ASSERT(beta->ne[0] == n_head);
GGML_ASSERT(beta->ne[1] == n_seq_tokens);
GGML_ASSERT(beta->ne[2] == n_seqs);
GGML_ASSERT(ids->ne[0] == n_seqs);
GGML_ASSERT(ggml_is_vector(ids));
}
// Output: y (attention output) + updated hidden states
// y: {head_dim, n_head, n_seq_tokens, n_seqs}
// h_new: {head_dim, head_dim, n_head, n_seqs}
const int64_t head_dim = h->ne[0];
const int64_t n_head = q->ne[1];
const int64_t n_seq_tokens = q->ne[2];
const int64_t n_seqs = q->ne[3];
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,
ggml_nelements(q) + head_dim * head_dim * n_head * n_seqs);
result->op = GGML_OP_KDA_SCAN;
result->src[0] = h;
result->src[1] = q;
result->src[2] = k;
result->src[3] = v;
result->src[4] = g;
result->src[5] = beta;
result->src[6] = ids;
return result;
}
// ggml_win_part
struct ggml_tensor * ggml_win_part(