diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 17cf4d84bb..8bf562e8b1 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -999,6 +999,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "FLASH_ATTN_BACK", "SSM_CONV", "SSM_SCAN", + "KDA_SCAN", "WIN_PART", "WIN_UNPART", "GET_REL_POS", @@ -1024,7 +1025,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "GLU", }; -static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95"); +static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -5434,6 +5435,70 @@ struct ggml_tensor * ggml_ssm_scan( return result; } +// ggml_kda_scan + +struct ggml_tensor * ggml_kda_scan( + struct ggml_context * ctx, + struct ggml_tensor * h, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * g, + struct ggml_tensor * beta, + struct ggml_tensor * ids) { + GGML_ASSERT(ggml_is_contiguous(h)); + GGML_ASSERT(ggml_is_contiguous(q)); + GGML_ASSERT(ggml_is_contiguous(k)); + GGML_ASSERT(ggml_is_contiguous(v)); + GGML_ASSERT(ggml_is_contiguous(g)); + GGML_ASSERT(ggml_is_contiguous(beta)); + GGML_ASSERT(ids->type == GGML_TYPE_I32); + + { + const int64_t head_dim = h->ne[0]; + const int64_t n_head = q->ne[1]; + const int64_t n_seq_tokens = q->ne[2]; + const int64_t n_seqs = q->ne[3]; + + GGML_ASSERT(h->ne[0] == head_dim); + GGML_ASSERT(h->ne[1] == head_dim); + GGML_ASSERT(h->ne[2] == n_head); + GGML_ASSERT(q->ne[0] == head_dim); + GGML_ASSERT(k->ne[0] == head_dim); + GGML_ASSERT(v->ne[0] == head_dim); + GGML_ASSERT(g->ne[0] == head_dim); + GGML_ASSERT(ggml_are_same_shape(q, k)); + GGML_ASSERT(ggml_are_same_shape(q, v)); + GGML_ASSERT(ggml_are_same_shape(q, g)); + GGML_ASSERT(beta->ne[0] == n_head); + GGML_ASSERT(beta->ne[1] == n_seq_tokens); + GGML_ASSERT(beta->ne[2] == n_seqs); + GGML_ASSERT(ids->ne[0] == n_seqs); + GGML_ASSERT(ggml_is_vector(ids)); + } + + // Output: y (attention output) + updated hidden states + // y: {head_dim, n_head, n_seq_tokens, n_seqs} + // h_new: {head_dim, head_dim, n_head, n_seqs} + const int64_t head_dim = h->ne[0]; + const int64_t n_head = q->ne[1]; + const int64_t n_seq_tokens = q->ne[2]; + const int64_t n_seqs = q->ne[3]; + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, + ggml_nelements(q) + head_dim * head_dim * n_head * n_seqs); + + result->op = GGML_OP_KDA_SCAN; + result->src[0] = h; + result->src[1] = q; + result->src[2] = k; + result->src[3] = v; + result->src[4] = g; + result->src[5] = beta; + result->src[6] = ids; + + return result; +} + // ggml_win_part struct ggml_tensor * ggml_win_part(