From cfed14e31bd02c4c4dc971f9c900b9c2d39ca6fe Mon Sep 17 00:00:00 2001 From: Yee Man Chan Date: Tue, 6 Jan 2026 11:23:53 +0800 Subject: [PATCH] naive chunking form implemented --- src/models/kimi-linear.cpp | 223 +++++++++++++++++++++++++++++++++++-- 1 file changed, 214 insertions(+), 9 deletions(-) diff --git a/src/models/kimi-linear.cpp b/src/models/kimi-linear.cpp index a943dd1dce..3fb40471a1 100644 --- a/src/models/kimi-linear.cpp +++ b/src/models/kimi-linear.cpp @@ -265,7 +265,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll // Choose between build_kda_chunking and build_kda_recurrent based on n_tokens // TODO: Currently only build_kda_recurrent is implemented ggml_tensor * attn_out = n_seq_tokens > CHUNK_SIZE ? - build_kda_recurrent(Qcur, Kcur, Vcur, g1, beta, state, causal_mask, identity, il) : + build_kda_chunking(Qcur, Kcur, Vcur, g1, beta, state, causal_mask, identity, il) : build_kda_recurrent(Qcur, Kcur, Vcur, g1, beta, state, causal_mask, identity, il); cb(attn_out, "attn_out", il); @@ -485,7 +485,7 @@ ggml_tensor * llm_build_kimi_linear::build_kda_chunking( GGML_ASSERT(k->ne[2] == n_tokens); GGML_ASSERT(gk->ne[0] == S_v && gk->ne[1] == H_v && gk->ne[2] == n_tokens && gk->ne[3] == n_seqs); GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs); - GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs); + GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v && state->ne[2] == H_v && state->ne[3] == n_seqs); GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); @@ -504,8 +504,6 @@ ggml_tensor * llm_build_kimi_linear::build_kda_chunking( const float scale = 1.0f / sqrtf(S_v); - q = ggml_scale(ctx0, q, scale); - beta = ggml_sigmoid(ctx0, beta); cb(q, "q_in", il); @@ -514,8 +512,8 @@ ggml_tensor * llm_build_kimi_linear::build_kda_chunking( cb(beta, "beta_in", il); cb(gk, "gk_in", il); - q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); - k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); + q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs); + k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs); v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); gk = ggml_cont_4d(ctx0, ggml_permute(ctx0, gk, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); @@ -530,20 +528,227 @@ ggml_tensor * llm_build_kimi_linear::build_kda_chunking( cb(beta, "beta_perm", il); cb(gk, "gk_perm", il); cb(state, "state_in", il); - cb(causal_diag_mask, "causal_diag_mask", il); GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs); GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs); GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs); GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs); + // Do padding + const int64_t chunk_size = CHUNK_SIZE; + + const int64_t pad = (chunk_size - n_tokens % chunk_size) % chunk_size; + const int64_t n_chunks = (n_tokens + pad) / chunk_size; + + q = ggml_pad(ctx0, q, 0, pad, 0, 0); + k = ggml_pad(ctx0, k, 0, pad, 0, 0); + v = ggml_pad(ctx0, v, 0, pad, 0, 0); + gk = ggml_pad(ctx0, gk, 0, pad, 0, 0); + beta = ggml_pad(ctx0, beta, 0, pad, 0, 0); + + cb(q, "q_pad", il); + cb(k, "k_pad", il); + cb(v, "v_pad", il); + cb(beta, "beta_pad", il); + cb(gk, "gk_pad", il); + ggml_tensor * v_beta = ggml_mul(ctx0, v, beta); ggml_tensor * k_beta = ggml_mul(ctx0, k, beta); - cb(k_beta, "k_beta", il); cb(v_beta, "v_beta", il); + cb(k_beta, "k_beta", il); - return nullptr; + ggml_tensor * chunked_mask = + ggml_view_4d(ctx0, causal_mask, chunk_size, + chunk_size, causal_mask->ne[2], causal_mask->ne[3], + causal_mask->nb[1], causal_mask->nb[2], causal_mask->nb[3], 0); + + ggml_tensor * chunked_diag_mask = + ggml_view_4d(ctx0, causal_diag_mask, chunk_size, + chunk_size, causal_diag_mask->ne[2], causal_diag_mask->ne[3], + causal_diag_mask->nb[1], causal_diag_mask->nb[2], causal_diag_mask->nb[3], 0); + + ggml_tensor * chunked_identity = + ggml_view_4d(ctx0, identity, chunk_size, + chunk_size, identity->ne[2], identity->ne[3], + identity->nb[1], identity->nb[2], identity->nb[3], 0); + + const int64_t HB = H_k * n_seqs; + + q = ggml_cont_4d(ctx0, q, S_k, chunk_size, n_chunks, HB); + k = ggml_cont_4d(ctx0, k, S_k, chunk_size, n_chunks, HB); + k_beta = ggml_cont_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, HB); + v = ggml_cont_4d(ctx0, v, S_v, chunk_size, n_chunks, HB); + v_beta = ggml_cont_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, HB); + + gk = ggml_cont_4d(ctx0, gk, S_k, chunk_size, n_chunks, HB); + beta = ggml_cont_4d(ctx0, beta, 1, chunk_size, n_chunks, HB); + + // switch for cumsum + gk = ggml_cont_4d(ctx0, ggml_permute(ctx0, gk, 1, 0, 2, 3), chunk_size, S_k, n_chunks, HB); + ggml_tensor * gk_cumsum = ggml_cumsum(ctx0, gk); + cb(gk_cumsum, "gk_cumsum", il); + + const int64_t CHB = n_chunks * H_v * n_seqs; + + ggml_tensor * g_i = ggml_reshape_4d(ctx0, gk_cumsum, chunk_size, 1, S_k, CHB); + ggml_tensor * g_j = ggml_reshape_4d(ctx0, gk_cumsum, 1, chunk_size, S_k, CHB); + + ggml_tensor * g_j_bc = ggml_repeat_4d(ctx0, g_j, chunk_size, chunk_size, S_k, CHB); + + ggml_tensor * decay_mask = ggml_sub(ctx0, g_j_bc, g_i); + + cb(decay_mask, "decay_mask", il); + + decay_mask = ggml_mul(ctx0, decay_mask, chunked_diag_mask); + decay_mask = ggml_exp(ctx0, decay_mask); + decay_mask = ggml_mul(ctx0, decay_mask, chunked_diag_mask); + cb(decay_mask, "decay_mask_exp", il); + +// k [S,BT,NT,H*B] k_per [BT,S,NT,H*B] + ggml_tensor * k_per = ggml_cont(ctx0, ggml_permute(ctx0, k, 1, 0, 2, 3)); + ggml_tensor * k_i = ggml_reshape_4d(ctx0, k_per, chunk_size, 1, S_k, CHB); + ggml_tensor * k_i_bc = ggml_repeat_4d(ctx0, k_i, chunk_size, chunk_size, S_k, CHB); + ggml_tensor * k_j = ggml_reshape_4d(ctx0, k_per, 1, chunk_size, S_k, CHB); + ggml_tensor * k_j_bc = ggml_repeat_4d(ctx0, k_j, chunk_size, chunk_size, S_k, CHB); + + ggml_tensor * Akk = ggml_mul(ctx0, decay_mask, k_j_bc); + Akk = ggml_mul(ctx0, Akk, k_i_bc); + + Akk = ggml_cont(ctx0, ggml_permute(ctx0, Akk, 1, 2, 0, 3)); + Akk = ggml_sum_rows(ctx0, Akk); + + Akk = ggml_reshape_4d(ctx0, Akk, chunk_size, chunk_size, n_chunks, H_k * n_seqs); + + Akk = ggml_mul(ctx0, Akk, beta); + Akk = ggml_neg(ctx0, ggml_mul(ctx0, Akk, chunked_mask)); + + cb(Akk, "attn_pre_solve", il); + + ggml_tensor * attn_lower = ggml_mul(ctx0, Akk, chunked_mask); + ggml_tensor * lhs = ggml_sub(ctx0, ggml_repeat(ctx0, chunked_identity, attn_lower), attn_lower); + + ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, Akk, true, true, false); + Akk = ggml_mul(ctx0, lin_solve, chunked_mask); + Akk = ggml_add(ctx0, Akk, chunked_identity); + + cb(Akk, "attn_solved", il); + + ggml_tensor * vb = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), Akk); + + gk_cumsum = ggml_cont_4d(ctx0, ggml_permute(ctx0, gk_cumsum, 1, 0, 2, 3), S_k, chunk_size, n_chunks, HB); + ggml_tensor * gkexp = ggml_exp(ctx0, gk_cumsum); + + ggml_tensor * kbeta_gkexp = ggml_mul(ctx0, k_beta, gkexp); + cb(kbeta_gkexp, "kbeta_gkexp", il); + + ggml_tensor * k_cumdecay = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gkexp)), Akk); + cb(k_cumdecay, "k_cumdecay", il); + + ggml_tensor * core_attn_out = nullptr; + ggml_tensor * new_state = ggml_dup(ctx0, state); + + cb(new_state, "new_state", il); + + for (int64_t chunk = 0; chunk < n_chunks; chunk++) { +// for (int64_t chunk = 0; chunk < 1; chunk++) { +// extract one chunk worth of data + auto chunkify = [=](ggml_tensor * t) { + return ggml_cont(ctx0, ggml_view_4d(ctx0, t, t->ne[0], chunk_size, 1, t->ne[3], + t->nb[1], t->nb[2], t->nb[3], t->nb[2] * chunk)); + }; + +// k [S,BT,NT,H*B] => k_chunk [S,BT,1,H*B] + ggml_tensor * k_chunk = chunkify(k); + ggml_tensor * q_chunk = chunkify(q); + ggml_tensor * vb_chunk = chunkify(vb); + + // Since decay_mask now has dimension of [BT,BT,S,NT*H*B], it can't be chunkified + // decay_mask_chunk needs to be recomputed +// gk_cumsum [S,BT,NT,H*B] => gk_cs_chunk [S,BT,1,H*B] + ggml_tensor * gk_cs_chunk = chunkify(gk_cumsum); + ggml_tensor * gk_cs_chunk_i = ggml_cont(ctx0, ggml_permute(ctx0, gk_cs_chunk, 2, 0, 1, 3)); + ggml_tensor * gk_cs_chunk_j = ggml_cont(ctx0, ggml_permute(ctx0, gk_cs_chunk, 2, 1, 0, 3)); + + ggml_tensor * gk_cs_chunk_j_bc = ggml_repeat_4d(ctx0, gk_cs_chunk_j, chunk_size, chunk_size, S_k, HB); + ggml_tensor * decay_mask_chunk = ggml_sub(ctx0, gk_cs_chunk_j_bc, gk_cs_chunk_i); + cb(decay_mask_chunk, "decay_mask_chunk", il); + decay_mask_chunk = ggml_mul(ctx0, decay_mask_chunk, chunked_diag_mask); + decay_mask_chunk = ggml_exp(ctx0, decay_mask_chunk); + decay_mask_chunk = ggml_mul(ctx0, decay_mask_chunk, chunked_diag_mask); + cb(decay_mask_chunk, "decay_mask_chunk_exp", il); + + ggml_tensor * k_cumdecay_chunk = chunkify(k_cumdecay); + + ggml_tensor * gkexp_chunk = ggml_exp(ctx0, gk_cs_chunk); + + ggml_tensor * k_chunk_i = ggml_cont(ctx0, ggml_permute(ctx0, k_chunk, 2, 0, 1, 3)); + ggml_tensor * k_chunk_i_bc = ggml_repeat_4d(ctx0, k_chunk_i, chunk_size, chunk_size, S_k, HB); + ggml_tensor * q_chunk_j = ggml_cont(ctx0, ggml_permute(ctx0, q_chunk, 2, 1, 0, 3)); + ggml_tensor * q_chunk_j_bc = ggml_repeat_4d(ctx0, q_chunk_j, chunk_size, chunk_size, S_k, HB); + ggml_tensor * kq = ggml_mul(ctx0, decay_mask_chunk, q_chunk_j_bc); + kq = ggml_mul(ctx0, kq, k_chunk_i_bc); + + ggml_tensor * Aqk = ggml_mul(ctx0, kq, decay_mask_chunk); + Aqk = ggml_mul(ctx0, Aqk, ggml_add(ctx0, chunked_identity, chunked_mask)); + Aqk = ggml_cont(ctx0, ggml_permute(ctx0, Aqk, 1, 2, 0, 3)); + Aqk = ggml_sum_rows(ctx0, Aqk); + Aqk = ggml_scale(ctx0, Aqk, scale); // scale q + Aqk = ggml_reshape_4d(ctx0, Aqk, chunk_size, chunk_size, 1, HB); + + ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs); + +// new_state [S,S,1,H*B] k_cumdecay_chunk [S,BT,1,H*B] + ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk); + + ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, vb_chunk, v_prime), v_prime); + ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new)); + +// q_chunk [S,BT,1,H*B] gkexp_chunk [S,BT,1,H*B] + ggml_tensor * q_gk_exp = ggml_mul(ctx0, q_chunk, gkexp_chunk); + ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_gk_exp); + attn_inter = ggml_scale(ctx0, attn_inter, scale); // scale q + +// v_new_t [S,BT,1,H*B] Aqk [BT,BT,1,H*B] + ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, Aqk); + + ggml_tensor * core_attn_out_chunk = ggml_add(ctx0, attn_inter, v_attn); + + core_attn_out = core_attn_out == nullptr ? core_attn_out_chunk : ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 1); + + ggml_tensor * gk_cum_last = + ggml_cont(ctx0, ggml_view_4d(ctx0, gk_cs_chunk, gk_cs_chunk->ne[0], 1, gk_cs_chunk->ne[2], gk_cs_chunk->ne[3], + gk_cs_chunk->nb[1], gk_cs_chunk->nb[2], gk_cs_chunk->nb[3], + gk_cs_chunk->nb[1] * (gk_cs_chunk->ne[1] - 1))); + + ggml_tensor * gkexp_last = ggml_exp(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, gk_cum_last))); + + ggml_tensor * gk_diff = ggml_neg(ctx0, ggml_sub(ctx0, gk_cs_chunk, gk_cum_last)); + + ggml_tensor * gk_diff_exp = ggml_exp(ctx0, gk_diff); + + ggml_tensor * key_gkdiff = ggml_mul(ctx0, k_chunk, gk_diff_exp); + + ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, ggml_cont(ctx0, ggml_transpose(ctx0, key_gkdiff))); + + new_state = ggml_add(ctx0, + ggml_mul(ctx0, new_state, ggml_reshape_4d(ctx0, gkexp_last, gkexp_last->ne[0], gkexp_last->ne[1], H_v, n_seqs)), + ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs)); + } + + core_attn_out = ggml_cont_4d(ctx0, core_attn_out, S_v, chunk_size * n_chunks, H_v, n_seqs); + + ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out, S_v, n_tokens, H_v, n_seqs, core_attn_out->nb[1], core_attn_out->nb[2], core_attn_out->nb[3], 0); + cb(output_tokens, "output_tokens", il); + + // flatten output + ggml_tensor * flat_output = + ggml_cont_1d(ctx0, ggml_permute(ctx0, output_tokens, 0, 2, 1, 3), S_v * H_v * n_tokens * n_seqs); + + ggml_tensor * flat_state = ggml_cont_1d(ctx0, new_state, S_v * S_v * H_v * n_seqs); + cb(new_state, "output_state", il); + + return ggml_concat(ctx0, flat_output, flat_state, 0); } ggml_tensor * llm_build_kimi_linear::build_kda_recurrent(