return ggml_tensor * pair in kda_autoregressive and kda_chunking as in ngxson's Qwen3Next improvement
This commit is contained in:
parent
4faf26c376
commit
22bc582a82
|
|
@ -263,34 +263,21 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
|
|||
ggml_tensor * state = build_rs(inp_rs, ssm_states_all, hparams.n_embd_s(), n_seqs);
|
||||
state = ggml_reshape_4d(ctx0, state, head_dim, head_dim, n_head, n_seqs);
|
||||
// Choose between build_kda_chunking and build_kda_recurrent based on n_tokens
|
||||
ggml_tensor * attn_out = n_seq_tokens == 1 ?
|
||||
std::pair<ggml_tensor *, ggml_tensor *> attn_out = n_seq_tokens == 1 ?
|
||||
build_kda_autoregressive(Qcur, Kcur, Vcur, g1, beta, state, il) :
|
||||
build_kda_chunking(Qcur, Kcur, Vcur, g1, beta, state, chunked_causal_mask, chunked_identity, chunked_diag_mask, il);
|
||||
cb(attn_out, "attn_out", il);
|
||||
|
||||
// The tensors were concatenated 1d, so we need to extract them 1d as well
|
||||
const int64_t output_flat_size = head_dim * n_head * n_seq_tokens * n_seqs;
|
||||
ggml_tensor * attn_out_1d = ggml_view_1d(ctx0, attn_out, output_flat_size, 0);
|
||||
cb(attn_out_1d, "attn_out_1d", il);
|
||||
ggml_tensor * output = attn_out.first;
|
||||
ggml_tensor * new_state = attn_out.second;
|
||||
cb(output, "attn_output", il);
|
||||
cb(new_state, "new_state", il);
|
||||
|
||||
ggml_tensor * attn_out_final = ggml_reshape_3d(ctx0, attn_out_1d, head_dim, n_head, n_seq_tokens * n_seqs);
|
||||
cb(attn_out_final, "attn_out_reshaped", il);
|
||||
// Extract the state part (second part of the concatenated tensor)
|
||||
// State starts after n_tokens elements along dimension 1
|
||||
const int64_t state_flat_size = head_dim * head_dim * n_head * n_seqs;
|
||||
|
||||
ggml_tensor * state_1d =
|
||||
ggml_view_1d(ctx0, attn_out, state_flat_size, output_flat_size * ggml_element_size(attn_out));
|
||||
cb(state_1d, "state_1d", il);
|
||||
|
||||
// Update the recurrent states
|
||||
ggml_build_forward_expand(gf,
|
||||
ggml_cpy(ctx0, state_1d,
|
||||
// Update the recurrent states
|
||||
ggml_build_forward_expand(gf,
|
||||
ggml_cpy(ctx0, new_state,
|
||||
ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs,
|
||||
kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all))));
|
||||
|
||||
GGML_ASSERT(ggml_nelements(attn_out_1d) + ggml_nelements(state_1d) == ggml_nelements(attn_out));
|
||||
|
||||
// Step 7: Output gating g2 = g_b(g_a(x))
|
||||
ggml_tensor * cur_2d = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs);
|
||||
ggml_tensor * g_a = ggml_mul_mat(ctx0, layer.ssm_g_a, cur_2d);
|
||||
|
|
@ -301,6 +288,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
|
|||
// Step 8: Apply o_norm with sigmoid gating
|
||||
// Note: Kimi model uses sigmoid gating, not SiLU (despite FusedRMSNormGated default being swish)
|
||||
// Formula: output = RMSNorm(x) * sigmoid(g)
|
||||
ggml_tensor * attn_out_final = ggml_reshape_3d(ctx0, output, head_dim, n_head, n_seq_tokens * n_seqs);
|
||||
ggml_tensor * normed = build_norm(attn_out_final, layer.ssm_o_norm, layer.ssm_o_norm_b, LLM_NORM_RMS, il);
|
||||
cb(normed, "kda_normed", il);
|
||||
ggml_tensor * gate = ggml_sigmoid(ctx0, g2);
|
||||
|
|
@ -496,7 +484,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
|
|||
This is a ggml implementation of the naive_chunk_kda function of
|
||||
https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/kda/naive.py
|
||||
*/
|
||||
ggml_tensor * llm_build_kimi_linear::build_kda_chunking(
|
||||
std::pair<ggml_tensor *, ggml_tensor *> llm_build_kimi_linear::build_kda_chunking(
|
||||
ggml_tensor * q,
|
||||
ggml_tensor * k,
|
||||
ggml_tensor * v,
|
||||
|
|
@ -774,20 +762,23 @@ ggml_tensor * llm_build_kimi_linear::build_kda_chunking(
|
|||
|
||||
core_attn_out = ggml_cont_4d(ctx0, core_attn_out, S_v, chunk_size * n_chunks, H_v, n_seqs);
|
||||
|
||||
ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out, S_v, n_tokens, H_v, n_seqs, core_attn_out->nb[1], core_attn_out->nb[2], core_attn_out->nb[3], 0);
|
||||
cb(output_tokens, "output_tokens", il);
|
||||
// truncate padded tokens
|
||||
ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out,
|
||||
S_v, n_tokens, H_v, n_seqs,
|
||||
ggml_row_size(core_attn_out->type, S_v),
|
||||
ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks),
|
||||
ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks * H_v), 0);
|
||||
output_tokens = ggml_cont(ctx0, output_tokens);
|
||||
// permute back to (S_v, H_v, n_tokens, n_seqs)
|
||||
output_tokens = ggml_permute(ctx0, output_tokens, 0, 2, 1, 3);
|
||||
output_tokens = ggml_cont(ctx0, output_tokens);
|
||||
|
||||
// flatten output
|
||||
ggml_tensor * flat_output =
|
||||
ggml_cont_1d(ctx0, ggml_permute(ctx0, output_tokens, 0, 2, 1, 3), S_v * H_v * n_tokens * n_seqs);
|
||||
|
||||
ggml_tensor * flat_state = ggml_cont_1d(ctx0, new_state, S_v * S_v * H_v * n_seqs);
|
||||
cb(new_state, "output_state", il);
|
||||
|
||||
return ggml_concat(ctx0, flat_output, flat_state, 0);
|
||||
return {output_tokens, new_state};
|
||||
}
|
||||
|
||||
ggml_tensor * llm_build_kimi_linear::build_kda_autoregressive(
|
||||
std::pair<ggml_tensor *, ggml_tensor *> llm_build_kimi_linear::build_kda_autoregressive(
|
||||
ggml_tensor * q,
|
||||
ggml_tensor * k,
|
||||
ggml_tensor * v,
|
||||
|
|
@ -876,10 +867,6 @@ ggml_tensor * llm_build_kimi_linear::build_kda_autoregressive(
|
|||
cb(core_attn_out, "output_tokens", il);
|
||||
cb(state, "new_state", il);
|
||||
|
||||
// flatten output, no need to permute since n_tokens is 1 so [S_v, 1, H_v, n_seqs] and [S_v, H_v, 1, n_seqs] are equivalent memory-layout wise
|
||||
ggml_tensor * flat_output = ggml_reshape_1d(ctx0, core_attn_out, S_v * H_v * n_tokens * n_seqs);
|
||||
ggml_tensor * flat_state = ggml_reshape_1d(ctx0, state, S_v * S_v * H_v * n_seqs);
|
||||
|
||||
return ggml_concat(ctx0, flat_output, flat_state, 0);
|
||||
return {core_attn_out, state};
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -287,7 +287,7 @@ struct llm_build_jamba : public llm_graph_context_mamba {
|
|||
struct llm_build_kimi_linear : public llm_graph_context_mamba {
|
||||
llm_build_kimi_linear(const llama_model & model, const llm_graph_params & params);
|
||||
private:
|
||||
ggml_tensor * build_kda_autoregressive(
|
||||
std::pair<ggml_tensor *, ggml_tensor *> build_kda_autoregressive(
|
||||
ggml_tensor * q,
|
||||
ggml_tensor * k,
|
||||
ggml_tensor * v,
|
||||
|
|
@ -296,7 +296,7 @@ private:
|
|||
ggml_tensor * state,
|
||||
int il);
|
||||
|
||||
ggml_tensor * build_kda_chunking(
|
||||
std::pair<ggml_tensor *, ggml_tensor *> build_kda_chunking(
|
||||
ggml_tensor * q,
|
||||
ggml_tensor * k,
|
||||
ggml_tensor * v,
|
||||
|
|
|
|||
Loading…
Reference in New Issue