4x4 16x16 blocks computation for Akk and Aqk

This commit is contained in:
Yee Man Chan 2026-02-06 19:03:09 +08:00
parent 6456393bbd
commit 17cd6e8514
1 changed files with 100 additions and 33 deletions

View File

@ -2,6 +2,7 @@
#include "ggml.h"
#define CHUNK_SIZE 64
#define BLOCK_SIZE 16
// Causal Conv1d function for Q,K,V
// When qkv is 0, it is Q, 1 is K, 2 is V
@ -84,13 +85,16 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
ggml_tensor * chunked_causal_mask =
ggml_tri(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, CHUNK_SIZE, CHUNK_SIZE), 1.0f),
GGML_TRI_TYPE_LOWER);
ggml_tensor * chunked_identity = ggml_diag(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, CHUNK_SIZE), 1.0f));
ggml_tensor * chunked_diag_mask = ggml_add(ctx0, chunked_causal_mask, chunked_identity);
ggml_tensor * blocked_causal_mask =
ggml_tri(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, BLOCK_SIZE, BLOCK_SIZE), 1.0f),
GGML_TRI_TYPE_LOWER);
ggml_tensor * blocked_identity = ggml_diag(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, BLOCK_SIZE), 1.0f));
ggml_tensor * blocked_diag_mask = ggml_add(ctx0, blocked_causal_mask, blocked_identity);
ggml_build_forward_expand(gf, chunked_causal_mask);
ggml_build_forward_expand(gf, chunked_identity);
ggml_build_forward_expand(gf, chunked_diag_mask);
ggml_build_forward_expand(gf, blocked_diag_mask);
// Kimi dimension constants
const int64_t n_head = hparams.n_head();
@ -175,7 +179,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
// Choose between build_kda_chunking and build_kda_recurrent based on n_tokens
std::pair<ggml_tensor *, ggml_tensor *> attn_out = n_seq_tokens == 1 ?
build_kda_autoregressive(Qcur, Kcur, Vcur, g1, beta, state, il) :
build_kda_chunking(Qcur, Kcur, Vcur, g1, beta, state, chunked_causal_mask, chunked_identity, chunked_diag_mask, il);
build_kda_chunking(Qcur, Kcur, Vcur, g1, beta, state, chunked_causal_mask, chunked_identity, blocked_diag_mask, il);
ggml_tensor * output = attn_out.first;
ggml_tensor * new_state = attn_out.second;
@ -404,6 +408,11 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_kimi_linear::build_kda_chunkin
ggml_tensor * identity,
ggml_tensor * diag_mask,
int il) {
GGML_ASSERT(ggml_is_contiguous(q));
GGML_ASSERT(ggml_is_contiguous(k));
GGML_ASSERT(ggml_is_contiguous(v));
GGML_ASSERT(ggml_is_contiguous(gk));
GGML_ASSERT(ggml_is_contiguous(beta));
GGML_ASSERT(ggml_is_contiguous(state));
const int64_t S_k = q->ne[0];
@ -519,45 +528,99 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_kimi_linear::build_kda_chunkin
g_j = g[:, :, i, j:j+1, :]
A[..., j] = torch.einsum('... c d, ... d -> ... c', q_i * (g_i - g_j).exp(), k_j)
*/
// compute matrix multiplication block by block and
// skip the blocks of zeros in the upper triangle
// Initialize Akk and Aqk to zeros
const int64_t CHB = n_chunks * H_k * n_seqs;
ggml_tensor * gkcs_i = ggml_reshape_4d(ctx0, gk_cumsum, chunk_size, 1, S_k, CHB); // [chunk_size, 1, S_k, CHB]
ggml_tensor * gkcs_j = ggml_reshape_4d(ctx0, gkcs_i, 1, chunk_size, S_k, CHB); // [1, chunk_size, S_k, CHB]
ggml_tensor * Akk = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, chunk_size, chunk_size, n_chunks, HB);
Akk = ggml_clamp(ctx0, Akk, 0.0f, 0.0f);
ggml_tensor * Aqk = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, chunk_size, chunk_size, n_chunks, HB);
Aqk = ggml_clamp(ctx0, Aqk, 0.0f, 0.0f);
const int64_t block_size = BLOCK_SIZE;
const int64_t n_blocks = CHUNK_SIZE / BLOCK_SIZE;
ggml_tensor * k_block[n_blocks];
ggml_tensor * q_block[n_blocks];
ggml_tensor * gk_block[n_blocks];
ggml_tensor * gk_block_bc[n_blocks];
for (int64_t j = 0; j < n_blocks; ++j) {
int64_t j_start = j * block_size;
// k_i_block: [S, block_size, C, HB]
k_block[j] = ggml_cont(ctx0, ggml_view_4d(ctx0, k,
S_k, block_size, n_chunks, HB,
k->nb[1], k->nb[2], k->nb[3],
j_start * k->nb[1]));
k_block[j] = ggml_reshape_4d(ctx0, k_block[j], S_k, block_size, 1, CHB);
// q_i_block: [S, block_size, C, HB]
q_block[j] = ggml_cont(ctx0, ggml_view_4d(ctx0, q,
S_k, block_size, n_chunks, HB,
q->nb[1], q->nb[2], q->nb[3],
j_start * q->nb[1]));
ggml_tensor * gkcs_j_bc = ggml_repeat_4d(ctx0, gkcs_j, chunk_size, chunk_size, S_k, CHB); // [1, chunk_size, S_k, CHB] -> [chunk_size, chunk_size, S_k, CHB]
// decay_mask [chunk_size,chunk_size,S_k,CHB]
ggml_tensor * decay_mask = ggml_sub(ctx0, gkcs_j_bc, gkcs_i);
cb(decay_mask, "decay_mask", il);
q_block[j] = ggml_reshape_4d(ctx0, q_block[j], S_k, block_size, 1, CHB);
// gk_j_block: [S, block_size, C, HB]
gk_block[j] = ggml_cont(ctx0, ggml_view_4d(ctx0, gk_cumsum,
block_size, S_k, n_chunks, HB,
gk_cumsum->nb[1], gk_cumsum->nb[2], gk_cumsum->nb[3],
j_start * gk_cumsum->nb[0]));
gk_block[j] = ggml_reshape_4d(ctx0, gk_block[j], 1, block_size, S_k, CHB);
decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
cb(decay_mask, "decay_masked", il);
decay_mask = ggml_exp(ctx0, decay_mask);
decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
gk_block_bc[j] = ggml_repeat_4d(ctx0, gk_block[j], block_size, block_size, S_k, CHB);
// decay_mask [S_k,BT_j,BT_i,CHB] *Note* second and third chunk_sizes are switched
decay_mask = ggml_cont_4d(ctx0, ggml_permute(ctx0, decay_mask, 2, 1, 0, 3), S_k, chunk_size, chunk_size, CHB);
gk_block[j] = ggml_reshape_4d(ctx0, gk_block[j], block_size, 1, S_k, CHB);
}
ggml_tensor * k_i = ggml_reshape_4d(ctx0, k, S_k, chunk_size, 1, CHB);
ggml_tensor * k_j = ggml_reshape_4d(ctx0, k, S_k, 1, chunk_size, CHB);
ggml_tensor * q_i = ggml_reshape_4d(ctx0, q, S_k, chunk_size, 1, CHB);
for (int64_t j = 0; j < n_blocks; ++j) {
int64_t j_start = j * block_size;
ggml_tensor * decay_k_i = ggml_mul(ctx0, decay_mask, k_i);
ggml_tensor * decay_q_i = ggml_mul(ctx0, decay_mask, q_i);
ggml_tensor * k_j_block = ggml_reshape_4d(ctx0, k_block[j], S_k, 1, block_size, CHB);
for (int64_t i = 0; i <= j; ++i) {
int64_t i_start = i * block_size;
// decay_k_i [S.BT,BT,CHB] @ k_j [S,1,BT,CHB] = Akk [BT,1,BT,CHB]
ggml_tensor * Akk = ggml_mul_mat(ctx0, decay_k_i, k_j);
ggml_tensor * Aqk = ggml_mul_mat(ctx0, decay_q_i, k_j);
Akk = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, Akk, chunk_size, chunk_size, n_chunks, HB)));
Aqk = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, Aqk, chunk_size, chunk_size, n_chunks, HB)));
cb(Akk, "Akk", il);
cb(Aqk, "Aqk", il);
ggml_tensor * decay_mask = ggml_sub(ctx0, gk_block_bc[j], gk_block[i]);
cb(decay_mask, "decay_mask", il);
// Apply diag_mask only at diagnoal blocks
if (i == j) {
decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
decay_mask = ggml_exp(ctx0, decay_mask);
decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
} else {
decay_mask = ggml_exp(ctx0, decay_mask);
}
// decay_mask [S_k,BT_j,BT_i,ShHB] *Note* second and third chunk_sizes are switched
decay_mask = ggml_cont_4d(ctx0, ggml_permute(ctx0, decay_mask, 2, 1, 0, 3), S_k, block_size, block_size, CHB);
ggml_tensor * decay_k_i = ggml_mul(ctx0, decay_mask, k_block[i]);
ggml_tensor * decay_q_i = ggml_mul(ctx0, decay_mask, q_block[i]);
ggml_tensor * Akk_block = ggml_mul_mat(ctx0, decay_k_i, k_j_block);
ggml_tensor * Aqk_block = ggml_mul_mat(ctx0, decay_q_i, k_j_block);
Akk_block = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, Akk_block, block_size, block_size, n_chunks, HB)));
Aqk_block = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, Aqk_block, block_size, block_size, n_chunks, HB)));
if (i == j) {
Aqk_block = ggml_mul(ctx0, Aqk_block, diag_mask);
}
Aqk_block = ggml_scale(ctx0, Aqk_block, scale); // scale q
// Accumulate into Akk at position [j_start:j_end, i_start:i_end]
Akk = ggml_acc(ctx0, Akk, Akk_block,
Akk->nb[1], Akk->nb[2], Akk->nb[3],
i_start * Akk->nb[0] + j_start * Akk->nb[1]);
Aqk = ggml_acc(ctx0, Aqk, Aqk_block,
Aqk->nb[1], Aqk->nb[2], Aqk->nb[3],
i_start * Aqk->nb[0] + j_start * Aqk->nb[1]);
}
}
Akk = ggml_mul(ctx0, Akk, beta);
Akk = ggml_neg(ctx0, ggml_mul(ctx0, Akk, causal_mask));
cb(Akk, "attn_pre_solve", il);
Aqk = ggml_mul(ctx0, Aqk, diag_mask);
Aqk = ggml_scale(ctx0, Aqk, scale); // scale q
cb(Aqk, "Aqk_masked", il);
cb(Akk, "Akk", il);
cb(Aqk, "Aqk", il);
// for i in range(1, chunk_size):
// row = attn[..., i, :i].clone()
@ -690,8 +753,12 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_kimi_linear::build_kda_autoreg
ggml_tensor * beta,
ggml_tensor * state,
int il) {
GGML_ASSERT(ggml_is_contiguous(q));
GGML_ASSERT(ggml_is_contiguous(k));
GGML_ASSERT(ggml_is_contiguous(v));
GGML_ASSERT(ggml_is_contiguous(gk));
GGML_ASSERT(ggml_is_contiguous(beta));
GGML_ASSERT(ggml_is_contiguous(state));
const int64_t S_k = q->ne[0];
const int64_t H_k = q->ne[1];