diff --git a/src/models/kimi-linear.cpp b/src/models/kimi-linear.cpp index 270f9e6e6b..b229d31165 100644 --- a/src/models/kimi-linear.cpp +++ b/src/models/kimi-linear.cpp @@ -571,48 +571,40 @@ ggml_tensor * llm_build_kimi_linear::build_kda_chunking( // switch for cumsum gk = ggml_cont_4d(ctx0, ggml_permute(ctx0, gk, 1, 0, 2, 3), chunk_size, S_k, n_chunks, HB); + cb(gk, "gk", il); ggml_tensor * gk_cumsum = ggml_cumsum(ctx0, gk); cb(gk_cumsum, "gk_cumsum", il); + // switch back for downstream + gk_cumsum = ggml_cont_4d(ctx0, ggml_permute(ctx0, gk_cumsum, 1, 0, 2, 3), S_k, chunk_size, n_chunks, HB); + ggml_tensor * gkexp = ggml_exp(ctx0, gk_cumsum); + + cb(gk_cumsum, "gk_cumsum", il); /* - https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/kda/naive.py - - for i in range(T): - k_i = k[..., i, :] - g_i = g[..., i:i+1, :] + for i in range(BT): + k_i = k[..., i, :] # k_i [B,H,NT,S] + g_i = g[..., i:i+1, :] # g_i [B,H,NT,1,S] A[..., i] = torch.einsum('... c d, ... d -> ... c', k * (g - g_i).exp(), k_i) */ - const int64_t CHB = n_chunks * H_v * n_seqs; + // gk_ref: [S, 1, C, HB] - first token of i_block + ggml_tensor * gk_ref = ggml_view_4d(ctx0, gk_cumsum, + S_k, 1, n_chunks, HB, + gk_cumsum->nb[1], gk_cumsum->nb[2], gk_cumsum->nb[3], + 0); + cb(gk_ref, "gk_ref", il); - ggml_tensor * g_i = ggml_reshape_4d(ctx0, gk_cumsum, chunk_size, 1, S_k, CHB); - ggml_tensor * g_j = ggml_reshape_4d(ctx0, gk_cumsum, 1, chunk_size, S_k, CHB); - - ggml_tensor * g_j_bc = ggml_repeat_4d(ctx0, g_j, chunk_size, chunk_size, S_k, CHB); - - ggml_tensor * decay_mask = ggml_sub(ctx0, g_j_bc, g_i); - - cb(decay_mask, "decay_mask", il); - - decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); - decay_mask = ggml_exp(ctx0, decay_mask); - decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); - cb(decay_mask, "decay_mask_exp", il); - -// k [S,BT,NT,H*B] k_per [BT,S,NT,H*B] - ggml_tensor * k_per = ggml_cont(ctx0, ggml_permute(ctx0, k, 1, 0, 2, 3)); - ggml_tensor * k_i = ggml_reshape_4d(ctx0, k_per, chunk_size, 1, S_k, CHB); - ggml_tensor * k_i_bc = ggml_repeat_4d(ctx0, k_i, chunk_size, chunk_size, S_k, CHB); - ggml_tensor * k_j = ggml_reshape_4d(ctx0, k_per, 1, chunk_size, S_k, CHB); - ggml_tensor * k_j_bc = ggml_repeat_4d(ctx0, k_j, chunk_size, chunk_size, S_k, CHB); - - ggml_tensor * Akk = ggml_mul(ctx0, decay_mask, k_j_bc); - Akk = ggml_mul(ctx0, Akk, k_i_bc); - - Akk = ggml_cont(ctx0, ggml_permute(ctx0, Akk, 1, 2, 0, 3)); - Akk = ggml_sum_rows(ctx0, Akk); - - Akk = ggml_reshape_4d(ctx0, Akk, chunk_size, chunk_size, n_chunks, H_k * n_seqs); + // Compute gk_diff + ggml_tensor * gk_diff_j = ggml_sub(ctx0, gk_cumsum, ggml_repeat(ctx0, gk_ref, gk_cumsum)); + ggml_tensor * gk_diff_i = ggml_clamp(ctx0, ggml_neg(ctx0, gk_diff_j), 0.0f, 88.0f); + cb(gk_diff_j, "gk_diff_j", il); + cb(gk_diff_i, "gk_diff_i", il); + // Decay k + ggml_tensor * k_exp_j = ggml_mul(ctx0, k, ggml_exp(ctx0, gk_diff_j)); + ggml_tensor * k_exp_i = ggml_mul(ctx0, k, ggml_exp(ctx0, gk_diff_i)); + ggml_tensor * Akk = ggml_mul_mat(ctx0, k_exp_i, k_exp_j); + cb(Akk, "Akk", il); + Akk = ggml_mul(ctx0, Akk, beta); Akk = ggml_neg(ctx0, ggml_mul(ctx0, Akk, causal_mask)); @@ -637,9 +629,6 @@ ggml_tensor * llm_build_kimi_linear::build_kda_chunking( // u = (A*beta[..., None, :]) @ v aka U_[t] ggml_tensor * vb = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), Akk); - gk_cumsum = ggml_cont_4d(ctx0, ggml_permute(ctx0, gk_cumsum, 1, 0, 2, 3), S_k, chunk_size, n_chunks, HB); - ggml_tensor * gkexp = ggml_exp(ctx0, gk_cumsum); - ggml_tensor * kbeta_gkexp = ggml_mul(ctx0, k_beta, gkexp); cb(kbeta_gkexp, "kbeta_gkexp", il); @@ -663,23 +652,9 @@ ggml_tensor * llm_build_kimi_linear::build_kda_chunking( ggml_tensor * q_chunk = chunkify(q); ggml_tensor * vb_chunk = chunkify(vb); - // Since decay_mask now has dimension of [BT,BT,S,NT*H*B], it can't be chunkified - // decay_mask_chunk needs to be recomputed // gk_cumsum [S,BT,NT,H*B] => gk_cs_chunk [S,BT,1,H*B] ggml_tensor * gk_cs_chunk = chunkify(gk_cumsum); - ggml_tensor * gk_cs_chunk_i = ggml_cont(ctx0, ggml_permute(ctx0, gk_cs_chunk, 2, 0, 1, 3)); - ggml_tensor * gk_cs_chunk_j = ggml_cont(ctx0, ggml_permute(ctx0, gk_cs_chunk, 2, 1, 0, 3)); - - ggml_tensor * gk_cs_chunk_j_bc = ggml_repeat_4d(ctx0, gk_cs_chunk_j, chunk_size, chunk_size, S_k, HB); - ggml_tensor * decay_mask_chunk = ggml_sub(ctx0, gk_cs_chunk_j_bc, gk_cs_chunk_i); - cb(decay_mask_chunk, "decay_mask_chunk", il); - decay_mask_chunk = ggml_mul(ctx0, decay_mask_chunk, diag_mask); - decay_mask_chunk = ggml_exp(ctx0, decay_mask_chunk); - decay_mask_chunk = ggml_mul(ctx0, decay_mask_chunk, diag_mask); - cb(decay_mask_chunk, "decay_mask_chunk_exp", il); - ggml_tensor * k_cumdecay_chunk = chunkify(k_cumdecay); - ggml_tensor * gkexp_chunk = ggml_exp(ctx0, gk_cs_chunk); /* https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/kda/naive.py @@ -689,19 +664,25 @@ ggml_tensor * llm_build_kimi_linear::build_kda_chunking( g_j = g[:, :, i, j:j+1, :] A[..., j] = torch.einsum('... c d, ... d -> ... c', q_i * (g_i - g_j).exp(), k_j) */ - ggml_tensor * k_chunk_i = ggml_cont(ctx0, ggml_permute(ctx0, k_chunk, 2, 0, 1, 3)); - ggml_tensor * k_chunk_i_bc = ggml_repeat_4d(ctx0, k_chunk_i, chunk_size, chunk_size, S_k, HB); - ggml_tensor * q_chunk_j = ggml_cont(ctx0, ggml_permute(ctx0, q_chunk, 2, 1, 0, 3)); - ggml_tensor * q_chunk_j_bc = ggml_repeat_4d(ctx0, q_chunk_j, chunk_size, chunk_size, S_k, HB); - ggml_tensor * kq = ggml_mul(ctx0, decay_mask_chunk, q_chunk_j_bc); - kq = ggml_mul(ctx0, kq, k_chunk_i_bc); + ggml_tensor * gk_ref_chunk = ggml_view_4d(ctx0, gk_cs_chunk, + S_k, 1, 1, HB, + gk_cs_chunk->nb[1], gk_cs_chunk->nb[2], gk_cs_chunk->nb[3], + 0); + // Compute gk_diff + ggml_tensor * gk_diff_chunk_j = ggml_sub(ctx0, gk_cs_chunk, ggml_repeat(ctx0, gk_ref_chunk, gk_cs_chunk)); + ggml_tensor * gk_diff_chunk_i = ggml_clamp(ctx0, ggml_neg(ctx0, gk_diff_chunk_j), 0.0f, 88.0f); + cb(gk_diff_chunk_j, "gk_diff_chunk_j", il); + cb(gk_diff_chunk_i, "gk_diff_chunk_i", il); - ggml_tensor * Aqk = ggml_mul(ctx0, kq, decay_mask_chunk); - Aqk = ggml_mul(ctx0, Aqk, ggml_add(ctx0, identity, causal_mask)); - Aqk = ggml_cont(ctx0, ggml_permute(ctx0, Aqk, 1, 2, 0, 3)); - Aqk = ggml_sum_rows(ctx0, Aqk); + // Decay q and k + ggml_tensor * q_exp_chunk = ggml_mul(ctx0, q_chunk, ggml_exp(ctx0, gk_diff_chunk_j)); + ggml_tensor * k_exp_chunk = ggml_mul(ctx0, k_chunk, ggml_exp(ctx0, gk_diff_chunk_i)); + + ggml_tensor * Aqk = ggml_mul_mat(ctx0, k_exp_chunk, q_exp_chunk); + cb(Aqk, "Aqk", il); + Aqk = ggml_mul(ctx0, Aqk, diag_mask); Aqk = ggml_scale(ctx0, Aqk, scale); // scale q - Aqk = ggml_reshape_4d(ctx0, Aqk, chunk_size, chunk_size, 1, HB); + cb(Aqk, "Aqk_masked", il); ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs);