Kimi Linear backend agnostic

This commit is contained in:
Yee Man Chan 2026-01-05 16:35:19 +08:00
parent a4020d867f
commit 66c0c5d8d4
2 changed files with 450 additions and 64 deletions

View File

@ -1,24 +1,35 @@
#include "models.h"
#include "ggml.h"
#include "llama-impl.h"
#define CHUNK_SIZE 64
llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const llm_graph_params & params) : llm_graph_context_mamba(params), model(model) {
ggml_tensor * cur;
ggml_tensor * inpL;
inpL = build_inp_embd(model.tok_embd);
cb(inpL, "model.embed_tokens", -1);
// Note: Kimi MLA does NOT use RoPE (rotary_emb=None in vLLM)
// So we don't need inp_pos
// Only use recurrent state input for KDA layers
// MLA layers use direct softmax attention without KV cache
auto * inp_rs = build_rs_inp();
// Input for MLA layers (no KV cache)
auto * inp_no_cache = build_attn_inp_no_cache();
auto * inp = build_inp_mem_hybrid();
auto * inp_rs = inp->get_recr();
auto * inp_attn = inp->get_attn();
// Output ids for selecting which tokens to output
ggml_tensor * inp_out_ids = build_inp_out_ids();
ggml_tensor * causal_mask =
ggml_tri(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, ubatch.n_seq_tokens, ubatch.n_seq_tokens), 1.0f),
GGML_TRI_TYPE_LOWER);
ggml_tensor * identity = ggml_diag(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, ubatch.n_seq_tokens), 1.0f));
ggml_build_forward_expand(gf, causal_mask);
ggml_build_forward_expand(gf, identity);
// Kimi dimension constants
const int64_t n_head = hparams.n_head();
const int64_t head_dim = hparams.kda_head_dim;
@ -40,10 +51,6 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
// Confirmed from tensor shape: wkv_a_mqa [2304, 576] = [n_embd, kv_lora_rank + qk_rope_head_dim]
const int64_t n_embd_head_qk_rope = hparams.n_rot; // config.qk_rope_head_dim
const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope; // 192 - 64 = 128
// Attention scale for KDA (1/sqrt(head_dim))
const float kq_scale_kda = 1.0f / sqrtf((float)head_dim);
// Attention scale for MLA
const float kq_scale_mla = 1.0f / sqrtf((float)n_embd_head_k_mla);
@ -51,6 +58,8 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
const auto & layer = model.layers[il];
ggml_tensor * inpSA = inpL;
if (!layer.attn_norm)
LLAMA_LOG_INFO("Empty attn_norm at layer %d\n", il);
// Attention Norm
cur = build_norm(inpL, layer.attn_norm, NULL, LLM_NORM_RMS, il);
cb(cur, "attn_norm", il);
@ -69,6 +78,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
// Get conv states from r_l tensor (Q, K, V each have separate state)
ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
cb(conv_states_all, "conv_states_all", il);
const int64_t conv_state_size = (d_conv - 1) * d_inner;
const int64_t n_embd_r_total = 3 * conv_state_size; // Q + K + V
ggml_tensor * conv_state_all = build_rs(inp_rs, conv_states_all, hparams.n_embd_r(), n_seqs);
@ -143,12 +153,14 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
// ggml_ssm_conv output: {d_inner, n_seq_tokens, n_seqs}
Qcur = ggml_ssm_conv(ctx0, conv_q, conv_weight);
cb(Qcur, "Q conv1d", il);
// Reshape to 2D for bias add: {d_inner, n_tokens}
Qcur = ggml_reshape_2d(ctx0, Qcur, d_inner, n_tokens);
if (layer.ssm_q_conv_b) {
Qcur = ggml_add(ctx0, Qcur, layer.ssm_q_conv_b);
}
Qcur = ggml_silu(ctx0, Qcur);
cb(Qcur, "Q conv1d b", il);
} else {
GGML_ABORT("KDA layer missing Q conv weight");
}
@ -173,11 +185,13 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
}
ggml_tensor * k_conv_weight = ggml_reshape_2d(ctx0, k_conv_f32, d_conv, d_inner);
Kcur = ggml_ssm_conv(ctx0, conv_k, k_conv_weight);
cb(Kcur, "K conv1d", il);
Kcur = ggml_reshape_2d(ctx0, Kcur, d_inner, n_tokens);
if (layer.ssm_k_conv_b) {
Kcur = ggml_add(ctx0, Kcur, layer.ssm_k_conv_b);
}
Kcur = ggml_silu(ctx0, Kcur);
cb(Kcur, "K conv1d b", il);
} else {
GGML_ABORT("KDA layer missing K conv weight");
}
@ -202,11 +216,13 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
}
ggml_tensor * v_conv_weight = ggml_reshape_2d(ctx0, v_conv_f32, d_conv, d_inner);
Vcur = ggml_ssm_conv(ctx0, conv_v, v_conv_weight);
cb(Vcur, "V conv1d", il);
Vcur = ggml_reshape_2d(ctx0, Vcur, d_inner, n_tokens);
if (layer.ssm_v_conv_b) {
Vcur = ggml_add(ctx0, Vcur, layer.ssm_v_conv_b);
}
Vcur = ggml_silu(ctx0, Vcur);
cb(Vcur, "V conv1d b", il);
} else {
GGML_ABORT("KDA layer missing V conv weight");
}
@ -215,6 +231,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
// g1 = -exp(A_log) * softplus(f_b(f_a(x)) + dt_bias)
ggml_tensor * f_a = ggml_mul_mat(ctx0, layer.ssm_f_a, cur);
ggml_tensor * g1 = ggml_mul_mat(ctx0, layer.ssm_f_b, f_a);
cb(g1, "g1 f_b(f_a(cur))", il);
g1 = ggml_add(ctx0, g1, layer.ssm_dt_b);
g1 = ggml_softplus(ctx0, g1);
g1 = ggml_reshape_3d(ctx0, g1, head_dim, n_head, n_tokens);
@ -229,7 +246,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
// Step 4: Compute beta (mixing coefficient)
ggml_tensor * beta = ggml_mul_mat(ctx0, layer.ssm_beta, cur);
beta = ggml_sigmoid(ctx0, beta);
beta = ggml_cont_4d(ctx0, beta, n_head, 1, n_seq_tokens, n_seqs);
cb(beta, "kda_beta", il);
// Step 5: Reshape for KDA recurrence
@ -240,49 +257,56 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
Kcur = ggml_cont(ctx0, ggml_reshape_4d(ctx0, Kcur, head_dim, n_head, n_seq_tokens, n_seqs));
Vcur = ggml_cont(ctx0, ggml_reshape_4d(ctx0, Vcur, head_dim, n_head, n_seq_tokens, n_seqs));
g1 = ggml_cont(ctx0, ggml_reshape_4d(ctx0, g1, head_dim, n_head, n_seq_tokens, n_seqs));
beta = ggml_cont(ctx0, ggml_reshape_3d(ctx0, beta, n_head, n_seq_tokens, n_seqs));
cb(Qcur, "kda_Q", il);
cb(Kcur, "kda_K", il);
cb(Vcur, "kda_V", il);
// Step 6: Get SSM state and compute KDA recurrence using ggml_kda_scan
ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
// Use build_rs with lambda pattern (like Mamba SSM scan)
auto get_kda_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) {
ggml_tensor * h_state = ggml_reshape_4d(ctx, states, head_dim, head_dim, n_head, mctx_cur->get_size());
// Call ggml_kda_scan which implements the correct KDA recurrence
return ggml_kda_scan(ctx, h_state, Qcur, Kcur, Vcur, g1, beta, ids);
};
ggml_tensor * y_kda = build_rs(inp_rs, ssm_states_all, hparams.n_embd_s(), n_seqs, get_kda_rows);
cb(y_kda, "kda_scan_out", il);
// Store updated state back
// y_kda contains: [attention_output (head_dim * n_head * n_seq_tokens * n_seqs), new_state (head_dim * head_dim * n_head * n_seqs)]
const int64_t attn_out_size = head_dim * n_head * n_seq_tokens * n_seqs;
const int64_t state_size = head_dim * head_dim * n_head;
ggml_build_forward_expand(gf,
ggml_cpy(ctx0,
ggml_view_1d(ctx0, y_kda, state_size * n_seqs, attn_out_size * ggml_element_size(y_kda)),
ggml_view_1d(ctx0, ssm_states_all, state_size * n_seqs, kv_head * state_size * ggml_element_size(ssm_states_all))));
// Extract attention output
ggml_tensor * attn_out = ggml_view_1d(ctx0, y_kda, attn_out_size, 0);
attn_out = ggml_reshape_3d(ctx0, attn_out, head_dim, n_head, n_seq_tokens * n_seqs);
cb(attn_out, "kda_attn_out", il);
ggml_tensor * state = build_rs(inp_rs, ssm_states_all, hparams.n_embd_s(), n_seqs);
state = ggml_reshape_4d(ctx0, state, head_dim, head_dim, n_head, n_seqs);
// Choose between build_kda_chunking and build_kda_recurrent based on n_tokens
// TODO: Currently only build_kda_recurrent is implemented
ggml_tensor * attn_out = n_seq_tokens > CHUNK_SIZE ?
build_kda_recurrent(Qcur, Kcur, Vcur, g1, beta, state, causal_mask, identity, il) :
build_kda_recurrent(Qcur, Kcur, Vcur, g1, beta, state, causal_mask, identity, il);
cb(attn_out, "attn_out", il);
// The tensors were concatenated 1d, so we need to extract them 1d as well
const int64_t output_flat_size = head_dim * n_head * n_seq_tokens * n_seqs;
ggml_tensor * attn_out_1d = ggml_view_1d(ctx0, attn_out, output_flat_size, 0);
cb(attn_out_1d, "attn_out_1d", il);
ggml_tensor * attn_out_final = ggml_reshape_3d(ctx0, attn_out_1d, head_dim, n_head, n_seq_tokens * n_seqs);
cb(attn_out_final, "attn_out_reshaped", il);
// Extract the state part (second part of the concatenated tensor)
// State starts after n_tokens elements along dimension 1
const int64_t state_flat_size = head_dim * head_dim * n_head * n_seqs;
ggml_tensor * state_1d =
ggml_view_1d(ctx0, attn_out, state_flat_size, output_flat_size * ggml_element_size(attn_out));
cb(state_1d, "state_1d", il);
// Update the recurrent states
ggml_build_forward_expand(gf,
ggml_cpy(ctx0, state_1d,
ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs,
kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all))));
GGML_ASSERT(ggml_nelements(attn_out_1d) + ggml_nelements(state_1d) == ggml_nelements(attn_out));
// Step 7: Output gating g2 = g_b(g_a(x))
ggml_tensor * cur_2d = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs);
ggml_tensor * g_a = ggml_mul_mat(ctx0, layer.ssm_g_a, cur_2d);
ggml_tensor * g2 = ggml_mul_mat(ctx0, layer.ssm_g_b, g_a);
cb(g2, "g2 g_b(g_a(cur_2d))", il);
g2 = ggml_reshape_3d(ctx0, g2, head_dim, n_head, n_seq_tokens * n_seqs);
// Step 8: Apply o_norm with sigmoid gating
// Note: Kimi model uses sigmoid gating, not SiLU (despite FusedRMSNormGated default being swish)
// Formula: output = RMSNorm(x) * sigmoid(g)
ggml_tensor * normed = build_norm(attn_out, layer.ssm_o_norm, layer.ssm_o_norm_b, LLM_NORM_RMS, il);
ggml_tensor * normed = build_norm(attn_out_final, layer.ssm_o_norm, layer.ssm_o_norm_b, LLM_NORM_RMS, il);
cb(normed, "kda_normed", il);
ggml_tensor * gate = ggml_sigmoid(ctx0, g2);
ggml_tensor * gated = ggml_mul(ctx0, normed, gate);
@ -290,11 +314,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
gated = ggml_cont_2d(ctx0, gated, d_inner, n_tokens);
cur = ggml_mul_mat(ctx0, layer.wo, gated);
cb(cur, "kda_out", il);
GGML_UNUSED(d_conv);
GGML_UNUSED(kq_scale_kda);
} else if (is_mla) {
// === MLA Layer (Multi-head Latent Attention) without KV Cache ===
// Reference: vLLM mla.py
@ -308,25 +328,25 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
cb(Qcur, "mla_Q", il);
// Step 2: KV compression
// kv_lora = kv_a_proj_with_mqa(hidden_states) -> [kv_lora_rank + qk_rope_head_dim, n_tokens]
ggml_tensor * kv_lora = ggml_mul_mat(ctx0, layer.wkv_a_mqa, cur);
// kv_cmpr_pe = kv_a_proj_with_mqa(hidden_states) -> [kv_lora_rank + qk_rope_head_dim, n_tokens]
ggml_tensor * kv_cmpr_pe = ggml_mul_mat(ctx0, layer.wkv_a_mqa, cur);
// Split: kv_c = kv_lora[:kv_lora_rank], k_pe = kv_lora[kv_lora_rank:]
ggml_tensor * kv_c = ggml_view_2d(ctx0, kv_lora, kv_lora_rank, n_tokens,
ggml_row_size(kv_lora->type, kv_lora_rank + n_embd_head_qk_rope), 0);
ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_lora, n_embd_head_qk_rope, 1, n_tokens,
ggml_row_size(kv_lora->type, kv_lora_rank + n_embd_head_qk_rope),
ggml_row_size(kv_lora->type, kv_lora_rank + n_embd_head_qk_rope),
ggml_row_size(kv_lora->type, kv_lora_rank));
// Split: kv_cmpr = kv_lora[:kv_lora_rank], k_pe = kv_lora[kv_lora_rank:]
ggml_tensor * kv_cmpr = ggml_view_2d(ctx0, kv_cmpr_pe, kv_lora_rank, n_tokens,
ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), 0);
ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_cmpr_pe, n_embd_head_qk_rope, 1, n_tokens,
ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope),
ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope),
ggml_row_size(kv_cmpr_pe->type, kv_lora_rank));
// Note: Kimi MLA does NOT apply RoPE (rotary_emb=None in vLLM)
// k_pe is used directly without RoPE
// Normalize kv_c
kv_c = build_norm(kv_c, layer.attn_kv_a_norm, nullptr, LLM_NORM_RMS, il);
kv_cmpr = build_norm(kv_cmpr, layer.attn_kv_a_norm, nullptr, LLM_NORM_RMS, il);
// KV decompression: kv = kv_b_proj(kv_c_normed)
ggml_tensor * kv = ggml_mul_mat(ctx0, layer.wkv_b, kv_c);
ggml_tensor * kv = ggml_mul_mat(ctx0, layer.wkv_b, kv_cmpr);
const int64_t kv_per_head = n_embd_head_qk_nope + n_embd_head_v_mla;
// Split kv into k_nope and v
@ -344,17 +364,16 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
// Concatenate k_nope + k_pe (broadcast k_pe to all heads)
// K = [k_nope, k_pe] where k_nope is [qk_nope_head_dim, n_head, n_tokens]
// and k_pe is [qk_rope_head_dim, 1, n_tokens] broadcast to all heads
k_pe = ggml_cont(ctx0, k_pe);
// Need to broadcast k_pe from [qk_rope, 1, n_tokens] to [qk_rope, n_head, n_tokens]
ggml_tensor * k_pe_target = ggml_new_tensor_3d(ctx0, k_pe->type, n_embd_head_qk_rope, n_head, n_tokens);
ggml_tensor * k_pe_repeated = ggml_repeat(ctx0, k_pe, k_pe_target);
ggml_tensor * Kcur = ggml_concat(ctx0, k_nope, k_pe_repeated, 0);
cb(Kcur, "mla_K", il);
// Direct softmax attention (without KV cache)
// Use build_attn with inp_no_cache for proper mask handling
cur = build_attn(inp_no_cache, layer.wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale_mla, il);
// cb(cur, "mla_out", il);
// Direct softmax attention (with KV cache)
// Use build_attn with inp_attn for proper mask handling
cur = build_attn(inp_attn, layer.wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale_mla, il);
cb(cur, "mla_out", il);
} else {
// Unknown layer type - this should not happen
@ -435,6 +454,352 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
res->t_logits = cur;
ggml_build_forward_expand(gf, cur);
GGML_UNUSED(n_embd_head_qk_nope);
}
/*
IMPORTANT: Currently build_kda_chunking is not implemented nor called
*/
ggml_tensor * llm_build_kimi_linear::build_kda_chunking(
ggml_tensor * q,
ggml_tensor * k,
ggml_tensor * v,
ggml_tensor * gk,
ggml_tensor * beta,
ggml_tensor * state,
ggml_tensor * causal_mask,
ggml_tensor * identity,
int il) {
GGML_ASSERT(ggml_is_contiguous(q));
GGML_ASSERT(ggml_is_contiguous(k));
GGML_ASSERT(ggml_is_contiguous(v));
GGML_ASSERT(ggml_is_contiguous(gk));
GGML_ASSERT(ggml_is_contiguous(beta));
GGML_ASSERT(ggml_is_contiguous(state));
const int64_t S_k = q->ne[0];
const int64_t H_k = q->ne[1];
const int64_t n_tokens = q->ne[2];
const int64_t n_seqs = q->ne[3];
const int64_t S_v = v->ne[0];
const int64_t H_v = v->ne[1];
GGML_ASSERT(v->ne[2] == n_tokens);
GGML_ASSERT(k->ne[2] == n_tokens);
GGML_ASSERT(gk->ne[0] == S_v && gk->ne[1] == H_v && gk->ne[2] == n_tokens && gk->ne[3] == n_seqs);
GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs);
GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case
// TODO: can this ever be false?
const bool use_qk_l2norm = true;
if (use_qk_l2norm) {
const float eps_norm = hparams.f_norm_rms_eps;
q = ggml_l2_norm(ctx0, q, eps_norm);
k = ggml_l2_norm(ctx0, k, eps_norm);
}
const float scale = 1.0f / sqrtf(S_v);
q = ggml_scale(ctx0, q, scale);
beta = ggml_sigmoid(ctx0, beta);
cb(q, "q_in", il);
cb(k, "k_in", il);
cb(v, "v_in", il);
cb(beta, "beta_in", il);
cb(gk, "gk_in", il);
q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
gk = ggml_cont_4d(ctx0, ggml_permute(ctx0, gk, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
beta = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3));
state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs);
ggml_tensor * causal_diag_mask = ggml_add(ctx0, causal_mask, identity);
cb(q, "q_perm", il);
cb(k, "k_perm", il);
cb(v, "v_perm", il);
cb(beta, "beta_perm", il);
cb(gk, "gk_perm", il);
cb(state, "state_in", il);
cb(causal_diag_mask, "causal_diag_mask", il);
GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs);
GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs);
GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs);
GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs);
ggml_tensor * v_beta = ggml_mul(ctx0, v, beta);
ggml_tensor * k_beta = ggml_mul(ctx0, k, beta);
cb(k_beta, "k_beta", il);
cb(v_beta, "v_beta", il);
return nullptr;
}
ggml_tensor * llm_build_kimi_linear::build_kda_recurrent(
ggml_tensor * q,
ggml_tensor * k,
ggml_tensor * v,
ggml_tensor * gk,
ggml_tensor * beta,
ggml_tensor * state,
ggml_tensor * causal_mask,
ggml_tensor * identity,
int il) {
GGML_ASSERT(ggml_is_contiguous(q));
GGML_ASSERT(ggml_is_contiguous(k));
GGML_ASSERT(ggml_is_contiguous(v));
GGML_ASSERT(ggml_is_contiguous(gk));
GGML_ASSERT(ggml_is_contiguous(beta));
GGML_ASSERT(ggml_is_contiguous(state));
const int64_t S_k = q->ne[0];
const int64_t H_k = q->ne[1];
const int64_t n_tokens = q->ne[2];
const int64_t n_seqs = q->ne[3];
const int64_t S_v = v->ne[0];
const int64_t H_v = v->ne[1];
GGML_ASSERT(v->ne[2] == n_tokens);
GGML_ASSERT(k->ne[2] == n_tokens);
GGML_ASSERT(gk->ne[0] == S_k && gk->ne[1] == H_v && gk->ne[2] == n_tokens && gk->ne[3] == n_seqs);
GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v && state->ne[2] == H_v && state->ne[3] == n_seqs);
GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case
// TODO: can this ever be false?
const bool use_qk_l2norm = true;
if (use_qk_l2norm) {
const float eps_norm = hparams.f_norm_rms_eps;
q = ggml_l2_norm(ctx0, q, eps_norm);
k = ggml_l2_norm(ctx0, k, eps_norm);
}
const float scale = 1.0f / sqrtf(S_v);
beta = ggml_sigmoid(ctx0, beta);
ggml_tensor * causal_diag_mask = ggml_add(ctx0, causal_mask, identity);
cb(q, "q_in", il);
cb(k, "k_in", il);
cb(v, "v_in", il);
cb(beta, "beta_in", il);
cb(gk, "gk_in", il);
q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
gk = ggml_cont_4d(ctx0, ggml_permute(ctx0, gk, 1, 2, 0, 3), n_tokens, S_k, H_k, n_seqs);
beta = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3));
state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs);
cb(q, "q_perm", il);
cb(k, "k_perm", il);
cb(v, "v_perm", il);
cb(beta, "beta_perm", il);
cb(gk, "gk_perm", il);
cb(state, "state_in", il);
GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs);
GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs);
GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs);
GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs);
// =========================================================================
// Compute cumulative sum of gk per key dimension
// gk_cumsum: [S_k, n_tokens, H_k, n_seqs] - cumsum along dim 1 (tokens)
// =========================================================================
ggml_tensor * gk_cumsum = ggml_cumsum(ctx0, gk);
cb(gk_cumsum, "gk_cumsum", il);
// Scale k and k_beta
ggml_tensor * k_beta = ggml_mul(ctx0, k, beta);
ggml_tensor * v_beta = ggml_mul(ctx0, v, beta);
cb(k_beta, "k_beta", il);
cb(v_beta, "v_beta", il);
/*
https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/kda/naive.py
for i in range(T):
k_i = k[..., i, :]
g_i = g[..., i:i+1, :]
A[..., i] = torch.einsum('... c d, ... d -> ... c', k * (g - g_i).exp(), k_i)
*/
const int64_t HB = H_k * n_seqs;
ggml_tensor * k_per = ggml_cont(ctx0, ggml_permute(ctx0, k, 1, 0, 2, 3));
ggml_tensor * k_i = ggml_reshape_4d(ctx0, k_per, n_tokens, 1, S_k, HB);
ggml_tensor * k_i_bc = ggml_repeat_4d(ctx0, k_i, n_tokens, n_tokens, S_k, HB);
ggml_tensor * g_i = ggml_reshape_4d(ctx0, gk_cumsum, n_tokens, 1, S_k, HB);
ggml_tensor * g_i_bc = ggml_repeat_4d(ctx0, g_i, n_tokens, n_tokens, S_k, HB); // [S_k, chunk_size, 1, HB] -> [S_k, chunk_size, chunk_size, HB]
ggml_tensor * k_j = ggml_reshape_4d(ctx0, k_per, 1, n_tokens, S_k, HB);
ggml_tensor * k_j_bc = ggml_repeat_4d(ctx0, k_j, n_tokens, n_tokens, S_k, HB);
ggml_tensor * g_j = ggml_reshape_4d(ctx0, gk_cumsum, 1, n_tokens, S_k, HB);
ggml_tensor * g_j_bc = ggml_repeat_4d(ctx0, g_j, n_tokens, n_tokens, S_k, HB); // [S_k, 1, chunk_size, HB] -> [S_k, chunk_size, chunk_size, HB]
ggml_tensor * decay_mask = ggml_sub(ctx0, g_j_bc, g_i_bc);
cb(decay_mask, "decay_mask", il);
decay_mask = ggml_mul(ctx0, decay_mask, causal_diag_mask);
decay_mask = ggml_exp(ctx0, decay_mask);
decay_mask = ggml_mul(ctx0, decay_mask, causal_diag_mask);
cb(decay_mask, "decay_mask_exp", il);
ggml_tensor * Akk = ggml_mul(ctx0, decay_mask, k_j_bc);
Akk = ggml_mul(ctx0, Akk, k_i_bc);
Akk = ggml_cont(ctx0, ggml_permute(ctx0, Akk, 1, 2, 0, 3));
Akk = ggml_sum_rows(ctx0, Akk);
Akk = ggml_reshape_4d(ctx0, Akk, n_tokens, n_tokens, H_k, n_seqs);
Akk = ggml_mul(ctx0, Akk, beta);
Akk = ggml_neg(ctx0, ggml_mul(ctx0, Akk, causal_mask));
cb(Akk, "attn_pre_rec", il);
// for i in range(1, chunk_size):
// row = attn[..., i, :i].clone()
// sub = attn[..., :i, :i].clone()
// attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
// attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
//
// We reduce this to a linear triangular solve: AX = B, where B = attn, A = I - tril(A)
ggml_tensor * attn_lower = ggml_mul(ctx0, Akk, causal_mask);
ggml_tensor * lhs = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower);
ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, Akk, true, true, false);
Akk = ggml_mul(ctx0, lin_solve, causal_mask);
Akk = ggml_add(ctx0, Akk, identity);
gk_cumsum = ggml_cont(ctx0, ggml_permute(ctx0, gk_cumsum, 1, 0, 2, 3)); // back to [S_k, n_tokens, H_k, n_seqs]
// u = (A*beta[..., None, :]) @ v aka U_[t]
ggml_tensor * vb = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), Akk);
cb(vb, "value_beta", il);
// k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) or W_[t]
ggml_tensor * gkexp = ggml_exp(ctx0, gk_cumsum); // [S,T,H,B]
ggml_tensor * kbeta_gkexp = ggml_mul(ctx0, k_beta, gkexp);
cb(kbeta_gkexp, "kbeta_gkexp", il);
ggml_tensor * k_cumdecay = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gkexp)), Akk);
cb(k_cumdecay, "k_cumdecay", il);
/*
https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/kda/naive.py
for j in range(BT):
k_j = k[:, :, i, j]
g_j = g[:, :, i, j:j+1, :]
A[..., j] = torch.einsum('... c d, ... d -> ... c', q_i * (g_i - g_j).exp(), k_j)
*/
ggml_tensor * q_per = ggml_cont(ctx0, ggml_permute(ctx0, q, 1, 0, 2, 3));
ggml_tensor * q_j = ggml_reshape_4d(ctx0, q_per, 1, n_tokens, S_k, HB);
ggml_tensor * q_j_bc = ggml_repeat_4d(ctx0, q_j, n_tokens, n_tokens, S_k, HB);
ggml_tensor * kq = ggml_mul(ctx0, decay_mask, q_j_bc);
kq = ggml_mul(ctx0, kq, k_i_bc);
kq = ggml_cont(ctx0, ggml_permute(ctx0, kq, 1, 2, 0, 3));
ggml_tensor * Aqk = ggml_sum_rows(ctx0, kq);
Aqk = ggml_cont(ctx0, ggml_reshape_4d(ctx0, Aqk, n_tokens, n_tokens, H_k, n_seqs));
Aqk = ggml_mul(ctx0, Aqk, ggml_add(ctx0, identity, causal_mask));
Aqk = ggml_scale(ctx0, Aqk, scale); // scale q
cb(Aqk, "attn_decay_key", il);
ggml_tensor * state_t = ggml_cont(ctx0, ggml_transpose(ctx0, state));
// v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state or W_[t] @ S_[t]
ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay);
cb(v_prime, "v_prime", il);
// v_new = v_i - v_prime or U_[t] - W_[t]*S_[t]
ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, vb, v_prime), v_prime);
// v_new_t [T.S.H,B]
ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new));
cb(v_new, "v_new", il);
// attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
// or Gamma_[t]*Q_]t] @ S
ggml_tensor * q_gk_exp = ggml_mul(ctx0, q, gkexp);
ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_gk_exp);
// scale q at attn_inter as suggested in chunk_gla_fwd_kernel_o of
// github.com/fla-org/flash-linear-attention/fla/ops/gla/chunk.py
attn_inter = ggml_scale(ctx0, attn_inter, scale); // scale q
cb(attn_inter, "attn_inter", il);
// core_attn_out[:, :, i] = attn_inter + attn @ v_new or A' @ (U_[t] - W_[t]*S_[t])
ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, Aqk);
cb(v_attn, "v_attn", il);
// o[:, :, i] = (q_i * g_i.exp()) @ S + A @ v_i
ggml_tensor * core_attn_out = ggml_add(ctx0, attn_inter, v_attn);
cb(core_attn_out, "core_attn_out", il);
ggml_tensor * gk_cum_last =
ggml_cont(ctx0, ggml_view_4d(ctx0, gk_cumsum, gk_cumsum->ne[0], 1, gk_cumsum->ne[2], gk_cumsum->ne[3],
gk_cumsum->nb[1], gk_cumsum->nb[2], gk_cumsum->nb[3],
gk_cumsum->nb[1] * (gk_cumsum->ne[1] - 1)));
cb(gk_cum_last, "gk_cum_last", il);
ggml_tensor * gkexp_last = ggml_exp(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, gk_cum_last)));
cb(gkexp_last, "gkexp_last", il);
ggml_tensor * gk_diff = ggml_neg(ctx0, ggml_sub(ctx0, gk_cumsum, gk_cum_last));
cb(gk_diff, "gk_diff", il);
ggml_tensor * gk_diff_exp = ggml_exp(ctx0, gk_diff);
cb(gk_diff_exp, "gk_diff_exp", il);
ggml_tensor * key_gkdiff = ggml_mul(ctx0, k, gk_diff_exp);
cb(key_gkdiff, "key_gkdiff", il);
// rearrange((g_i[:,:,-1:] - g_i).exp()*k_i, 'b h c k -> b h k c') @ (U_[t] - W_[t] @ S)
ggml_tensor * kgkdmulvnew = ggml_mul_mat(ctx0, v_new_t, ggml_cont(ctx0, ggml_transpose(ctx0, key_gkdiff)));
cb(kgkdmulvnew, "kgkdmulvnew", il);
state = ggml_add(ctx0, ggml_mul(ctx0, state, gkexp_last), kgkdmulvnew);
cb(state, "new_state", il);
// flatten output
ggml_tensor * flat_output =
ggml_cont_1d(ctx0, ggml_permute(ctx0, core_attn_out, 0, 2, 1, 3), S_v * H_v * n_tokens * n_seqs);
ggml_tensor * flat_state = ggml_cont_1d(ctx0, state, S_v * S_v * H_v * n_seqs);
return ggml_concat(ctx0, flat_output, flat_state, 0);
}

View File

@ -287,6 +287,27 @@ struct llm_build_kimi_linear : public llm_graph_context_mamba {
llm_build_kimi_linear(const llama_model & model, const llm_graph_params & params);
private:
const llama_model & model;
ggml_tensor * build_kda_recurrent(
ggml_tensor * q,
ggml_tensor * k,
ggml_tensor * v,
ggml_tensor * g,
ggml_tensor * beta,
ggml_tensor * state,
ggml_tensor * causal_mask,
ggml_tensor * identity,
int il);
ggml_tensor * build_kda_chunking(
ggml_tensor * q,
ggml_tensor * k,
ggml_tensor * v,
ggml_tensor * g,
ggml_tensor * beta,
ggml_tensor * state,
ggml_tensor * causal_mask,
ggml_tensor * identity,
int il);
};
struct llm_build_lfm2 : public llm_graph_context {