build_kda_autoregressive is implemented to replace build_kda_recurrent for faster inference. sync'd to b7682

This commit is contained in:
Yee Man Chan 2026-01-07 18:42:31 +08:00
parent 40f6118192
commit 1099cbf694
2 changed files with 110 additions and 254 deletions

View File

@ -20,14 +20,16 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
// Output ids for selecting which tokens to output
ggml_tensor * inp_out_ids = build_inp_out_ids();
ggml_tensor * causal_mask =
ggml_tri(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, ubatch.n_seq_tokens, ubatch.n_seq_tokens), 1.0f),
ggml_tensor * chunked_causal_mask =
ggml_tri(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, CHUNK_SIZE, CHUNK_SIZE), 1.0f),
GGML_TRI_TYPE_LOWER);
ggml_tensor * identity = ggml_diag(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, ubatch.n_seq_tokens), 1.0f));
ggml_tensor * chunked_identity = ggml_diag(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, CHUNK_SIZE), 1.0f));
ggml_tensor * chunked_diag_mask = ggml_add(ctx0, chunked_causal_mask, chunked_identity);
ggml_build_forward_expand(gf, causal_mask);
ggml_build_forward_expand(gf, identity);
ggml_build_forward_expand(gf, chunked_causal_mask);
ggml_build_forward_expand(gf, chunked_identity);
ggml_build_forward_expand(gf, chunked_diag_mask);
// Kimi dimension constants
const int64_t n_head = hparams.n_head();
@ -263,9 +265,9 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
ggml_tensor * state = build_rs(inp_rs, ssm_states_all, hparams.n_embd_s(), n_seqs);
state = ggml_reshape_4d(ctx0, state, head_dim, head_dim, n_head, n_seqs);
// Choose between build_kda_chunking and build_kda_recurrent based on n_tokens
ggml_tensor * attn_out = n_seq_tokens > CHUNK_SIZE ?
build_kda_chunking(Qcur, Kcur, Vcur, g1, beta, state, causal_mask, identity, il) :
build_kda_recurrent(Qcur, Kcur, Vcur, g1, beta, state, causal_mask, identity, il);
ggml_tensor * attn_out = n_seq_tokens == 1 ?
build_kda_autoregressive(Qcur, Kcur, Vcur, g1, beta, state, il) :
build_kda_chunking(Qcur, Kcur, Vcur, g1, beta, state, chunked_causal_mask, chunked_identity, chunked_diag_mask, il);
cb(attn_out, "attn_out", il);
// The tensors were concatenated 1d, so we need to extract them 1d as well
@ -464,6 +466,7 @@ ggml_tensor * llm_build_kimi_linear::build_kda_chunking(
ggml_tensor * state,
ggml_tensor * causal_mask,
ggml_tensor * identity,
ggml_tensor * diag_mask,
int il) {
GGML_ASSERT(ggml_is_contiguous(q));
GGML_ASSERT(ggml_is_contiguous(k));
@ -519,8 +522,6 @@ ggml_tensor * llm_build_kimi_linear::build_kda_chunking(
beta = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3));
state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs);
ggml_tensor * causal_diag_mask = ggml_add(ctx0, causal_mask, identity);
cb(q, "q_perm", il);
cb(k, "k_perm", il);
cb(v, "v_perm", il);
@ -557,21 +558,6 @@ ggml_tensor * llm_build_kimi_linear::build_kda_chunking(
cb(v_beta, "v_beta", il);
cb(k_beta, "k_beta", il);
ggml_tensor * chunked_mask =
ggml_view_4d(ctx0, causal_mask, chunk_size,
chunk_size, causal_mask->ne[2], causal_mask->ne[3],
causal_mask->nb[1], causal_mask->nb[2], causal_mask->nb[3], 0);
ggml_tensor * chunked_diag_mask =
ggml_view_4d(ctx0, causal_diag_mask, chunk_size,
chunk_size, causal_diag_mask->ne[2], causal_diag_mask->ne[3],
causal_diag_mask->nb[1], causal_diag_mask->nb[2], causal_diag_mask->nb[3], 0);
ggml_tensor * chunked_identity =
ggml_view_4d(ctx0, identity, chunk_size,
chunk_size, identity->ne[2], identity->ne[3],
identity->nb[1], identity->nb[2], identity->nb[3], 0);
const int64_t HB = H_k * n_seqs;
q = ggml_cont_4d(ctx0, q, S_k, chunk_size, n_chunks, HB);
@ -588,6 +574,14 @@ ggml_tensor * llm_build_kimi_linear::build_kda_chunking(
ggml_tensor * gk_cumsum = ggml_cumsum(ctx0, gk);
cb(gk_cumsum, "gk_cumsum", il);
/*
https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/kda/naive.py
for i in range(T):
k_i = k[..., i, :]
g_i = g[..., i:i+1, :]
A[..., i] = torch.einsum('... c d, ... d -> ... c', k * (g - g_i).exp(), k_i)
*/
const int64_t CHB = n_chunks * H_v * n_seqs;
ggml_tensor * g_i = ggml_reshape_4d(ctx0, gk_cumsum, chunk_size, 1, S_k, CHB);
@ -599,9 +593,9 @@ ggml_tensor * llm_build_kimi_linear::build_kda_chunking(
cb(decay_mask, "decay_mask", il);
decay_mask = ggml_mul(ctx0, decay_mask, chunked_diag_mask);
decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
decay_mask = ggml_exp(ctx0, decay_mask);
decay_mask = ggml_mul(ctx0, decay_mask, chunked_diag_mask);
decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
cb(decay_mask, "decay_mask_exp", il);
// k [S,BT,NT,H*B] k_per [BT,S,NT,H*B]
@ -620,19 +614,27 @@ ggml_tensor * llm_build_kimi_linear::build_kda_chunking(
Akk = ggml_reshape_4d(ctx0, Akk, chunk_size, chunk_size, n_chunks, H_k * n_seqs);
Akk = ggml_mul(ctx0, Akk, beta);
Akk = ggml_neg(ctx0, ggml_mul(ctx0, Akk, chunked_mask));
Akk = ggml_neg(ctx0, ggml_mul(ctx0, Akk, causal_mask));
cb(Akk, "attn_pre_solve", il);
ggml_tensor * attn_lower = ggml_mul(ctx0, Akk, chunked_mask);
ggml_tensor * lhs = ggml_sub(ctx0, ggml_repeat(ctx0, chunked_identity, attn_lower), attn_lower);
// for i in range(1, chunk_size):
// row = attn[..., i, :i].clone()
// sub = attn[..., :i, :i].clone()
// attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
// attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
//
// We reduce this to a linear triangular solve: AX = B, where B = attn, A = I - tril(A)
ggml_tensor * attn_lower = ggml_mul(ctx0, Akk, causal_mask);
ggml_tensor * lhs = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower);
ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, Akk, true, true, false);
Akk = ggml_mul(ctx0, lin_solve, chunked_mask);
Akk = ggml_add(ctx0, Akk, chunked_identity);
Akk = ggml_mul(ctx0, lin_solve, causal_mask);
Akk = ggml_add(ctx0, Akk, identity);
cb(Akk, "attn_solved", il);
// u = (A*beta[..., None, :]) @ v aka U_[t]
ggml_tensor * vb = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), Akk);
gk_cumsum = ggml_cont_4d(ctx0, ggml_permute(ctx0, gk_cumsum, 1, 0, 2, 3), S_k, chunk_size, n_chunks, HB);
@ -650,7 +652,6 @@ ggml_tensor * llm_build_kimi_linear::build_kda_chunking(
cb(new_state, "new_state", il);
for (int64_t chunk = 0; chunk < n_chunks; chunk++) {
// for (int64_t chunk = 0; chunk < 1; chunk++) {
// extract one chunk worth of data
auto chunkify = [=](ggml_tensor * t) {
return ggml_cont(ctx0, ggml_view_4d(ctx0, t, t->ne[0], chunk_size, 1, t->ne[3],
@ -672,15 +673,22 @@ ggml_tensor * llm_build_kimi_linear::build_kda_chunking(
ggml_tensor * gk_cs_chunk_j_bc = ggml_repeat_4d(ctx0, gk_cs_chunk_j, chunk_size, chunk_size, S_k, HB);
ggml_tensor * decay_mask_chunk = ggml_sub(ctx0, gk_cs_chunk_j_bc, gk_cs_chunk_i);
cb(decay_mask_chunk, "decay_mask_chunk", il);
decay_mask_chunk = ggml_mul(ctx0, decay_mask_chunk, chunked_diag_mask);
decay_mask_chunk = ggml_mul(ctx0, decay_mask_chunk, diag_mask);
decay_mask_chunk = ggml_exp(ctx0, decay_mask_chunk);
decay_mask_chunk = ggml_mul(ctx0, decay_mask_chunk, chunked_diag_mask);
decay_mask_chunk = ggml_mul(ctx0, decay_mask_chunk, diag_mask);
cb(decay_mask_chunk, "decay_mask_chunk_exp", il);
ggml_tensor * k_cumdecay_chunk = chunkify(k_cumdecay);
ggml_tensor * gkexp_chunk = ggml_exp(ctx0, gk_cs_chunk);
/*
https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/kda/naive.py
for j in range(BT):
k_j = k[:, :, i, j]
g_j = g[:, :, i, j:j+1, :]
A[..., j] = torch.einsum('... c d, ... d -> ... c', q_i * (g_i - g_j).exp(), k_j)
*/
ggml_tensor * k_chunk_i = ggml_cont(ctx0, ggml_permute(ctx0, k_chunk, 2, 0, 1, 3));
ggml_tensor * k_chunk_i_bc = ggml_repeat_4d(ctx0, k_chunk_i, chunk_size, chunk_size, S_k, HB);
ggml_tensor * q_chunk_j = ggml_cont(ctx0, ggml_permute(ctx0, q_chunk, 2, 1, 0, 3));
@ -689,7 +697,7 @@ ggml_tensor * llm_build_kimi_linear::build_kda_chunking(
kq = ggml_mul(ctx0, kq, k_chunk_i_bc);
ggml_tensor * Aqk = ggml_mul(ctx0, kq, decay_mask_chunk);
Aqk = ggml_mul(ctx0, Aqk, ggml_add(ctx0, chunked_identity, chunked_mask));
Aqk = ggml_mul(ctx0, Aqk, ggml_add(ctx0, identity, causal_mask));
Aqk = ggml_cont(ctx0, ggml_permute(ctx0, Aqk, 1, 2, 0, 3));
Aqk = ggml_sum_rows(ctx0, Aqk);
Aqk = ggml_scale(ctx0, Aqk, scale); // scale q
@ -697,20 +705,26 @@ ggml_tensor * llm_build_kimi_linear::build_kda_chunking(
ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs);
// new_state [S,S,1,H*B] k_cumdecay_chunk [S,BT,1,H*B]
// new_state [S,S,1,H*B] k_cumdecay_chunk [S,BT,1,H*B]
// v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state or W_[t] @ S_[t]
ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk);
// v_new = v_i - v_prime or U_[t] - W_[t]*S_[t]
ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, vb_chunk, v_prime), v_prime);
ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new));
// q_chunk [S,BT,1,H*B] gkexp_chunk [S,BT,1,H*B]
// q_chunk [S,BT,1,H*B] gkexp_chunk [S,BT,1,H*B]
// attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
// or Gamma_[t]*Q_]t] @ S
ggml_tensor * q_gk_exp = ggml_mul(ctx0, q_chunk, gkexp_chunk);
ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_gk_exp);
attn_inter = ggml_scale(ctx0, attn_inter, scale); // scale q
// v_new_t [S,BT,1,H*B] Aqk [BT,BT,1,H*B]
// v_new_t [S,BT,1,H*B] Aqk [BT,BT,1,H*B]
// core_attn_out[:, :, i] = attn_inter + attn @ v_new or A' @ (U_[t] - W_[t]*S_[t])
ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, Aqk);
// o[:, :, i] = (q_i * g_i.exp()) @ S + A @ v_i
ggml_tensor * core_attn_out_chunk = ggml_add(ctx0, attn_inter, v_attn);
core_attn_out = core_attn_out == nullptr ? core_attn_out_chunk : ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 1);
@ -728,6 +742,7 @@ ggml_tensor * llm_build_kimi_linear::build_kda_chunking(
ggml_tensor * key_gkdiff = ggml_mul(ctx0, k_chunk, gk_diff_exp);
// rearrange((g_i[:,:,-1:] - g_i).exp()*k_i, 'b h c k -> b h k c') @ (U_[t] - W_[t] @ S)
ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, ggml_cont(ctx0, ggml_transpose(ctx0, key_gkdiff)));
new_state = ggml_add(ctx0,
@ -750,256 +765,98 @@ ggml_tensor * llm_build_kimi_linear::build_kda_chunking(
return ggml_concat(ctx0, flat_output, flat_state, 0);
}
ggml_tensor * llm_build_kimi_linear::build_kda_recurrent(
ggml_tensor * llm_build_kimi_linear::build_kda_autoregressive(
ggml_tensor * q,
ggml_tensor * k,
ggml_tensor * v,
ggml_tensor * gk,
ggml_tensor * beta,
ggml_tensor * state,
ggml_tensor * causal_mask,
ggml_tensor * identity,
ggml_tensor * state,
int il) {
GGML_ASSERT(ggml_is_contiguous(q));
GGML_ASSERT(ggml_is_contiguous(k));
GGML_ASSERT(ggml_is_contiguous(v));
GGML_ASSERT(ggml_is_contiguous(v));
GGML_ASSERT(ggml_is_contiguous(gk));
GGML_ASSERT(ggml_is_contiguous(beta));
GGML_ASSERT(ggml_is_contiguous(state));
const int64_t S_k = q->ne[0];
const int64_t H_k = q->ne[1];
const int64_t n_tokens = q->ne[2];
const int64_t n_seqs = q->ne[3];
const int64_t S_v = v->ne[0];
const int64_t H_v = v->ne[1];
GGML_ASSERT(n_tokens == 1);
GGML_ASSERT(v->ne[2] == n_tokens);
GGML_ASSERT(k->ne[2] == n_tokens);
GGML_ASSERT(gk->ne[0] == S_k && gk->ne[1] == H_v && gk->ne[2] == n_tokens && gk->ne[3] == n_seqs);
GGML_ASSERT(gk->ne[0] == S_k && gk->ne[1] == H_k && gk->ne[2] == n_tokens && gk->ne[3] == n_seqs);
GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v && state->ne[2] == H_v && state->ne[3] == n_seqs);
GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_k && state->ne[2] == H_v && state->ne[3] == n_seqs);
GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case
// TODO: can this ever be false?
const bool use_qk_l2norm = true;
if (use_qk_l2norm) {
const float eps_norm = hparams.f_norm_rms_eps;
q = ggml_l2_norm(ctx0, q, eps_norm);
k = ggml_l2_norm(ctx0, k, eps_norm);
}
const float scale = 1.0f / sqrtf(S_v);
beta = ggml_sigmoid(ctx0, beta);
ggml_tensor * causal_diag_mask = ggml_add(ctx0, causal_mask, identity);
GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case
const float eps_norm = hparams.f_norm_rms_eps;
q = ggml_l2_norm(ctx0, q, eps_norm);
k = ggml_l2_norm(ctx0, k, eps_norm);
const float scale = 1.0f / sqrtf(S_v);
q = ggml_scale(ctx0, q, scale);
beta = ggml_sigmoid(ctx0, beta);
cb(q, "q_in", il);
cb(k, "k_in", il);
cb(v, "v_in", il);
cb(beta, "beta_in", il);
cb(gk, "gk_in", il);
q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
gk = ggml_cont_4d(ctx0, ggml_permute(ctx0, gk, 1, 2, 0, 3), n_tokens, S_k, H_k, n_seqs);
beta = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3));
state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs);
cb(q, "q_perm", il);
cb(k, "k_perm", il);
cb(v, "v_perm", il);
cb(beta, "beta_perm", il);
cb(gk, "gk_perm", il);
cb(state, "state_in", il);
GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs);
GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs);
GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs);
GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs);
// =========================================================================
// Compute cumulative sum of gk per key dimension
// gk_cumsum: [S_k, n_tokens, H_k, n_seqs] - cumsum along dim 1 (tokens)
// =========================================================================
ggml_tensor * gk_cumsum = ggml_cumsum(ctx0, gk);
cb(gk_cumsum, "gk_cumsum", il);
// g [H,1,B,1] g_t [1,H,B,1] => [1,1,H,B]
// gk [S,H,1,B] => [S,1,H,B] gk_t [1,S,H,B]
// beta [H,1,1,B] beta_t [1,H,1,B] => [1,1,H,B]
gk = ggml_reshape_4d(ctx0, gk, S_k, 1, H_k, n_seqs);
ggml_tensor * gk_t = ggml_cont(ctx0, ggml_transpose(ctx0, gk));
ggml_tensor * beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_k, n_seqs);
// Scale k and k_beta
ggml_tensor * k_beta = ggml_mul(ctx0, k, beta);
ggml_tensor * v_beta = ggml_mul(ctx0, v, beta);
cb(k_beta, "k_beta", il);
cb(v_beta, "v_beta", il);
// Apply exponential to gk_t
gk_t = ggml_exp(ctx0, gk_t);
// Apply the gated delta rule for the single timestep
// last_recurrent_state = last_recurrent_state * gk_t
// S = S * g_i[..., None].exp()
state = ggml_mul(ctx0, state, gk_t);
/*
https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/kda/naive.py
for i in range(T):
k_i = k[..., i, :]
g_i = g[..., i:i+1, :]
A[..., i] = torch.einsum('... c d, ... d -> ... c', k * (g - g_i).exp(), k_i)
*/
const int64_t HB = H_k * n_seqs;
ggml_tensor * k_per = ggml_cont(ctx0, ggml_permute(ctx0, k, 1, 0, 2, 3));
ggml_tensor * k_i = ggml_reshape_4d(ctx0, k_per, n_tokens, 1, S_k, HB);
ggml_tensor * k_i_bc = ggml_repeat_4d(ctx0, k_i, n_tokens, n_tokens, S_k, HB);
ggml_tensor * g_i = ggml_reshape_4d(ctx0, gk_cumsum, n_tokens, 1, S_k, HB);
ggml_tensor * g_i_bc = ggml_repeat_4d(ctx0, g_i, n_tokens, n_tokens, S_k, HB); // [S_k, chunk_size, 1, HB] -> [S_k, chunk_size, chunk_size, HB]
ggml_tensor * k_j = ggml_reshape_4d(ctx0, k_per, 1, n_tokens, S_k, HB);
ggml_tensor * k_j_bc = ggml_repeat_4d(ctx0, k_j, n_tokens, n_tokens, S_k, HB);
ggml_tensor * g_j = ggml_reshape_4d(ctx0, gk_cumsum, 1, n_tokens, S_k, HB);
ggml_tensor * g_j_bc = ggml_repeat_4d(ctx0, g_j, n_tokens, n_tokens, S_k, HB); // [S_k, 1, chunk_size, HB] -> [S_k, chunk_size, chunk_size, HB]
ggml_tensor * decay_mask = ggml_sub(ctx0, g_j_bc, g_i_bc);
cb(decay_mask, "decay_mask", il);
decay_mask = ggml_mul(ctx0, decay_mask, causal_diag_mask);
decay_mask = ggml_exp(ctx0, decay_mask);
decay_mask = ggml_mul(ctx0, decay_mask, causal_diag_mask);
cb(decay_mask, "decay_mask_exp", il);
ggml_tensor * Akk = ggml_mul(ctx0, decay_mask, k_j_bc);
Akk = ggml_mul(ctx0, Akk, k_i_bc);
Akk = ggml_cont(ctx0, ggml_permute(ctx0, Akk, 1, 2, 0, 3));
Akk = ggml_sum_rows(ctx0, Akk);
Akk = ggml_reshape_4d(ctx0, Akk, n_tokens, n_tokens, H_k, n_seqs);
Akk = ggml_mul(ctx0, Akk, beta);
Akk = ggml_neg(ctx0, ggml_mul(ctx0, Akk, causal_mask));
cb(Akk, "attn_pre_rec", il);
// for i in range(1, chunk_size):
// row = attn[..., i, :i].clone()
// sub = attn[..., :i, :i].clone()
// attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
// attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
//
// We reduce this to a linear triangular solve: AX = B, where B = attn, A = I - tril(A)
ggml_tensor * attn_lower = ggml_mul(ctx0, Akk, causal_mask);
ggml_tensor * lhs = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower);
ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, Akk, true, true, false);
Akk = ggml_mul(ctx0, lin_solve, causal_mask);
Akk = ggml_add(ctx0, Akk, identity);
gk_cumsum = ggml_cont(ctx0, ggml_permute(ctx0, gk_cumsum, 1, 0, 2, 3)); // back to [S_k, n_tokens, H_k, n_seqs]
// u = (A*beta[..., None, :]) @ v aka U_[t]
ggml_tensor * vb = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), Akk);
cb(vb, "value_beta", il);
// k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) or W_[t]
ggml_tensor * gkexp = ggml_exp(ctx0, gk_cumsum); // [S,T,H,B]
ggml_tensor * kbeta_gkexp = ggml_mul(ctx0, k_beta, gkexp);
cb(kbeta_gkexp, "kbeta_gkexp", il);
ggml_tensor * k_cumdecay = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gkexp)), Akk);
cb(k_cumdecay, "k_cumdecay", il);
/*
https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/kda/naive.py
for j in range(BT):
k_j = k[:, :, i, j]
g_j = g[:, :, i, j:j+1, :]
A[..., j] = torch.einsum('... c d, ... d -> ... c', q_i * (g_i - g_j).exp(), k_j)
*/
ggml_tensor * q_per = ggml_cont(ctx0, ggml_permute(ctx0, q, 1, 0, 2, 3));
ggml_tensor * q_j = ggml_reshape_4d(ctx0, q_per, 1, n_tokens, S_k, HB);
ggml_tensor * q_j_bc = ggml_repeat_4d(ctx0, q_j, n_tokens, n_tokens, S_k, HB);
ggml_tensor * kq = ggml_mul(ctx0, decay_mask, q_j_bc);
kq = ggml_mul(ctx0, kq, k_i_bc);
kq = ggml_cont(ctx0, ggml_permute(ctx0, kq, 1, 2, 0, 3));
ggml_tensor * Aqk = ggml_sum_rows(ctx0, kq);
Aqk = ggml_cont(ctx0, ggml_reshape_4d(ctx0, Aqk, n_tokens, n_tokens, H_k, n_seqs));
Aqk = ggml_mul(ctx0, Aqk, ggml_add(ctx0, identity, causal_mask));
Aqk = ggml_scale(ctx0, Aqk, scale); // scale q
cb(Aqk, "attn_decay_key", il);
ggml_tensor * state_t = ggml_cont(ctx0, ggml_transpose(ctx0, state));
// v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state or W_[t] @ S_[t]
ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay);
cb(v_prime, "v_prime", il);
// state [S,S,H,B] k [S,1,H,B] k_state [S_v,1,H,B]
k = ggml_reshape_4d(ctx0, k, S_k, 1, H_k, n_seqs);
ggml_tensor * k_state = ggml_mul_mat(ctx0, state_t, k);
// v_new = v_i - v_prime or U_[t] - W_[t]*S_[t]
ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, vb, v_prime), v_prime);
// v_i - (k_i[..., None] * S).sum(-2)
v = ggml_reshape_4d(ctx0, v, S_v, 1, H_v, n_seqs);
ggml_tensor * v_diff = ggml_sub(ctx0, v, k_state);
// v_new_t [T.S.H,B]
ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new));
// b_i[..., None] * k_i
ggml_tensor * k_beta = ggml_mul(ctx0, k, beta_t);
cb(v_new, "v_new", il);
// S = S + torch.einsum('b h k, b h v -> b h k v', b_i[..., None] * k_i, v_i - (k_i[..., None] * S).sum(-2))
// v_diff_t [1,S_v,H,B] k_beta_t [1,S_k,H,B] state [S_v,S_k,H,B]
state = ggml_add(ctx0, state, ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_diff)), ggml_cont(ctx0, ggml_transpose(ctx0, k_beta))));
// attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
// or Gamma_[t]*Q_]t] @ S
ggml_tensor * q_gk_exp = ggml_mul(ctx0, q, gkexp);
ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_gk_exp);
// scale q at attn_inter as suggested in chunk_gla_fwd_kernel_o of
// github.com/fla-org/flash-linear-attention/fla/ops/gla/chunk.py
attn_inter = ggml_scale(ctx0, attn_inter, scale); // scale q
cb(attn_inter, "attn_inter", il);
// core_attn_out[:, :, i] = attn_inter + attn @ v_new or A' @ (U_[t] - W_[t]*S_[t])
ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, Aqk);
cb(v_attn, "v_attn", il);
// o[:, :, i] = (q_i * g_i.exp()) @ S + A @ v_i
ggml_tensor * core_attn_out = ggml_add(ctx0, attn_inter, v_attn);
cb(core_attn_out, "core_attn_out", il);
ggml_tensor * gk_cum_last =
ggml_cont(ctx0, ggml_view_4d(ctx0, gk_cumsum, gk_cumsum->ne[0], 1, gk_cumsum->ne[2], gk_cumsum->ne[3],
gk_cumsum->nb[1], gk_cumsum->nb[2], gk_cumsum->nb[3],
gk_cumsum->nb[1] * (gk_cumsum->ne[1] - 1)));
cb(gk_cum_last, "gk_cum_last", il);
ggml_tensor * gkexp_last = ggml_exp(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, gk_cum_last)));
cb(gkexp_last, "gkexp_last", il);
ggml_tensor * gk_diff = ggml_neg(ctx0, ggml_sub(ctx0, gk_cumsum, gk_cum_last));
cb(gk_diff, "gk_diff", il);
ggml_tensor * gk_diff_exp = ggml_exp(ctx0, gk_diff);
cb(gk_diff_exp, "gk_diff_exp", il);
ggml_tensor * key_gkdiff = ggml_mul(ctx0, k, gk_diff_exp);
cb(key_gkdiff, "key_gkdiff", il);
// rearrange((g_i[:,:,-1:] - g_i).exp()*k_i, 'b h c k -> b h k c') @ (U_[t] - W_[t] @ S)
ggml_tensor * kgkdmulvnew = ggml_mul_mat(ctx0, v_new_t, ggml_cont(ctx0, ggml_transpose(ctx0, key_gkdiff)));
cb(kgkdmulvnew, "kgkdmulvnew", il);
state = ggml_add(ctx0, ggml_mul(ctx0, state, gkexp_last), kgkdmulvnew);
q = ggml_reshape_4d(ctx0, q, S_k, 1, H_k, n_seqs);
state_t = ggml_cont(ctx0, ggml_transpose(ctx0, state));
ggml_tensor * core_attn_out = ggml_mul_mat(ctx0, state_t, q);
// core_attn_out should be [S_v, 1, H_v, n_seqs] after this
cb(core_attn_out, "output_tokens", il);
cb(state, "new_state", il);
// flatten output
ggml_tensor * flat_output =
ggml_cont_1d(ctx0, ggml_permute(ctx0, core_attn_out, 0, 2, 1, 3), S_v * H_v * n_tokens * n_seqs);
ggml_tensor * flat_state = ggml_cont_1d(ctx0, state, S_v * S_v * H_v * n_seqs);
// flatten output, no need to permute since n_tokens is 1 so [S_v, 1, H_v, n_seqs] and [S_v, H_v, 1, n_seqs] are equivalent memory-layout wise
ggml_tensor * flat_output = ggml_reshape_1d(ctx0, core_attn_out, S_v * H_v * n_tokens * n_seqs);
ggml_tensor * flat_state = ggml_reshape_1d(ctx0, state, S_v * S_v * H_v * n_seqs);
return ggml_concat(ctx0, flat_output, flat_state, 0);
}

View File

@ -288,26 +288,25 @@ struct llm_build_kimi_linear : public llm_graph_context_mamba {
llm_build_kimi_linear(const llama_model & model, const llm_graph_params & params);
private:
const llama_model & model;
ggml_tensor * build_kda_recurrent(
ggml_tensor * build_kda_autoregressive(
ggml_tensor * q,
ggml_tensor * k,
ggml_tensor * v,
ggml_tensor * g,
ggml_tensor * gk,
ggml_tensor * beta,
ggml_tensor * state,
ggml_tensor * causal_mask,
ggml_tensor * identity,
int il);
ggml_tensor * build_kda_chunking(
ggml_tensor * q,
ggml_tensor * k,
ggml_tensor * v,
ggml_tensor * g,
ggml_tensor * gk,
ggml_tensor * beta,
ggml_tensor * state,
ggml_tensor * causal_mask,
ggml_tensor * identity,
ggml_tensor * diag_mask,
int il);
};