diff --git a/src/models/kimi-linear.cpp b/src/models/kimi-linear.cpp index b0330e23b3..4831b7bbc7 100644 --- a/src/models/kimi-linear.cpp +++ b/src/models/kimi-linear.cpp @@ -263,34 +263,21 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll ggml_tensor * state = build_rs(inp_rs, ssm_states_all, hparams.n_embd_s(), n_seqs); state = ggml_reshape_4d(ctx0, state, head_dim, head_dim, n_head, n_seqs); // Choose between build_kda_chunking and build_kda_recurrent based on n_tokens - ggml_tensor * attn_out = n_seq_tokens == 1 ? + std::pair attn_out = n_seq_tokens == 1 ? build_kda_autoregressive(Qcur, Kcur, Vcur, g1, beta, state, il) : build_kda_chunking(Qcur, Kcur, Vcur, g1, beta, state, chunked_causal_mask, chunked_identity, chunked_diag_mask, il); - cb(attn_out, "attn_out", il); - // The tensors were concatenated 1d, so we need to extract them 1d as well - const int64_t output_flat_size = head_dim * n_head * n_seq_tokens * n_seqs; - ggml_tensor * attn_out_1d = ggml_view_1d(ctx0, attn_out, output_flat_size, 0); - cb(attn_out_1d, "attn_out_1d", il); + ggml_tensor * output = attn_out.first; + ggml_tensor * new_state = attn_out.second; + cb(output, "attn_output", il); + cb(new_state, "new_state", il); - ggml_tensor * attn_out_final = ggml_reshape_3d(ctx0, attn_out_1d, head_dim, n_head, n_seq_tokens * n_seqs); - cb(attn_out_final, "attn_out_reshaped", il); - // Extract the state part (second part of the concatenated tensor) - // State starts after n_tokens elements along dimension 1 - const int64_t state_flat_size = head_dim * head_dim * n_head * n_seqs; - - ggml_tensor * state_1d = - ggml_view_1d(ctx0, attn_out, state_flat_size, output_flat_size * ggml_element_size(attn_out)); - cb(state_1d, "state_1d", il); - - // Update the recurrent states - ggml_build_forward_expand(gf, - ggml_cpy(ctx0, state_1d, + // Update the recurrent states + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, new_state, ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs, kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); - GGML_ASSERT(ggml_nelements(attn_out_1d) + ggml_nelements(state_1d) == ggml_nelements(attn_out)); - // Step 7: Output gating g2 = g_b(g_a(x)) ggml_tensor * cur_2d = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs); ggml_tensor * g_a = ggml_mul_mat(ctx0, layer.ssm_g_a, cur_2d); @@ -301,6 +288,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll // Step 8: Apply o_norm with sigmoid gating // Note: Kimi model uses sigmoid gating, not SiLU (despite FusedRMSNormGated default being swish) // Formula: output = RMSNorm(x) * sigmoid(g) + ggml_tensor * attn_out_final = ggml_reshape_3d(ctx0, output, head_dim, n_head, n_seq_tokens * n_seqs); ggml_tensor * normed = build_norm(attn_out_final, layer.ssm_o_norm, layer.ssm_o_norm_b, LLM_NORM_RMS, il); cb(normed, "kda_normed", il); ggml_tensor * gate = ggml_sigmoid(ctx0, g2); @@ -496,7 +484,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll This is a ggml implementation of the naive_chunk_kda function of https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/kda/naive.py */ -ggml_tensor * llm_build_kimi_linear::build_kda_chunking( +std::pair llm_build_kimi_linear::build_kda_chunking( ggml_tensor * q, ggml_tensor * k, ggml_tensor * v, @@ -774,20 +762,23 @@ ggml_tensor * llm_build_kimi_linear::build_kda_chunking( core_attn_out = ggml_cont_4d(ctx0, core_attn_out, S_v, chunk_size * n_chunks, H_v, n_seqs); - ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out, S_v, n_tokens, H_v, n_seqs, core_attn_out->nb[1], core_attn_out->nb[2], core_attn_out->nb[3], 0); - cb(output_tokens, "output_tokens", il); + // truncate padded tokens + ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out, + S_v, n_tokens, H_v, n_seqs, + ggml_row_size(core_attn_out->type, S_v), + ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks), + ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks * H_v), 0); + output_tokens = ggml_cont(ctx0, output_tokens); + // permute back to (S_v, H_v, n_tokens, n_seqs) + output_tokens = ggml_permute(ctx0, output_tokens, 0, 2, 1, 3); + output_tokens = ggml_cont(ctx0, output_tokens); - // flatten output - ggml_tensor * flat_output = - ggml_cont_1d(ctx0, ggml_permute(ctx0, output_tokens, 0, 2, 1, 3), S_v * H_v * n_tokens * n_seqs); - - ggml_tensor * flat_state = ggml_cont_1d(ctx0, new_state, S_v * S_v * H_v * n_seqs); cb(new_state, "output_state", il); - return ggml_concat(ctx0, flat_output, flat_state, 0); + return {output_tokens, new_state}; } -ggml_tensor * llm_build_kimi_linear::build_kda_autoregressive( +std::pair llm_build_kimi_linear::build_kda_autoregressive( ggml_tensor * q, ggml_tensor * k, ggml_tensor * v, @@ -876,10 +867,6 @@ ggml_tensor * llm_build_kimi_linear::build_kda_autoregressive( cb(core_attn_out, "output_tokens", il); cb(state, "new_state", il); - // flatten output, no need to permute since n_tokens is 1 so [S_v, 1, H_v, n_seqs] and [S_v, H_v, 1, n_seqs] are equivalent memory-layout wise - ggml_tensor * flat_output = ggml_reshape_1d(ctx0, core_attn_out, S_v * H_v * n_tokens * n_seqs); - ggml_tensor * flat_state = ggml_reshape_1d(ctx0, state, S_v * S_v * H_v * n_seqs); - - return ggml_concat(ctx0, flat_output, flat_state, 0); + return {core_attn_out, state}; } diff --git a/src/models/models.h b/src/models/models.h index 549329e15a..8e8f502e78 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -287,7 +287,7 @@ struct llm_build_jamba : public llm_graph_context_mamba { struct llm_build_kimi_linear : public llm_graph_context_mamba { llm_build_kimi_linear(const llama_model & model, const llm_graph_params & params); private: - ggml_tensor * build_kda_autoregressive( + std::pair build_kda_autoregressive( ggml_tensor * q, ggml_tensor * k, ggml_tensor * v, @@ -296,7 +296,7 @@ private: ggml_tensor * state, int il); - ggml_tensor * build_kda_chunking( + std::pair build_kda_chunking( ggml_tensor * q, ggml_tensor * k, ggml_tensor * v,