From c8953657c474b2ad52bafe77ebac6ca9de606df9 Mon Sep 17 00:00:00 2001 From: Aes Sedai <7980540+AesSedai@users.noreply.github.com> Date: Sun, 8 Feb 2026 12:18:49 -0800 Subject: [PATCH] Kimi-K2.5: remove v/o permutes, unnecessary --- convert_hf_to_gguf.py | 17 +---------------- tools/mtmd/models/kimik25.cpp | 7 ++----- 2 files changed, 3 insertions(+), 21 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index c58dd91d9d..5a3f74812e 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -11122,14 +11122,6 @@ class KimiK25Model(MmprojModel): w = w.permute(0, 2, 1, 3, 4) return w.reshape(out_dim, in_dim) - @staticmethod - def _permute_output_proj(weights: Tensor, n_head: int) -> Tensor: - out_dim, in_dim = weights.shape - head_dim = in_dim // n_head - w = weights.reshape(out_dim, n_head, head_dim // 4, 2, 2) - w = w.permute(0, 1, 3, 2, 4) - return w.reshape(out_dim, in_dim) - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: # Only process vision and projector tensors is_vision = any(x in name for x in ["vision_tower", "mm_projector"]) @@ -11140,10 +11132,8 @@ class KimiK25Model(MmprojModel): assert self.hparams_vision is not None n_head = self.hparams_vision.get("num_attention_heads", 16) - # Permute Q/K/V weights/biases from interleaved to split RoPE format + # Permute Q/K weights/biases from interleaved to split RoPE format # This allows using build_rope_2d at runtime without post-permutation. - # V is also permuted so the attention output is in split format, - # which is then handled by the permuted output projection. if "wqkv" in name: out_dim = data_torch.shape[0] qkv_dim = out_dim // 3 @@ -11153,18 +11143,13 @@ class KimiK25Model(MmprojModel): wq, wk, wv = data_torch[:qkv_dim, :], data_torch[qkv_dim:2*qkv_dim, :], data_torch[2*qkv_dim:, :] wq = self._permute_kqv(wq, n_head) wk = self._permute_kqv(wk, n_head) - wv = self._permute_kqv(wv, n_head) data_torch = torch.cat([wq, wk, wv], dim=0) elif "bias" in name: bq, bk, bv = data_torch[:qkv_dim], data_torch[qkv_dim:2*qkv_dim], data_torch[2*qkv_dim:] bq = bq.reshape(n_head, head_dim // 4, 2, 2).permute(0, 2, 1, 3).reshape(-1) bk = bk.reshape(n_head, head_dim // 4, 2, 2).permute(0, 2, 1, 3).reshape(-1) - bv = bv.reshape(n_head, head_dim // 4, 2, 2).permute(0, 2, 1, 3).reshape(-1) data_torch = torch.cat([bq, bk, bv], dim=0) - # Permute output projection from interleaved to split RoPE format - if "wo.weight" in name: - data_torch = self._permute_output_proj(data_torch, n_head) # Temporal embeddings: (T, 1, C) → (T, C) if "pos_emb.time_weight" in name: diff --git a/tools/mtmd/models/kimik25.cpp b/tools/mtmd/models/kimik25.cpp index 5f5cd9b7ed..cf9f27f63a 100644 --- a/tools/mtmd/models/kimik25.cpp +++ b/tools/mtmd/models/kimik25.cpp @@ -42,11 +42,8 @@ ggml_cgraph * clip_graph_kimik25::build() { ggml_tensor * learned_pos_embd = resize_position_embeddings_3d(GGML_SCALE_MODE_BICUBIC); - // Kimi-K2.5 uses interleaved 2D RoPE pattern natively, but all attention weights - // (Q, K, V, O) are permuted during conversion to use split format throughout. - // This allows using build_rope_2d without any runtime format conversion. - // The dot product in attention is order-independent, so keeping everything in - // split format produces mathematically equivalent results. + // Kimi-K2.5 uses interleaved 2D RoPE pattern natively, but + // Q / K are permuted during conversion to use split format. auto add_pos = [&](ggml_tensor * cur, const clip_layer &) { cur = build_rope_2d(ctx0, cur, pos_w, pos_h, hparams.rope_theta, false); return cur;