diff --git a/src/models/kimi-linear.cpp b/src/models/kimi-linear.cpp index 0f037d1a39..ebaabf7ee6 100644 --- a/src/models/kimi-linear.cpp +++ b/src/models/kimi-linear.cpp @@ -2,6 +2,7 @@ #include "ggml.h" #define CHUNK_SIZE 64 +#define BLOCK_SIZE 16 // Causal Conv1d function for Q,K,V // When qkv is 0, it is Q, 1 is K, 2 is V @@ -84,13 +85,16 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll ggml_tensor * chunked_causal_mask = ggml_tri(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, CHUNK_SIZE, CHUNK_SIZE), 1.0f), GGML_TRI_TYPE_LOWER); - ggml_tensor * chunked_identity = ggml_diag(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, CHUNK_SIZE), 1.0f)); - ggml_tensor * chunked_diag_mask = ggml_add(ctx0, chunked_causal_mask, chunked_identity); + ggml_tensor * blocked_causal_mask = + ggml_tri(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, BLOCK_SIZE, BLOCK_SIZE), 1.0f), + GGML_TRI_TYPE_LOWER); + ggml_tensor * blocked_identity = ggml_diag(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, BLOCK_SIZE), 1.0f)); + ggml_tensor * blocked_diag_mask = ggml_add(ctx0, blocked_causal_mask, blocked_identity); ggml_build_forward_expand(gf, chunked_causal_mask); ggml_build_forward_expand(gf, chunked_identity); - ggml_build_forward_expand(gf, chunked_diag_mask); + ggml_build_forward_expand(gf, blocked_diag_mask); // Kimi dimension constants const int64_t n_head = hparams.n_head(); @@ -175,7 +179,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll // Choose between build_kda_chunking and build_kda_recurrent based on n_tokens std::pair attn_out = n_seq_tokens == 1 ? build_kda_autoregressive(Qcur, Kcur, Vcur, g1, beta, state, il) : - build_kda_chunking(Qcur, Kcur, Vcur, g1, beta, state, chunked_causal_mask, chunked_identity, chunked_diag_mask, il); + build_kda_chunking(Qcur, Kcur, Vcur, g1, beta, state, chunked_causal_mask, chunked_identity, blocked_diag_mask, il); ggml_tensor * output = attn_out.first; ggml_tensor * new_state = attn_out.second; @@ -404,6 +408,11 @@ std::pair llm_build_kimi_linear::build_kda_chunkin ggml_tensor * identity, ggml_tensor * diag_mask, int il) { + GGML_ASSERT(ggml_is_contiguous(q)); + GGML_ASSERT(ggml_is_contiguous(k)); + GGML_ASSERT(ggml_is_contiguous(v)); + GGML_ASSERT(ggml_is_contiguous(gk)); + GGML_ASSERT(ggml_is_contiguous(beta)); GGML_ASSERT(ggml_is_contiguous(state)); const int64_t S_k = q->ne[0]; @@ -519,45 +528,99 @@ std::pair llm_build_kimi_linear::build_kda_chunkin g_j = g[:, :, i, j:j+1, :] A[..., j] = torch.einsum('... c d, ... d -> ... c', q_i * (g_i - g_j).exp(), k_j) */ + // compute matrix multiplication block by block and + // skip the blocks of zeros in the upper triangle + // Initialize Akk and Aqk to zeros const int64_t CHB = n_chunks * H_k * n_seqs; - ggml_tensor * gkcs_i = ggml_reshape_4d(ctx0, gk_cumsum, chunk_size, 1, S_k, CHB); // [chunk_size, 1, S_k, CHB] - ggml_tensor * gkcs_j = ggml_reshape_4d(ctx0, gkcs_i, 1, chunk_size, S_k, CHB); // [1, chunk_size, S_k, CHB] + ggml_tensor * Akk = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, chunk_size, chunk_size, n_chunks, HB); + Akk = ggml_clamp(ctx0, Akk, 0.0f, 0.0f); + ggml_tensor * Aqk = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, chunk_size, chunk_size, n_chunks, HB); + Aqk = ggml_clamp(ctx0, Aqk, 0.0f, 0.0f); + const int64_t block_size = BLOCK_SIZE; + const int64_t n_blocks = CHUNK_SIZE / BLOCK_SIZE; + + ggml_tensor * k_block[n_blocks]; + ggml_tensor * q_block[n_blocks]; + ggml_tensor * gk_block[n_blocks]; + ggml_tensor * gk_block_bc[n_blocks]; + for (int64_t j = 0; j < n_blocks; ++j) { + int64_t j_start = j * block_size; + + // k_i_block: [S, block_size, C, HB] + k_block[j] = ggml_cont(ctx0, ggml_view_4d(ctx0, k, + S_k, block_size, n_chunks, HB, + k->nb[1], k->nb[2], k->nb[3], + j_start * k->nb[1])); + + k_block[j] = ggml_reshape_4d(ctx0, k_block[j], S_k, block_size, 1, CHB); + // q_i_block: [S, block_size, C, HB] + q_block[j] = ggml_cont(ctx0, ggml_view_4d(ctx0, q, + S_k, block_size, n_chunks, HB, + q->nb[1], q->nb[2], q->nb[3], + j_start * q->nb[1])); - ggml_tensor * gkcs_j_bc = ggml_repeat_4d(ctx0, gkcs_j, chunk_size, chunk_size, S_k, CHB); // [1, chunk_size, S_k, CHB] -> [chunk_size, chunk_size, S_k, CHB] - // decay_mask [chunk_size,chunk_size,S_k,CHB] - ggml_tensor * decay_mask = ggml_sub(ctx0, gkcs_j_bc, gkcs_i); - cb(decay_mask, "decay_mask", il); + q_block[j] = ggml_reshape_4d(ctx0, q_block[j], S_k, block_size, 1, CHB); + // gk_j_block: [S, block_size, C, HB] + gk_block[j] = ggml_cont(ctx0, ggml_view_4d(ctx0, gk_cumsum, + block_size, S_k, n_chunks, HB, + gk_cumsum->nb[1], gk_cumsum->nb[2], gk_cumsum->nb[3], + j_start * gk_cumsum->nb[0])); + gk_block[j] = ggml_reshape_4d(ctx0, gk_block[j], 1, block_size, S_k, CHB); - decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); - cb(decay_mask, "decay_masked", il); - decay_mask = ggml_exp(ctx0, decay_mask); - decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); + gk_block_bc[j] = ggml_repeat_4d(ctx0, gk_block[j], block_size, block_size, S_k, CHB); - // decay_mask [S_k,BT_j,BT_i,CHB] *Note* second and third chunk_sizes are switched - decay_mask = ggml_cont_4d(ctx0, ggml_permute(ctx0, decay_mask, 2, 1, 0, 3), S_k, chunk_size, chunk_size, CHB); + gk_block[j] = ggml_reshape_4d(ctx0, gk_block[j], block_size, 1, S_k, CHB); + } - ggml_tensor * k_i = ggml_reshape_4d(ctx0, k, S_k, chunk_size, 1, CHB); - ggml_tensor * k_j = ggml_reshape_4d(ctx0, k, S_k, 1, chunk_size, CHB); - ggml_tensor * q_i = ggml_reshape_4d(ctx0, q, S_k, chunk_size, 1, CHB); + for (int64_t j = 0; j < n_blocks; ++j) { + int64_t j_start = j * block_size; - ggml_tensor * decay_k_i = ggml_mul(ctx0, decay_mask, k_i); - ggml_tensor * decay_q_i = ggml_mul(ctx0, decay_mask, q_i); + ggml_tensor * k_j_block = ggml_reshape_4d(ctx0, k_block[j], S_k, 1, block_size, CHB); + for (int64_t i = 0; i <= j; ++i) { + int64_t i_start = i * block_size; - // decay_k_i [S.BT,BT,CHB] @ k_j [S,1,BT,CHB] = Akk [BT,1,BT,CHB] - ggml_tensor * Akk = ggml_mul_mat(ctx0, decay_k_i, k_j); - ggml_tensor * Aqk = ggml_mul_mat(ctx0, decay_q_i, k_j); - Akk = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, Akk, chunk_size, chunk_size, n_chunks, HB))); - Aqk = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, Aqk, chunk_size, chunk_size, n_chunks, HB))); - cb(Akk, "Akk", il); - cb(Aqk, "Aqk", il); + ggml_tensor * decay_mask = ggml_sub(ctx0, gk_block_bc[j], gk_block[i]); + cb(decay_mask, "decay_mask", il); + // Apply diag_mask only at diagnoal blocks + if (i == j) { + decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); + decay_mask = ggml_exp(ctx0, decay_mask); + decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); + } else { + decay_mask = ggml_exp(ctx0, decay_mask); + } + + // decay_mask [S_k,BT_j,BT_i,ShHB] *Note* second and third chunk_sizes are switched + decay_mask = ggml_cont_4d(ctx0, ggml_permute(ctx0, decay_mask, 2, 1, 0, 3), S_k, block_size, block_size, CHB); + + ggml_tensor * decay_k_i = ggml_mul(ctx0, decay_mask, k_block[i]); + ggml_tensor * decay_q_i = ggml_mul(ctx0, decay_mask, q_block[i]); + + ggml_tensor * Akk_block = ggml_mul_mat(ctx0, decay_k_i, k_j_block); + ggml_tensor * Aqk_block = ggml_mul_mat(ctx0, decay_q_i, k_j_block); + + Akk_block = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, Akk_block, block_size, block_size, n_chunks, HB))); + Aqk_block = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, Aqk_block, block_size, block_size, n_chunks, HB))); + + if (i == j) { + Aqk_block = ggml_mul(ctx0, Aqk_block, diag_mask); + } + Aqk_block = ggml_scale(ctx0, Aqk_block, scale); // scale q + + // Accumulate into Akk at position [j_start:j_end, i_start:i_end] + Akk = ggml_acc(ctx0, Akk, Akk_block, + Akk->nb[1], Akk->nb[2], Akk->nb[3], + i_start * Akk->nb[0] + j_start * Akk->nb[1]); + Aqk = ggml_acc(ctx0, Aqk, Aqk_block, + Aqk->nb[1], Aqk->nb[2], Aqk->nb[3], + i_start * Aqk->nb[0] + j_start * Aqk->nb[1]); + } + } Akk = ggml_mul(ctx0, Akk, beta); Akk = ggml_neg(ctx0, ggml_mul(ctx0, Akk, causal_mask)); - cb(Akk, "attn_pre_solve", il); - - Aqk = ggml_mul(ctx0, Aqk, diag_mask); - Aqk = ggml_scale(ctx0, Aqk, scale); // scale q - cb(Aqk, "Aqk_masked", il); + cb(Akk, "Akk", il); + cb(Aqk, "Aqk", il); // for i in range(1, chunk_size): // row = attn[..., i, :i].clone() @@ -690,8 +753,12 @@ std::pair llm_build_kimi_linear::build_kda_autoreg ggml_tensor * beta, ggml_tensor * state, int il) { + GGML_ASSERT(ggml_is_contiguous(q)); + GGML_ASSERT(ggml_is_contiguous(k)); GGML_ASSERT(ggml_is_contiguous(v)); GGML_ASSERT(ggml_is_contiguous(gk)); + GGML_ASSERT(ggml_is_contiguous(beta)); + GGML_ASSERT(ggml_is_contiguous(state)); const int64_t S_k = q->ne[0]; const int64_t H_k = q->ne[1];