Kimi Linear ggml.c
This commit is contained in:
parent
bf42bc0606
commit
d73d3e51a5
|
|
@ -999,6 +999,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|||
"FLASH_ATTN_BACK",
|
||||
"SSM_CONV",
|
||||
"SSM_SCAN",
|
||||
"KDA_SCAN",
|
||||
"WIN_PART",
|
||||
"WIN_UNPART",
|
||||
"GET_REL_POS",
|
||||
|
|
@ -1024,7 +1025,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|||
"GLU",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95");
|
||||
static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96");
|
||||
|
||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"none",
|
||||
|
|
@ -5434,6 +5435,70 @@ struct ggml_tensor * ggml_ssm_scan(
|
|||
return result;
|
||||
}
|
||||
|
||||
// ggml_kda_scan
|
||||
|
||||
struct ggml_tensor * ggml_kda_scan(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * h,
|
||||
struct ggml_tensor * q,
|
||||
struct ggml_tensor * k,
|
||||
struct ggml_tensor * v,
|
||||
struct ggml_tensor * g,
|
||||
struct ggml_tensor * beta,
|
||||
struct ggml_tensor * ids) {
|
||||
GGML_ASSERT(ggml_is_contiguous(h));
|
||||
GGML_ASSERT(ggml_is_contiguous(q));
|
||||
GGML_ASSERT(ggml_is_contiguous(k));
|
||||
GGML_ASSERT(ggml_is_contiguous(v));
|
||||
GGML_ASSERT(ggml_is_contiguous(g));
|
||||
GGML_ASSERT(ggml_is_contiguous(beta));
|
||||
GGML_ASSERT(ids->type == GGML_TYPE_I32);
|
||||
|
||||
{
|
||||
const int64_t head_dim = h->ne[0];
|
||||
const int64_t n_head = q->ne[1];
|
||||
const int64_t n_seq_tokens = q->ne[2];
|
||||
const int64_t n_seqs = q->ne[3];
|
||||
|
||||
GGML_ASSERT(h->ne[0] == head_dim);
|
||||
GGML_ASSERT(h->ne[1] == head_dim);
|
||||
GGML_ASSERT(h->ne[2] == n_head);
|
||||
GGML_ASSERT(q->ne[0] == head_dim);
|
||||
GGML_ASSERT(k->ne[0] == head_dim);
|
||||
GGML_ASSERT(v->ne[0] == head_dim);
|
||||
GGML_ASSERT(g->ne[0] == head_dim);
|
||||
GGML_ASSERT(ggml_are_same_shape(q, k));
|
||||
GGML_ASSERT(ggml_are_same_shape(q, v));
|
||||
GGML_ASSERT(ggml_are_same_shape(q, g));
|
||||
GGML_ASSERT(beta->ne[0] == n_head);
|
||||
GGML_ASSERT(beta->ne[1] == n_seq_tokens);
|
||||
GGML_ASSERT(beta->ne[2] == n_seqs);
|
||||
GGML_ASSERT(ids->ne[0] == n_seqs);
|
||||
GGML_ASSERT(ggml_is_vector(ids));
|
||||
}
|
||||
|
||||
// Output: y (attention output) + updated hidden states
|
||||
// y: {head_dim, n_head, n_seq_tokens, n_seqs}
|
||||
// h_new: {head_dim, head_dim, n_head, n_seqs}
|
||||
const int64_t head_dim = h->ne[0];
|
||||
const int64_t n_head = q->ne[1];
|
||||
const int64_t n_seq_tokens = q->ne[2];
|
||||
const int64_t n_seqs = q->ne[3];
|
||||
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,
|
||||
ggml_nelements(q) + head_dim * head_dim * n_head * n_seqs);
|
||||
|
||||
result->op = GGML_OP_KDA_SCAN;
|
||||
result->src[0] = h;
|
||||
result->src[1] = q;
|
||||
result->src[2] = k;
|
||||
result->src[3] = v;
|
||||
result->src[4] = g;
|
||||
result->src[5] = beta;
|
||||
result->src[6] = ids;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_win_part
|
||||
|
||||
struct ggml_tensor * ggml_win_part(
|
||||
|
|
|
|||
Loading…
Reference in New Issue