From 6ae66fc40dcbd9562ef71ebe5cd3a7bc9686e385 Mon Sep 17 00:00:00 2001 From: Yee Man Chan Date: Sun, 11 Jan 2026 21:31:35 +0800 Subject: [PATCH] fix trailing spaces --- gguf-py/gguf/tensor_mapping.py | 4 +- src/llama-vocab.cpp | 4 +- src/models/kimi-linear.cpp | 82 ++++++++++++++++------------------ 3 files changed, 42 insertions(+), 48 deletions(-) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 88e2caf541..c4957a7b20 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -819,13 +819,13 @@ class TensorNameMap: # Kimi Linear KDA (using SSM_ prefix for consistency) MODEL_TENSOR.SSM_CONV1D_Q: ( "model.layers.{bid}.self_attn.q_conv1d", - ), + ), MODEL_TENSOR.SSM_CONV1D_K: ( "model.layers.{bid}.self_attn.k_conv1d", ), MODEL_TENSOR.SSM_CONV1D_V: ( "model.layers.{bid}.self_attn.v_conv1d", - ), + ), MODEL_TENSOR.SSM_F_A: ( "model.layers.{bid}.self_attn.f_a_proj", ), diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index eaa574f3b8..f7a264dc60 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -1747,7 +1747,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { const int merges_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_MERGES).c_str()); // Kimi-K2 uses custom tokenization without traditional BPE merges const bool is_kimi_k2 = (tokenizer_pre == "kimi-k2"); - + if (merges_keyidx == -1) { if (!is_kimi_k2) { throw std::runtime_error("cannot find tokenizer merges in model file\n"); @@ -1768,7 +1768,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { if (pos != std::string::npos) { first = word.substr(0, pos); second = word.substr(pos + 1); - } + } bpe_ranks.emplace(std::make_pair(first, second), i); } diff --git a/src/models/kimi-linear.cpp b/src/models/kimi-linear.cpp index 9d83ca8fa5..e873024c90 100644 --- a/src/models/kimi-linear.cpp +++ b/src/models/kimi-linear.cpp @@ -12,7 +12,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll // Note: Kimi MLA does NOT use RoPE (rotary_emb=None in vLLM) // So we don't need inp_pos - + auto * inp = build_inp_mem_hybrid(); auto * inp_rs = inp->get_recr(); auto * inp_attn = inp->get_attn(); @@ -38,12 +38,12 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll const int64_t d_inner = n_head * head_dim; // 32 * 128 = 4096 const int64_t n_seqs = ubatch.n_seqs; const int64_t n_seq_tokens = ubatch.n_seq_tokens; - + // Verify batch consistency for recurrent layers GGML_ASSERT(n_seqs != 0); GGML_ASSERT(ubatch.equal_seqs()); GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); - + // MLA params const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla; const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla; @@ -67,14 +67,13 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll // KDA layers have ssm_a_log tensor, MLA layers have wkv_a_mqa tensor bool is_kda = (layer.ssm_a_log != nullptr); bool is_mla = (layer.wkv_a_mqa != nullptr); - + if (is_kda) { // === KDA Layer (Kimi Delta Attention) with Recurrent State === // Reference: vLLM kda.py - const auto * mctx_cur = inp_rs->mctx; const auto kv_head = mctx_cur->get_head(); - + // Get conv states from r_l tensor (Q, K, V each have separate state) ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); cb(conv_states_all, "conv_states_all", il); @@ -85,7 +84,6 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll // Each conv state is [(d_conv-1) * d_inner] per sequence, need to reshape to [d_conv-1, d_inner, n_seqs] // Memory layout: for each seq, Q state is first conv_state_size elements, then K, then V // conv_state_all has stride: nb[0] = element_size, nb[1] = n_embd_r_total * element_size - // View Q conv state: offset 0, size conv_state_size per seq // conv_state_all is [n_embd_r_total, n_seqs] with memory layout: // state[i + seq * n_embd_r_total] where i = conv_step + channel * (d_conv-1) + {0, conv_state_size, 2*conv_state_size} for Q/K/V @@ -104,7 +102,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll (d_conv - 1) * ggml_element_size(conv_state_all), n_embd_r_total * ggml_element_size(conv_state_all), 2 * conv_state_size * ggml_element_size(conv_state_all)); // offset for V - + // Step 1: Q, K, V projections -> [d_inner, n_tokens] ggml_tensor * q_proj = ggml_mul_mat(ctx0, layer.wq, cur); ggml_tensor * k_proj = ggml_mul_mat(ctx0, layer.wk, cur); @@ -112,14 +110,14 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll cb(q_proj, "kda_q_proj", il); cb(k_proj, "kda_k_proj", il); cb(v_proj, "kda_v_proj", il); - + // Step 2: Causal Conv1d for Q // Reshape input: {d_inner, n_tokens} -> {d_inner, n_seq_tokens, n_seqs} ggml_tensor * q_3d = ggml_reshape_3d(ctx0, q_proj, d_inner, n_seq_tokens, n_seqs); - + // Concat Q conv state and current input: {d_conv-1 + n_seq_tokens, d_inner, n_seqs} ggml_tensor * conv_q = ggml_concat(ctx0, conv_state_q, ggml_transpose(ctx0, q_3d), 0); - + // Save last (d_conv-1) columns back to Q conv state ggml_tensor * last_conv_q = ggml_view_3d(ctx0, conv_q, d_conv - 1, d_inner, n_seqs, conv_q->nb[1], conv_q->nb[2], n_seq_tokens * conv_q->nb[0]); @@ -127,7 +125,6 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll ggml_cpy(ctx0, last_conv_q, ggml_view_1d(ctx0, conv_states_all, conv_state_size * n_seqs, kv_head * n_embd_r_total * ggml_element_size(conv_states_all)))); - // Reshape conv weight: GGUF [d_conv, 1, d_inner, 1] -> ggml_ssm_conv expects [d_conv, d_inner] // GGUF stores as [d_conv, 1, d_inner, 1] with memory layout w[conv_step + channel * d_conv] // vLLM stores as [d_inner, d_conv] with memory layout w[channel * d_conv + conv_step] @@ -143,13 +140,13 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll } conv_weight = ggml_reshape_2d(ctx0, q_conv_f32, d_conv, d_inner); } - + // Apply conv1d ggml_tensor * Qcur; if (conv_weight) { // Make conv_q contiguous for ggml_ssm_conv conv_q = ggml_cont(ctx0, conv_q); - + // ggml_ssm_conv output: {d_inner, n_seq_tokens, n_seqs} Qcur = ggml_ssm_conv(ctx0, conv_q, conv_weight); cb(Qcur, "Q conv1d", il); @@ -163,13 +160,13 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll } else { GGML_ABORT("KDA layer missing Q conv weight"); } - + // K conv1d (with separate K conv state) ggml_tensor * Kcur; if (layer.ssm_k_conv) { ggml_tensor * k_3d = ggml_reshape_3d(ctx0, k_proj, d_inner, n_seq_tokens, n_seqs); ggml_tensor * conv_k = ggml_cont(ctx0, ggml_concat(ctx0, conv_state_k, ggml_transpose(ctx0, k_3d), 0)); - + // Save K conv state ggml_tensor * last_conv_k = ggml_view_3d(ctx0, conv_k, d_conv - 1, d_inner, n_seqs, conv_k->nb[1], conv_k->nb[2], n_seq_tokens * conv_k->nb[0]); @@ -177,7 +174,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll ggml_cpy(ctx0, last_conv_k, ggml_view_1d(ctx0, conv_states_all, conv_state_size * n_seqs, (kv_head * n_embd_r_total + conv_state_size) * ggml_element_size(conv_states_all)))); - + ggml_tensor * k_conv_f32 = layer.ssm_k_conv; if (k_conv_f32->type != GGML_TYPE_F32) { k_conv_f32 = ggml_cast(ctx0, k_conv_f32, GGML_TYPE_F32); @@ -194,13 +191,13 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll } else { GGML_ABORT("KDA layer missing K conv weight"); } - + // V conv1d (with separate V conv state) ggml_tensor * Vcur; if (layer.ssm_v_conv) { ggml_tensor * v_3d = ggml_reshape_3d(ctx0, v_proj, d_inner, n_seq_tokens, n_seqs); ggml_tensor * conv_v = ggml_cont(ctx0, ggml_concat(ctx0, conv_state_v, ggml_transpose(ctx0, v_3d), 0)); - + // Save V conv state ggml_tensor * last_conv_v = ggml_view_3d(ctx0, conv_v, d_conv - 1, d_inner, n_seqs, conv_v->nb[1], conv_v->nb[2], n_seq_tokens * conv_v->nb[0]); @@ -208,7 +205,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll ggml_cpy(ctx0, last_conv_v, ggml_view_1d(ctx0, conv_states_all, conv_state_size * n_seqs, (kv_head * n_embd_r_total + 2 * conv_state_size) * ggml_element_size(conv_states_all)))); - + ggml_tensor * v_conv_f32 = layer.ssm_v_conv; if (v_conv_f32->type != GGML_TYPE_F32) { v_conv_f32 = ggml_cast(ctx0, v_conv_f32, GGML_TYPE_F32); @@ -225,7 +222,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll } else { GGML_ABORT("KDA layer missing V conv weight"); } - + // Step 3: Compute g1 (forget gate) // g1 = -exp(A_log) * softplus(f_b(f_a(x)) + dt_bias) ggml_tensor * f_a = ggml_mul_mat(ctx0, layer.ssm_f_a, cur); @@ -234,7 +231,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll g1 = ggml_add(ctx0, g1, layer.ssm_dt_b); g1 = ggml_softplus(ctx0, g1); g1 = ggml_reshape_3d(ctx0, g1, head_dim, n_head, n_tokens); - + // A_log shape is [1, n_head] or [1, n_head, 1, 1], need to broadcast to [head_dim, n_head, n_tokens] // First compute -exp(A_log), then reshape for broadcasting ggml_tensor * A_neg_exp = ggml_neg(ctx0, ggml_exp(ctx0, layer.ssm_a_log)); @@ -242,16 +239,16 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll A_neg_exp = ggml_reshape_3d(ctx0, A_neg_exp, 1, n_head, 1); g1 = ggml_mul(ctx0, g1, A_neg_exp); cb(g1, "kda_g1", il); - + // Step 4: Compute beta (mixing coefficient) ggml_tensor * beta = ggml_mul_mat(ctx0, layer.ssm_beta, cur); beta = ggml_cont_4d(ctx0, beta, n_head, 1, n_seq_tokens, n_seqs); cb(beta, "kda_beta", il); - + // Step 5: Reshape for KDA recurrence // {n_embd, n_tokens} -> {n_embd, n_seq_tokens, n_seqs} cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); - + Qcur = ggml_cont(ctx0, ggml_reshape_4d(ctx0, Qcur, head_dim, n_head, n_seq_tokens, n_seqs)); Kcur = ggml_cont(ctx0, ggml_reshape_4d(ctx0, Kcur, head_dim, n_head, n_seq_tokens, n_seqs)); Vcur = ggml_cont(ctx0, ggml_reshape_4d(ctx0, Vcur, head_dim, n_head, n_seq_tokens, n_seqs)); @@ -274,7 +271,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll const int64_t output_flat_size = head_dim * n_head * n_seq_tokens * n_seqs; ggml_tensor * attn_out_1d = ggml_view_1d(ctx0, attn_out, output_flat_size, 0); cb(attn_out_1d, "attn_out_1d", il); - + ggml_tensor * attn_out_final = ggml_reshape_3d(ctx0, attn_out_1d, head_dim, n_head, n_seq_tokens * n_seqs); cb(attn_out_final, "attn_out_reshaped", il); // Extract the state part (second part of the concatenated tensor) @@ -299,7 +296,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll ggml_tensor * g2 = ggml_mul_mat(ctx0, layer.ssm_g_b, g_a); cb(g2, "g2 g_b(g_a(cur_2d))", il); g2 = ggml_reshape_3d(ctx0, g2, head_dim, n_head, n_seq_tokens * n_seqs); - + // Step 8: Apply o_norm with sigmoid gating // Note: Kimi model uses sigmoid gating, not SiLU (despite FusedRMSNormGated default being swish) // Formula: output = RMSNorm(x) * sigmoid(g) @@ -307,7 +304,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll cb(normed, "kda_normed", il); ggml_tensor * gate = ggml_sigmoid(ctx0, g2); ggml_tensor * gated = ggml_mul(ctx0, normed, gate); - + // Step 9: Output projection gated = ggml_cont_2d(ctx0, gated, d_inner, n_tokens); cur = ggml_mul_mat(ctx0, layer.wo, gated); @@ -316,7 +313,6 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll } else if (is_mla) { // === MLA Layer (Multi-head Latent Attention) without KV Cache === // Reference: vLLM mla.py - // Step 1: Q projection and reshape // vLLM Kimi: q = q_proj(hidden_states), then view as [n_tokens, n_head, qk_head_dim] // Note: Kimi MLA does NOT use RoPE (rotary_emb=None in vLLM) @@ -325,7 +321,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll // Step 2: KV compression // kv_cmpr_pe = kv_a_proj_with_mqa(hidden_states) -> [kv_lora_rank + qk_rope_head_dim, n_tokens] ggml_tensor * kv_cmpr_pe = ggml_mul_mat(ctx0, layer.wkv_a_mqa, cur); - + // Split: kv_cmpr = kv_lora[:kv_lora_rank], k_pe = kv_lora[kv_lora_rank:] ggml_tensor * kv_cmpr = ggml_view_2d(ctx0, kv_cmpr_pe, kv_lora_rank, n_tokens, ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), 0); @@ -333,10 +329,8 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), ggml_row_size(kv_cmpr_pe->type, kv_lora_rank)); - // Note: Kimi MLA does NOT apply RoPE (rotary_emb=None in vLLM) // k_pe is used directly without RoPE - // Normalize kv_c kv_cmpr = build_norm(kv_cmpr, layer.attn_kv_a_norm, nullptr, LLM_NORM_RMS, il); @@ -346,7 +340,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll ggml_view_3d(ctx0, Qcur, n_embd_head_qk_nope, n_head, n_tokens, ggml_row_size(Qcur->type, n_embd_head_k_mla), ggml_row_size(Qcur->type, n_embd_head_k_mla) * n_head, 0); cb(q_nope, "q_nope", il); - + // and {n_embd_head_qk_rope, n_head, n_tokens} ggml_tensor * q_pe = ggml_view_3d( ctx0, Qcur, n_embd_head_qk_rope, n_head, n_tokens, ggml_row_size(Qcur->type, n_embd_head_k_mla), @@ -389,7 +383,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll // KV decompression: kv = kv_b_proj(kv_c_normed) ggml_tensor * kv = ggml_mul_mat(ctx0, layer.wkv_b, kv_cmpr); const int64_t kv_per_head = n_embd_head_qk_nope + n_embd_head_v_mla; - + // Split kv into k_nope and v ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens, ggml_row_size(kv->type, kv_per_head), @@ -401,7 +395,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll k_nope = ggml_cont(ctx0, k_nope); Vcur = ggml_cont(ctx0, Vcur); cb(Vcur, "mla_V", il); - + // Concatenate k_nope + k_pe (broadcast k_pe to all heads) // K = [k_nope, k_pe] where k_nope is [qk_nope_head_dim, n_head, n_tokens] // and k_pe is [qk_rope_head_dim, 1, n_tokens] broadcast to all heads @@ -410,7 +404,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll ggml_tensor * k_pe_repeated = ggml_repeat(ctx0, k_pe, k_pe_target); ggml_tensor * Kcur = ggml_concat(ctx0, k_nope, k_pe_repeated, 0); cb(Kcur, "mla_K", il); - + // Direct softmax attention (with MHA KV cache) // Use build_attn with inp_attn for proper mask handling cur = build_attn(inp_attn, layer.wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale_mla, il); @@ -420,13 +414,13 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll // Unknown layer type - this should not happen GGML_ABORT("Kimi layer is neither KDA nor MLA - missing required tensors"); } - + // On last layer, select only the output tokens if (il == n_layer - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } - + // Residual ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); cb(ffn_inp, "ffn_inp", il); @@ -459,7 +453,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll (llama_expert_gating_func_type) hparams.expert_gating_func, il); cb(moe_out, "ffn_moe_out", il); - + // Shared expert { ggml_tensor * ffn_shexp = build_ffn(cur, @@ -468,7 +462,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll layer.ffn_down_shexp, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); cb(ffn_shexp, "ffn_shexp", il); - + cur = ggml_add(ctx0, moe_out, ffn_shexp); cb(cur, "ffn_out", il); } @@ -663,7 +657,7 @@ ggml_tensor * llm_build_kimi_linear::build_kda_chunking( Aqk = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, Aqk, chunk_size, chunk_size, n_chunks, HB))); cb(Akk, "Akk", il); cb(Aqk, "Aqk", il); - + Akk = ggml_mul(ctx0, Akk, beta); Akk = ggml_neg(ctx0, ggml_mul(ctx0, Akk, causal_mask)); cb(Akk, "attn_pre_solve", il); @@ -798,15 +792,15 @@ ggml_tensor * llm_build_kimi_linear::build_kda_autoregressive( ggml_tensor * v, ggml_tensor * gk, ggml_tensor * beta, - ggml_tensor * state, + ggml_tensor * state, int il) { GGML_ASSERT(ggml_is_contiguous(q)); GGML_ASSERT(ggml_is_contiguous(k)); - GGML_ASSERT(ggml_is_contiguous(v)); + GGML_ASSERT(ggml_is_contiguous(v)); GGML_ASSERT(ggml_is_contiguous(gk)); GGML_ASSERT(ggml_is_contiguous(beta)); GGML_ASSERT(ggml_is_contiguous(state)); - + const int64_t S_k = q->ne[0]; const int64_t H_k = q->ne[1]; const int64_t n_tokens = q->ne[2];