revert back to normal implementation

This commit is contained in:
Yee Man Chan 2026-02-13 08:06:29 +08:00
parent b2d02ad6b9
commit 6286253e39
1 changed files with 53 additions and 106 deletions

View File

@ -2,7 +2,6 @@
#include "ggml.h"
#define CHUNK_SIZE 64
#define BLOCK_SIZE 16
// Causal Conv1d function for Q,K,V
// When qkv is 0, it is Q, 1 is K, 2 is V
@ -88,16 +87,13 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
ggml_tensor * chunked_causal_mask =
ggml_tri(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, CHUNK_SIZE, CHUNK_SIZE), 1.0f),
GGML_TRI_TYPE_LOWER);
ggml_tensor * chunked_identity = ggml_diag(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, CHUNK_SIZE), 1.0f));
ggml_tensor * blocked_causal_mask =
ggml_tri(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, BLOCK_SIZE, BLOCK_SIZE), 1.0f),
GGML_TRI_TYPE_LOWER);
ggml_tensor * blocked_identity = ggml_diag(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, BLOCK_SIZE), 1.0f));
ggml_tensor * blocked_diag_mask = ggml_add(ctx0, blocked_causal_mask, blocked_identity);
ggml_tensor * chunked_diag_mask = ggml_add(ctx0, chunked_causal_mask, chunked_identity);
ggml_build_forward_expand(gf, chunked_causal_mask);
ggml_build_forward_expand(gf, chunked_identity);
ggml_build_forward_expand(gf, blocked_diag_mask);
ggml_build_forward_expand(gf, chunked_diag_mask);
// Kimi dimension constants
const int64_t n_head = hparams.n_head();
@ -182,7 +178,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
// Choose between build_kda_chunking and build_kda_recurrent based on n_tokens
std::pair<ggml_tensor *, ggml_tensor *> attn_out = n_seq_tokens == 1 ?
build_kda_autoregressive(Qcur, Kcur, Vcur, g1, beta, state, il) :
build_kda_chunking(Qcur, Kcur, Vcur, g1, beta, state, chunked_causal_mask, chunked_identity, blocked_diag_mask, il);
build_kda_chunking(Qcur, Kcur, Vcur, g1, beta, state, chunked_causal_mask, chunked_identity, chunked_diag_mask, il);
ggml_tensor * output = attn_out.first;
ggml_tensor * new_state = attn_out.second;
@ -396,14 +392,6 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
ggml_build_forward_expand(gf, cur);
}
// Helper to get a slice along dimension 2 (n_chunks dimension)
static ggml_tensor * get_slice_2d(ggml_context * ctx, ggml_tensor * t, int64_t chunk) {
return ggml_view_4d(ctx, t,
t->ne[0], t->ne[1], 1, t->ne[3],
t->nb[1], t->nb[2], t->nb[3],
chunk * t->nb[2]);
}
/*
This is a ggml implementation of the naive_chunk_kda function of
https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/kda/naive.py
@ -534,100 +522,46 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_kimi_linear::build_kda_chunkin
g_j = g[:, :, i, j:j+1, :]
A[..., j] = torch.einsum('... c d, ... d -> ... c', q_i * (g_i - g_j).exp(), k_j)
*/
// compute matrix multiplication block by block and
// skip the blocks of zeros in the upper triangle
// Initialize Akk and Aqk to zeros
const int64_t CHB = n_chunks * H_k * n_seqs;
ggml_tensor * Akk = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, chunk_size, chunk_size, n_chunks, HB);
Akk = ggml_clamp(ctx0, Akk, 0.0f, 0.0f);
ggml_tensor * Aqk = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, chunk_size, chunk_size, n_chunks, HB);
Aqk = ggml_clamp(ctx0, Aqk, 0.0f, 0.0f);
const int64_t block_size = BLOCK_SIZE;
const int64_t n_blocks = CHUNK_SIZE / BLOCK_SIZE;
ggml_tensor * k_block[n_blocks];
ggml_tensor * q_block[n_blocks];
ggml_tensor * gk_block[n_blocks];
ggml_tensor * gk_block_bc[n_blocks];
for (int64_t j = 0; j < n_blocks; ++j) {
int64_t j_start = j * block_size;
// k_i_block: [S, block_size, C, HB]
k_block[j] = ggml_cont(ctx0, ggml_view_4d(ctx0, k,
S_k, block_size, n_chunks, HB,
k->nb[1], k->nb[2], k->nb[3],
j_start * k->nb[1]));
k_block[j] = ggml_reshape_4d(ctx0, k_block[j], S_k, block_size, 1, CHB);
// q_i_block: [S, block_size, C, HB]
q_block[j] = ggml_cont(ctx0, ggml_view_4d(ctx0, q,
S_k, block_size, n_chunks, HB,
q->nb[1], q->nb[2], q->nb[3],
j_start * q->nb[1]));
ggml_tensor * gkcs_i = ggml_reshape_4d(ctx0, gk_cumsum, chunk_size, 1, S_k, CHB); // [chunk_size, 1, S_k, CHB]
ggml_tensor * gkcs_j = ggml_reshape_4d(ctx0, gkcs_i, 1, chunk_size, S_k, CHB); // [1, chunk_size, S_k, CHB]
q_block[j] = ggml_reshape_4d(ctx0, q_block[j], S_k, block_size, 1, CHB);
// gk_j_block: [S, block_size, C, HB]
gk_block[j] = ggml_cont(ctx0, ggml_view_4d(ctx0, gk_cumsum,
block_size, S_k, n_chunks, HB,
gk_cumsum->nb[1], gk_cumsum->nb[2], gk_cumsum->nb[3],
j_start * gk_cumsum->nb[0]));
gk_block[j] = ggml_reshape_4d(ctx0, gk_block[j], 1, block_size, S_k, CHB);
ggml_tensor * gkcs_j_bc = ggml_repeat_4d(ctx0, gkcs_j, chunk_size, chunk_size, S_k, CHB); // [1, chunk_size, S_k, CHB] -> [chunk_size, chunk_size, S_k, CHB]
// decay_mask [chunk_size,chunk_size,S_k,CHB]
ggml_tensor * decay_mask = ggml_sub(ctx0, gkcs_j_bc, gkcs_i);
cb(decay_mask, "decay_mask", il);
gk_block_bc[j] = ggml_repeat_4d(ctx0, gk_block[j], block_size, block_size, S_k, CHB);
decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
cb(decay_mask, "decay_masked", il);
decay_mask = ggml_exp(ctx0, decay_mask);
decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
gk_block[j] = ggml_reshape_4d(ctx0, gk_block[j], block_size, 1, S_k, CHB);
}
// decay_mask [S_k,BT_j,BT_i,CHB] *Note* second and third chunk_sizes are switched
decay_mask = ggml_cont_4d(ctx0, ggml_permute(ctx0, decay_mask, 2, 1, 0, 3), S_k, chunk_size, chunk_size, CHB);
for (int64_t j = 0; j < n_blocks; ++j) {
int64_t j_start = j * block_size;
ggml_tensor * k_i = ggml_reshape_4d(ctx0, k, S_k, chunk_size, 1, CHB);
ggml_tensor * k_j = ggml_reshape_4d(ctx0, k, S_k, 1, chunk_size, CHB);
ggml_tensor * q_i = ggml_reshape_4d(ctx0, q, S_k, chunk_size, 1, CHB);
ggml_tensor * k_j_block = ggml_reshape_4d(ctx0, k_block[j], S_k, 1, block_size, CHB);
for (int64_t i = 0; i <= j; ++i) {
int64_t i_start = i * block_size;
ggml_tensor * decay_k_i = ggml_mul(ctx0, decay_mask, k_i);
ggml_tensor * decay_q_i = ggml_mul(ctx0, decay_mask, q_i);
ggml_tensor * decay_mask = ggml_sub(ctx0, gk_block_bc[j], gk_block[i]);
cb(decay_mask, "decay_mask", il);
// Apply diag_mask only at diagnoal blocks
if (i == j) {
decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
decay_mask = ggml_exp(ctx0, decay_mask);
decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
} else {
decay_mask = ggml_exp(ctx0, decay_mask);
}
// decay_mask [S_k,BT_j,BT_i,ShHB] *Note* second and third chunk_sizes are switched
decay_mask = ggml_cont_4d(ctx0, ggml_permute(ctx0, decay_mask, 2, 1, 0, 3), S_k, block_size, block_size, CHB);
ggml_tensor * decay_k_i = ggml_mul(ctx0, decay_mask, k_block[i]);
ggml_tensor * decay_q_i = ggml_mul(ctx0, decay_mask, q_block[i]);
ggml_tensor * Akk_block = ggml_mul_mat(ctx0, decay_k_i, k_j_block);
ggml_tensor * Aqk_block = ggml_mul_mat(ctx0, decay_q_i, k_j_block);
Akk_block = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, Akk_block, block_size, block_size, n_chunks, HB)));
Aqk_block = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, Aqk_block, block_size, block_size, n_chunks, HB)));
if (i == j) {
Aqk_block = ggml_mul(ctx0, Aqk_block, diag_mask);
}
Aqk_block = ggml_scale(ctx0, Aqk_block, scale); // scale q
// Copy to Akk at position [j_start:j_end, i_start:i_end]
Akk = ggml_set(ctx0, Akk, Akk_block,
Akk->nb[1], Akk->nb[2], Akk->nb[3],
i_start * Akk->nb[0] + j_start * Akk->nb[1]);
Aqk = ggml_set(ctx0, Aqk, Aqk_block,
Aqk->nb[1], Aqk->nb[2], Aqk->nb[3],
i_start * Aqk->nb[0] + j_start * Aqk->nb[1]);
}
}
Akk = ggml_mul(ctx0, Akk, beta);
Akk = ggml_neg(ctx0, ggml_mul(ctx0, Akk, causal_mask));
// decay_k_i [S.BT,BT,CHB] @ k_j [S,1,BT,CHB] = Akk [BT,1,BT,CHB]
ggml_tensor * Akk = ggml_mul_mat(ctx0, decay_k_i, k_j);
ggml_tensor * Aqk = ggml_mul_mat(ctx0, decay_q_i, k_j);
Akk = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, Akk, chunk_size, chunk_size, n_chunks, HB)));
Aqk = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, Aqk, chunk_size, chunk_size, n_chunks, HB)));
cb(Akk, "Akk", il);
cb(Aqk, "Aqk", il);
Akk = ggml_mul(ctx0, Akk, beta);
Akk = ggml_neg(ctx0, ggml_mul(ctx0, Akk, causal_mask));
cb(Akk, "attn_pre_solve", il);
Aqk = ggml_mul(ctx0, Aqk, diag_mask);
Aqk = ggml_scale(ctx0, Aqk, scale); // scale q
cb(Aqk, "Aqk_masked", il);
// for i in range(1, chunk_size):
// row = attn[..., i, :i].clone()
// sub = attn[..., :i, :i].clone()
@ -664,14 +598,27 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_kimi_linear::build_kda_chunkin
cb(new_state, "new_state", il);
for (int64_t chunk = 0; chunk < n_chunks; chunk++) {
ggml_tensor * k_chunk = get_slice_2d(ctx0, k, chunk);
ggml_tensor * q_chunk = get_slice_2d(ctx0, q, chunk);
ggml_tensor * vb_chunk = get_slice_2d(ctx0, vb, chunk);
// extract one chunk worth of data
auto chunkify = [=](ggml_tensor * t) {
return ggml_cont(ctx0, ggml_view_4d(ctx0, t, t->ne[0], chunk_size, 1, t->ne[3],
t->nb[1], t->nb[2], t->nb[3], t->nb[2] * chunk));
};
auto chunkify_A = [=](ggml_tensor * t) {
return ggml_cont(ctx0, ggml_view_4d(ctx0, t, chunk_size, chunk_size, 1, t->ne[3],
t->nb[1], t->nb[2], t->nb[3], t->nb[2] * chunk));
};
ggml_tensor * gk_cs_chunk = ggml_cont(ctx0, get_slice_2d(ctx0, gk_cumsum, chunk));
ggml_tensor * k_cumdecay_chunk = get_slice_2d(ctx0, k_cumdecay, chunk);
// k [S,BT,NT,H*B] => k_chunk [S,BT,1,H*B]
ggml_tensor * k_chunk = chunkify(k);
ggml_tensor * q_chunk = chunkify(q);
ggml_tensor * vb_chunk = chunkify(vb);
// gk_cumsum [S,BT,NT,H*B] => gk_cs_chunk [S,BT,1,H*B]
ggml_tensor * gk_cs_chunk = chunkify(gk_cumsum);
ggml_tensor * k_cumdecay_chunk = chunkify(k_cumdecay);
ggml_tensor * gkexp_chunk = ggml_exp(ctx0, gk_cs_chunk);
ggml_tensor * Aqk_chunk = get_slice_2d(ctx0, Aqk, chunk);
ggml_tensor * Aqk_chunk = chunkify_A(Aqk);
ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs);