diff --git a/src/models/kimi-linear.cpp b/src/models/kimi-linear.cpp index f645e46df9..942844d071 100644 --- a/src/models/kimi-linear.cpp +++ b/src/models/kimi-linear.cpp @@ -2,7 +2,6 @@ #include "ggml.h" #define CHUNK_SIZE 64 -#define BLOCK_SIZE 16 // Causal Conv1d function for Q,K,V // When qkv is 0, it is Q, 1 is K, 2 is V @@ -88,16 +87,13 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll ggml_tensor * chunked_causal_mask = ggml_tri(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, CHUNK_SIZE, CHUNK_SIZE), 1.0f), GGML_TRI_TYPE_LOWER); + ggml_tensor * chunked_identity = ggml_diag(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, CHUNK_SIZE), 1.0f)); - ggml_tensor * blocked_causal_mask = - ggml_tri(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, BLOCK_SIZE, BLOCK_SIZE), 1.0f), - GGML_TRI_TYPE_LOWER); - ggml_tensor * blocked_identity = ggml_diag(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, BLOCK_SIZE), 1.0f)); - ggml_tensor * blocked_diag_mask = ggml_add(ctx0, blocked_causal_mask, blocked_identity); + ggml_tensor * chunked_diag_mask = ggml_add(ctx0, chunked_causal_mask, chunked_identity); ggml_build_forward_expand(gf, chunked_causal_mask); ggml_build_forward_expand(gf, chunked_identity); - ggml_build_forward_expand(gf, blocked_diag_mask); + ggml_build_forward_expand(gf, chunked_diag_mask); // Kimi dimension constants const int64_t n_head = hparams.n_head(); @@ -182,7 +178,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll // Choose between build_kda_chunking and build_kda_recurrent based on n_tokens std::pair attn_out = n_seq_tokens == 1 ? build_kda_autoregressive(Qcur, Kcur, Vcur, g1, beta, state, il) : - build_kda_chunking(Qcur, Kcur, Vcur, g1, beta, state, chunked_causal_mask, chunked_identity, blocked_diag_mask, il); + build_kda_chunking(Qcur, Kcur, Vcur, g1, beta, state, chunked_causal_mask, chunked_identity, chunked_diag_mask, il); ggml_tensor * output = attn_out.first; ggml_tensor * new_state = attn_out.second; @@ -396,14 +392,6 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll ggml_build_forward_expand(gf, cur); } -// Helper to get a slice along dimension 2 (n_chunks dimension) -static ggml_tensor * get_slice_2d(ggml_context * ctx, ggml_tensor * t, int64_t chunk) { - return ggml_view_4d(ctx, t, - t->ne[0], t->ne[1], 1, t->ne[3], - t->nb[1], t->nb[2], t->nb[3], - chunk * t->nb[2]); -} - /* This is a ggml implementation of the naive_chunk_kda function of https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/kda/naive.py @@ -534,100 +522,46 @@ std::pair llm_build_kimi_linear::build_kda_chunkin g_j = g[:, :, i, j:j+1, :] A[..., j] = torch.einsum('... c d, ... d -> ... c', q_i * (g_i - g_j).exp(), k_j) */ - // compute matrix multiplication block by block and - // skip the blocks of zeros in the upper triangle - // Initialize Akk and Aqk to zeros const int64_t CHB = n_chunks * H_k * n_seqs; - ggml_tensor * Akk = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, chunk_size, chunk_size, n_chunks, HB); - Akk = ggml_clamp(ctx0, Akk, 0.0f, 0.0f); - ggml_tensor * Aqk = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, chunk_size, chunk_size, n_chunks, HB); - Aqk = ggml_clamp(ctx0, Aqk, 0.0f, 0.0f); - const int64_t block_size = BLOCK_SIZE; - const int64_t n_blocks = CHUNK_SIZE / BLOCK_SIZE; - - ggml_tensor * k_block[n_blocks]; - ggml_tensor * q_block[n_blocks]; - ggml_tensor * gk_block[n_blocks]; - ggml_tensor * gk_block_bc[n_blocks]; - for (int64_t j = 0; j < n_blocks; ++j) { - int64_t j_start = j * block_size; - - // k_i_block: [S, block_size, C, HB] - k_block[j] = ggml_cont(ctx0, ggml_view_4d(ctx0, k, - S_k, block_size, n_chunks, HB, - k->nb[1], k->nb[2], k->nb[3], - j_start * k->nb[1])); - - k_block[j] = ggml_reshape_4d(ctx0, k_block[j], S_k, block_size, 1, CHB); - // q_i_block: [S, block_size, C, HB] - q_block[j] = ggml_cont(ctx0, ggml_view_4d(ctx0, q, - S_k, block_size, n_chunks, HB, - q->nb[1], q->nb[2], q->nb[3], - j_start * q->nb[1])); + ggml_tensor * gkcs_i = ggml_reshape_4d(ctx0, gk_cumsum, chunk_size, 1, S_k, CHB); // [chunk_size, 1, S_k, CHB] + ggml_tensor * gkcs_j = ggml_reshape_4d(ctx0, gkcs_i, 1, chunk_size, S_k, CHB); // [1, chunk_size, S_k, CHB] - q_block[j] = ggml_reshape_4d(ctx0, q_block[j], S_k, block_size, 1, CHB); - // gk_j_block: [S, block_size, C, HB] - gk_block[j] = ggml_cont(ctx0, ggml_view_4d(ctx0, gk_cumsum, - block_size, S_k, n_chunks, HB, - gk_cumsum->nb[1], gk_cumsum->nb[2], gk_cumsum->nb[3], - j_start * gk_cumsum->nb[0])); - gk_block[j] = ggml_reshape_4d(ctx0, gk_block[j], 1, block_size, S_k, CHB); + ggml_tensor * gkcs_j_bc = ggml_repeat_4d(ctx0, gkcs_j, chunk_size, chunk_size, S_k, CHB); // [1, chunk_size, S_k, CHB] -> [chunk_size, chunk_size, S_k, CHB] + // decay_mask [chunk_size,chunk_size,S_k,CHB] + ggml_tensor * decay_mask = ggml_sub(ctx0, gkcs_j_bc, gkcs_i); + cb(decay_mask, "decay_mask", il); - gk_block_bc[j] = ggml_repeat_4d(ctx0, gk_block[j], block_size, block_size, S_k, CHB); + decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); + cb(decay_mask, "decay_masked", il); + decay_mask = ggml_exp(ctx0, decay_mask); + decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); - gk_block[j] = ggml_reshape_4d(ctx0, gk_block[j], block_size, 1, S_k, CHB); - } + // decay_mask [S_k,BT_j,BT_i,CHB] *Note* second and third chunk_sizes are switched + decay_mask = ggml_cont_4d(ctx0, ggml_permute(ctx0, decay_mask, 2, 1, 0, 3), S_k, chunk_size, chunk_size, CHB); - for (int64_t j = 0; j < n_blocks; ++j) { - int64_t j_start = j * block_size; + ggml_tensor * k_i = ggml_reshape_4d(ctx0, k, S_k, chunk_size, 1, CHB); + ggml_tensor * k_j = ggml_reshape_4d(ctx0, k, S_k, 1, chunk_size, CHB); + ggml_tensor * q_i = ggml_reshape_4d(ctx0, q, S_k, chunk_size, 1, CHB); - ggml_tensor * k_j_block = ggml_reshape_4d(ctx0, k_block[j], S_k, 1, block_size, CHB); - for (int64_t i = 0; i <= j; ++i) { - int64_t i_start = i * block_size; + ggml_tensor * decay_k_i = ggml_mul(ctx0, decay_mask, k_i); + ggml_tensor * decay_q_i = ggml_mul(ctx0, decay_mask, q_i); - ggml_tensor * decay_mask = ggml_sub(ctx0, gk_block_bc[j], gk_block[i]); - cb(decay_mask, "decay_mask", il); - - // Apply diag_mask only at diagnoal blocks - if (i == j) { - decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); - decay_mask = ggml_exp(ctx0, decay_mask); - decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); - } else { - decay_mask = ggml_exp(ctx0, decay_mask); - } - - // decay_mask [S_k,BT_j,BT_i,ShHB] *Note* second and third chunk_sizes are switched - decay_mask = ggml_cont_4d(ctx0, ggml_permute(ctx0, decay_mask, 2, 1, 0, 3), S_k, block_size, block_size, CHB); - - ggml_tensor * decay_k_i = ggml_mul(ctx0, decay_mask, k_block[i]); - ggml_tensor * decay_q_i = ggml_mul(ctx0, decay_mask, q_block[i]); - - ggml_tensor * Akk_block = ggml_mul_mat(ctx0, decay_k_i, k_j_block); - ggml_tensor * Aqk_block = ggml_mul_mat(ctx0, decay_q_i, k_j_block); - - Akk_block = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, Akk_block, block_size, block_size, n_chunks, HB))); - Aqk_block = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, Aqk_block, block_size, block_size, n_chunks, HB))); - - if (i == j) { - Aqk_block = ggml_mul(ctx0, Aqk_block, diag_mask); - } - Aqk_block = ggml_scale(ctx0, Aqk_block, scale); // scale q - - // Copy to Akk at position [j_start:j_end, i_start:i_end] - Akk = ggml_set(ctx0, Akk, Akk_block, - Akk->nb[1], Akk->nb[2], Akk->nb[3], - i_start * Akk->nb[0] + j_start * Akk->nb[1]); - Aqk = ggml_set(ctx0, Aqk, Aqk_block, - Aqk->nb[1], Aqk->nb[2], Aqk->nb[3], - i_start * Aqk->nb[0] + j_start * Aqk->nb[1]); - } - } - Akk = ggml_mul(ctx0, Akk, beta); - Akk = ggml_neg(ctx0, ggml_mul(ctx0, Akk, causal_mask)); + // decay_k_i [S.BT,BT,CHB] @ k_j [S,1,BT,CHB] = Akk [BT,1,BT,CHB] + ggml_tensor * Akk = ggml_mul_mat(ctx0, decay_k_i, k_j); + ggml_tensor * Aqk = ggml_mul_mat(ctx0, decay_q_i, k_j); + Akk = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, Akk, chunk_size, chunk_size, n_chunks, HB))); + Aqk = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, Aqk, chunk_size, chunk_size, n_chunks, HB))); cb(Akk, "Akk", il); cb(Aqk, "Aqk", il); + Akk = ggml_mul(ctx0, Akk, beta); + Akk = ggml_neg(ctx0, ggml_mul(ctx0, Akk, causal_mask)); + cb(Akk, "attn_pre_solve", il); + + Aqk = ggml_mul(ctx0, Aqk, diag_mask); + Aqk = ggml_scale(ctx0, Aqk, scale); // scale q + cb(Aqk, "Aqk_masked", il); + // for i in range(1, chunk_size): // row = attn[..., i, :i].clone() // sub = attn[..., :i, :i].clone() @@ -664,14 +598,27 @@ std::pair llm_build_kimi_linear::build_kda_chunkin cb(new_state, "new_state", il); for (int64_t chunk = 0; chunk < n_chunks; chunk++) { - ggml_tensor * k_chunk = get_slice_2d(ctx0, k, chunk); - ggml_tensor * q_chunk = get_slice_2d(ctx0, q, chunk); - ggml_tensor * vb_chunk = get_slice_2d(ctx0, vb, chunk); +// extract one chunk worth of data + auto chunkify = [=](ggml_tensor * t) { + return ggml_cont(ctx0, ggml_view_4d(ctx0, t, t->ne[0], chunk_size, 1, t->ne[3], + t->nb[1], t->nb[2], t->nb[3], t->nb[2] * chunk)); + }; + auto chunkify_A = [=](ggml_tensor * t) { + return ggml_cont(ctx0, ggml_view_4d(ctx0, t, chunk_size, chunk_size, 1, t->ne[3], + t->nb[1], t->nb[2], t->nb[3], t->nb[2] * chunk)); + }; - ggml_tensor * gk_cs_chunk = ggml_cont(ctx0, get_slice_2d(ctx0, gk_cumsum, chunk)); - ggml_tensor * k_cumdecay_chunk = get_slice_2d(ctx0, k_cumdecay, chunk); + +// k [S,BT,NT,H*B] => k_chunk [S,BT,1,H*B] + ggml_tensor * k_chunk = chunkify(k); + ggml_tensor * q_chunk = chunkify(q); + ggml_tensor * vb_chunk = chunkify(vb); + +// gk_cumsum [S,BT,NT,H*B] => gk_cs_chunk [S,BT,1,H*B] + ggml_tensor * gk_cs_chunk = chunkify(gk_cumsum); + ggml_tensor * k_cumdecay_chunk = chunkify(k_cumdecay); ggml_tensor * gkexp_chunk = ggml_exp(ctx0, gk_cs_chunk); - ggml_tensor * Aqk_chunk = get_slice_2d(ctx0, Aqk, chunk); + ggml_tensor * Aqk_chunk = chunkify_A(Aqk); ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs);