revert back to normal implementation
This commit is contained in:
parent
b2d02ad6b9
commit
6286253e39
|
|
@ -2,7 +2,6 @@
|
|||
#include "ggml.h"
|
||||
|
||||
#define CHUNK_SIZE 64
|
||||
#define BLOCK_SIZE 16
|
||||
|
||||
// Causal Conv1d function for Q,K,V
|
||||
// When qkv is 0, it is Q, 1 is K, 2 is V
|
||||
|
|
@ -88,16 +87,13 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
|
|||
ggml_tensor * chunked_causal_mask =
|
||||
ggml_tri(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, CHUNK_SIZE, CHUNK_SIZE), 1.0f),
|
||||
GGML_TRI_TYPE_LOWER);
|
||||
|
||||
ggml_tensor * chunked_identity = ggml_diag(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, CHUNK_SIZE), 1.0f));
|
||||
ggml_tensor * blocked_causal_mask =
|
||||
ggml_tri(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, BLOCK_SIZE, BLOCK_SIZE), 1.0f),
|
||||
GGML_TRI_TYPE_LOWER);
|
||||
ggml_tensor * blocked_identity = ggml_diag(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, BLOCK_SIZE), 1.0f));
|
||||
ggml_tensor * blocked_diag_mask = ggml_add(ctx0, blocked_causal_mask, blocked_identity);
|
||||
ggml_tensor * chunked_diag_mask = ggml_add(ctx0, chunked_causal_mask, chunked_identity);
|
||||
|
||||
ggml_build_forward_expand(gf, chunked_causal_mask);
|
||||
ggml_build_forward_expand(gf, chunked_identity);
|
||||
ggml_build_forward_expand(gf, blocked_diag_mask);
|
||||
ggml_build_forward_expand(gf, chunked_diag_mask);
|
||||
|
||||
// Kimi dimension constants
|
||||
const int64_t n_head = hparams.n_head();
|
||||
|
|
@ -182,7 +178,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
|
|||
// Choose between build_kda_chunking and build_kda_recurrent based on n_tokens
|
||||
std::pair<ggml_tensor *, ggml_tensor *> attn_out = n_seq_tokens == 1 ?
|
||||
build_kda_autoregressive(Qcur, Kcur, Vcur, g1, beta, state, il) :
|
||||
build_kda_chunking(Qcur, Kcur, Vcur, g1, beta, state, chunked_causal_mask, chunked_identity, blocked_diag_mask, il);
|
||||
build_kda_chunking(Qcur, Kcur, Vcur, g1, beta, state, chunked_causal_mask, chunked_identity, chunked_diag_mask, il);
|
||||
|
||||
ggml_tensor * output = attn_out.first;
|
||||
ggml_tensor * new_state = attn_out.second;
|
||||
|
|
@ -396,14 +392,6 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
|
|||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
|
||||
// Helper to get a slice along dimension 2 (n_chunks dimension)
|
||||
static ggml_tensor * get_slice_2d(ggml_context * ctx, ggml_tensor * t, int64_t chunk) {
|
||||
return ggml_view_4d(ctx, t,
|
||||
t->ne[0], t->ne[1], 1, t->ne[3],
|
||||
t->nb[1], t->nb[2], t->nb[3],
|
||||
chunk * t->nb[2]);
|
||||
}
|
||||
|
||||
/*
|
||||
This is a ggml implementation of the naive_chunk_kda function of
|
||||
https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/kda/naive.py
|
||||
|
|
@ -534,100 +522,46 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_kimi_linear::build_kda_chunkin
|
|||
g_j = g[:, :, i, j:j+1, :]
|
||||
A[..., j] = torch.einsum('... c d, ... d -> ... c', q_i * (g_i - g_j).exp(), k_j)
|
||||
*/
|
||||
// compute matrix multiplication block by block and
|
||||
// skip the blocks of zeros in the upper triangle
|
||||
// Initialize Akk and Aqk to zeros
|
||||
const int64_t CHB = n_chunks * H_k * n_seqs;
|
||||
ggml_tensor * Akk = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, chunk_size, chunk_size, n_chunks, HB);
|
||||
Akk = ggml_clamp(ctx0, Akk, 0.0f, 0.0f);
|
||||
ggml_tensor * Aqk = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, chunk_size, chunk_size, n_chunks, HB);
|
||||
Aqk = ggml_clamp(ctx0, Aqk, 0.0f, 0.0f);
|
||||
const int64_t block_size = BLOCK_SIZE;
|
||||
const int64_t n_blocks = CHUNK_SIZE / BLOCK_SIZE;
|
||||
|
||||
ggml_tensor * k_block[n_blocks];
|
||||
ggml_tensor * q_block[n_blocks];
|
||||
ggml_tensor * gk_block[n_blocks];
|
||||
ggml_tensor * gk_block_bc[n_blocks];
|
||||
for (int64_t j = 0; j < n_blocks; ++j) {
|
||||
int64_t j_start = j * block_size;
|
||||
|
||||
// k_i_block: [S, block_size, C, HB]
|
||||
k_block[j] = ggml_cont(ctx0, ggml_view_4d(ctx0, k,
|
||||
S_k, block_size, n_chunks, HB,
|
||||
k->nb[1], k->nb[2], k->nb[3],
|
||||
j_start * k->nb[1]));
|
||||
|
||||
k_block[j] = ggml_reshape_4d(ctx0, k_block[j], S_k, block_size, 1, CHB);
|
||||
// q_i_block: [S, block_size, C, HB]
|
||||
q_block[j] = ggml_cont(ctx0, ggml_view_4d(ctx0, q,
|
||||
S_k, block_size, n_chunks, HB,
|
||||
q->nb[1], q->nb[2], q->nb[3],
|
||||
j_start * q->nb[1]));
|
||||
ggml_tensor * gkcs_i = ggml_reshape_4d(ctx0, gk_cumsum, chunk_size, 1, S_k, CHB); // [chunk_size, 1, S_k, CHB]
|
||||
ggml_tensor * gkcs_j = ggml_reshape_4d(ctx0, gkcs_i, 1, chunk_size, S_k, CHB); // [1, chunk_size, S_k, CHB]
|
||||
|
||||
q_block[j] = ggml_reshape_4d(ctx0, q_block[j], S_k, block_size, 1, CHB);
|
||||
// gk_j_block: [S, block_size, C, HB]
|
||||
gk_block[j] = ggml_cont(ctx0, ggml_view_4d(ctx0, gk_cumsum,
|
||||
block_size, S_k, n_chunks, HB,
|
||||
gk_cumsum->nb[1], gk_cumsum->nb[2], gk_cumsum->nb[3],
|
||||
j_start * gk_cumsum->nb[0]));
|
||||
gk_block[j] = ggml_reshape_4d(ctx0, gk_block[j], 1, block_size, S_k, CHB);
|
||||
ggml_tensor * gkcs_j_bc = ggml_repeat_4d(ctx0, gkcs_j, chunk_size, chunk_size, S_k, CHB); // [1, chunk_size, S_k, CHB] -> [chunk_size, chunk_size, S_k, CHB]
|
||||
// decay_mask [chunk_size,chunk_size,S_k,CHB]
|
||||
ggml_tensor * decay_mask = ggml_sub(ctx0, gkcs_j_bc, gkcs_i);
|
||||
cb(decay_mask, "decay_mask", il);
|
||||
|
||||
gk_block_bc[j] = ggml_repeat_4d(ctx0, gk_block[j], block_size, block_size, S_k, CHB);
|
||||
decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
|
||||
cb(decay_mask, "decay_masked", il);
|
||||
decay_mask = ggml_exp(ctx0, decay_mask);
|
||||
decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
|
||||
|
||||
gk_block[j] = ggml_reshape_4d(ctx0, gk_block[j], block_size, 1, S_k, CHB);
|
||||
}
|
||||
// decay_mask [S_k,BT_j,BT_i,CHB] *Note* second and third chunk_sizes are switched
|
||||
decay_mask = ggml_cont_4d(ctx0, ggml_permute(ctx0, decay_mask, 2, 1, 0, 3), S_k, chunk_size, chunk_size, CHB);
|
||||
|
||||
for (int64_t j = 0; j < n_blocks; ++j) {
|
||||
int64_t j_start = j * block_size;
|
||||
ggml_tensor * k_i = ggml_reshape_4d(ctx0, k, S_k, chunk_size, 1, CHB);
|
||||
ggml_tensor * k_j = ggml_reshape_4d(ctx0, k, S_k, 1, chunk_size, CHB);
|
||||
ggml_tensor * q_i = ggml_reshape_4d(ctx0, q, S_k, chunk_size, 1, CHB);
|
||||
|
||||
ggml_tensor * k_j_block = ggml_reshape_4d(ctx0, k_block[j], S_k, 1, block_size, CHB);
|
||||
for (int64_t i = 0; i <= j; ++i) {
|
||||
int64_t i_start = i * block_size;
|
||||
ggml_tensor * decay_k_i = ggml_mul(ctx0, decay_mask, k_i);
|
||||
ggml_tensor * decay_q_i = ggml_mul(ctx0, decay_mask, q_i);
|
||||
|
||||
ggml_tensor * decay_mask = ggml_sub(ctx0, gk_block_bc[j], gk_block[i]);
|
||||
cb(decay_mask, "decay_mask", il);
|
||||
|
||||
// Apply diag_mask only at diagnoal blocks
|
||||
if (i == j) {
|
||||
decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
|
||||
decay_mask = ggml_exp(ctx0, decay_mask);
|
||||
decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
|
||||
} else {
|
||||
decay_mask = ggml_exp(ctx0, decay_mask);
|
||||
}
|
||||
|
||||
// decay_mask [S_k,BT_j,BT_i,ShHB] *Note* second and third chunk_sizes are switched
|
||||
decay_mask = ggml_cont_4d(ctx0, ggml_permute(ctx0, decay_mask, 2, 1, 0, 3), S_k, block_size, block_size, CHB);
|
||||
|
||||
ggml_tensor * decay_k_i = ggml_mul(ctx0, decay_mask, k_block[i]);
|
||||
ggml_tensor * decay_q_i = ggml_mul(ctx0, decay_mask, q_block[i]);
|
||||
|
||||
ggml_tensor * Akk_block = ggml_mul_mat(ctx0, decay_k_i, k_j_block);
|
||||
ggml_tensor * Aqk_block = ggml_mul_mat(ctx0, decay_q_i, k_j_block);
|
||||
|
||||
Akk_block = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, Akk_block, block_size, block_size, n_chunks, HB)));
|
||||
Aqk_block = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, Aqk_block, block_size, block_size, n_chunks, HB)));
|
||||
|
||||
if (i == j) {
|
||||
Aqk_block = ggml_mul(ctx0, Aqk_block, diag_mask);
|
||||
}
|
||||
Aqk_block = ggml_scale(ctx0, Aqk_block, scale); // scale q
|
||||
|
||||
// Copy to Akk at position [j_start:j_end, i_start:i_end]
|
||||
Akk = ggml_set(ctx0, Akk, Akk_block,
|
||||
Akk->nb[1], Akk->nb[2], Akk->nb[3],
|
||||
i_start * Akk->nb[0] + j_start * Akk->nb[1]);
|
||||
Aqk = ggml_set(ctx0, Aqk, Aqk_block,
|
||||
Aqk->nb[1], Aqk->nb[2], Aqk->nb[3],
|
||||
i_start * Aqk->nb[0] + j_start * Aqk->nb[1]);
|
||||
}
|
||||
}
|
||||
Akk = ggml_mul(ctx0, Akk, beta);
|
||||
Akk = ggml_neg(ctx0, ggml_mul(ctx0, Akk, causal_mask));
|
||||
// decay_k_i [S.BT,BT,CHB] @ k_j [S,1,BT,CHB] = Akk [BT,1,BT,CHB]
|
||||
ggml_tensor * Akk = ggml_mul_mat(ctx0, decay_k_i, k_j);
|
||||
ggml_tensor * Aqk = ggml_mul_mat(ctx0, decay_q_i, k_j);
|
||||
Akk = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, Akk, chunk_size, chunk_size, n_chunks, HB)));
|
||||
Aqk = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, Aqk, chunk_size, chunk_size, n_chunks, HB)));
|
||||
cb(Akk, "Akk", il);
|
||||
cb(Aqk, "Aqk", il);
|
||||
|
||||
Akk = ggml_mul(ctx0, Akk, beta);
|
||||
Akk = ggml_neg(ctx0, ggml_mul(ctx0, Akk, causal_mask));
|
||||
cb(Akk, "attn_pre_solve", il);
|
||||
|
||||
Aqk = ggml_mul(ctx0, Aqk, diag_mask);
|
||||
Aqk = ggml_scale(ctx0, Aqk, scale); // scale q
|
||||
cb(Aqk, "Aqk_masked", il);
|
||||
|
||||
// for i in range(1, chunk_size):
|
||||
// row = attn[..., i, :i].clone()
|
||||
// sub = attn[..., :i, :i].clone()
|
||||
|
|
@ -664,14 +598,27 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_kimi_linear::build_kda_chunkin
|
|||
cb(new_state, "new_state", il);
|
||||
|
||||
for (int64_t chunk = 0; chunk < n_chunks; chunk++) {
|
||||
ggml_tensor * k_chunk = get_slice_2d(ctx0, k, chunk);
|
||||
ggml_tensor * q_chunk = get_slice_2d(ctx0, q, chunk);
|
||||
ggml_tensor * vb_chunk = get_slice_2d(ctx0, vb, chunk);
|
||||
// extract one chunk worth of data
|
||||
auto chunkify = [=](ggml_tensor * t) {
|
||||
return ggml_cont(ctx0, ggml_view_4d(ctx0, t, t->ne[0], chunk_size, 1, t->ne[3],
|
||||
t->nb[1], t->nb[2], t->nb[3], t->nb[2] * chunk));
|
||||
};
|
||||
auto chunkify_A = [=](ggml_tensor * t) {
|
||||
return ggml_cont(ctx0, ggml_view_4d(ctx0, t, chunk_size, chunk_size, 1, t->ne[3],
|
||||
t->nb[1], t->nb[2], t->nb[3], t->nb[2] * chunk));
|
||||
};
|
||||
|
||||
ggml_tensor * gk_cs_chunk = ggml_cont(ctx0, get_slice_2d(ctx0, gk_cumsum, chunk));
|
||||
ggml_tensor * k_cumdecay_chunk = get_slice_2d(ctx0, k_cumdecay, chunk);
|
||||
|
||||
// k [S,BT,NT,H*B] => k_chunk [S,BT,1,H*B]
|
||||
ggml_tensor * k_chunk = chunkify(k);
|
||||
ggml_tensor * q_chunk = chunkify(q);
|
||||
ggml_tensor * vb_chunk = chunkify(vb);
|
||||
|
||||
// gk_cumsum [S,BT,NT,H*B] => gk_cs_chunk [S,BT,1,H*B]
|
||||
ggml_tensor * gk_cs_chunk = chunkify(gk_cumsum);
|
||||
ggml_tensor * k_cumdecay_chunk = chunkify(k_cumdecay);
|
||||
ggml_tensor * gkexp_chunk = ggml_exp(ctx0, gk_cs_chunk);
|
||||
ggml_tensor * Aqk_chunk = get_slice_2d(ctx0, Aqk, chunk);
|
||||
ggml_tensor * Aqk_chunk = chunkify_A(Aqk);
|
||||
|
||||
ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs);
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue