From 042c3cb8c5cdd857e55300630f405bb5041afd4b Mon Sep 17 00:00:00 2001 From: Aes Sedai <7980540+AesSedai@users.noreply.github.com> Date: Wed, 28 Jan 2026 22:06:59 -0800 Subject: [PATCH 01/12] Move dequant_model to after the text_config merge Add new kimi-k2.5 keys to mtmd convert Update V_MMPROJ tensor mapping for new mm_projector.proj keys Update V_M_IMP_NORM for new mm_projector.pre_norm key --- convert_hf_to_gguf.py | 41 ++++++++++++++++++++++++---------- gguf-py/gguf/tensor_mapping.py | 2 ++ 2 files changed, 31 insertions(+), 12 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index eb43520f98..8e293cd970 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -160,8 +160,6 @@ class ModelBase: self.ftype = gguf.LlamaFileType.MOSTLY_F16 logger.info("heuristics unable to detect tensor dtype, defaulting to --outtype f16") - self.dequant_model() - # Configure GGUF Writer self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file, split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard) @@ -527,6 +525,8 @@ class ModelBase: return () def prepare_tensors(self): + self.dequant_model() + # Handle empty tensor_map for models with block_count=0 (like MobileNetV5) if self.tensor_map.mapping: max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,") @@ -1808,7 +1808,7 @@ class MmprojModel(ModelBase): preprocessor_config: dict[str, Any] global_config: dict[str, Any] - n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth", "encoder_layers"] + n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth", "encoder_layers", "vt_num_hidden_layers"] has_vision_encoder: bool = True # by default has_audio_encoder: bool = False @@ -1863,7 +1863,15 @@ class MmprojModel(ModelBase): preprocessor_config_path = self.dir_model / "preprocessor_config.json" if preprocessor_config_path.is_file(): with open(preprocessor_config_path, "r", encoding="utf-8") as f: - self.preprocessor_config = json.load(f) + cfg = json.load(f) + # move media_proc_cfg to root level for compat + if "media_proc_cfg" in cfg: + cfg = { + **cfg, + **cfg["media_proc_cfg"], + } + # merge configs + self.preprocessor_config = {**self.preprocessor_config, **cfg} # prefer processor_config.json if possible processor_config_path = self.dir_model / "processor_config.json" @@ -1912,10 +1920,10 @@ class MmprojModel(ModelBase): self.image_size = self.find_vparam(["image_size"]) self.gguf_writer.add_vision_image_size(self.image_size) self.gguf_writer.add_vision_patch_size(self.find_vparam(["patch_size"])) - self.gguf_writer.add_vision_embedding_length(self.find_vparam(["hidden_size"])) - self.gguf_writer.add_vision_feed_forward_length(self.find_vparam(["intermediate_size"])) + self.gguf_writer.add_vision_embedding_length(self.find_vparam(["hidden_size", "vt_hidden_size"])) + self.gguf_writer.add_vision_feed_forward_length(self.find_vparam(["intermediate_size", "vt_intermediate_size"])) self.gguf_writer.add_vision_block_count(self.find_vparam(self.n_block_keys)) - self.gguf_writer.add_vision_head_count(self.find_vparam(["num_attention_heads", "num_heads"])) + self.gguf_writer.add_vision_head_count(self.find_vparam(["num_attention_heads", "num_heads", "vt_num_attention_heads"])) # preprocessor config image_mean = _MISTRAL_COMMON_DATASET_MEAN if self.is_mistral_format else self.preprocessor_config["image_mean"] @@ -7360,6 +7368,7 @@ class DeepseekModel(TextModel): "DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM", "KimiVLForConditionalGeneration", + "KimiK25ForConditionalGeneration", "YoutuForCausalLM", "YoutuVLForConditionalGeneration", ) @@ -7478,8 +7487,8 @@ class DeepseekV2Model(TextModel): _experts: list[dict[str, Tensor]] | None = None def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - # skip vision tensors and remove "language_model." for Kimi-VL - if "vision_tower" in name or "multi_modal_projector" in name: + # skip vision tensors and remove "language_model." for Kimi-VL and Kimi-K2.5 + if "vision_tower" in name or "multi_modal_projector" in name or "mm_projector" in name: return if name.startswith("siglip2.") or name.startswith("merger."): return @@ -10614,7 +10623,7 @@ class MistralMoeModel(DeepseekV2Model): self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1) # mscale_all_dim * 0.1 def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): - if name.startswith("vision_") or name.startswith("patch_merger.") or "mm_projector" in name: + if name.startswith("vision_") or name.startswith("patch_merger."): return # rename certain tensors so that we can reuse DeepseekV2Model modify_tensors logic @@ -10679,7 +10688,7 @@ class LightOnOCRVisionModel(LlavaVisionModel): yield from super().modify_tensors(data_torch, name, bid) -@ModelBase.register("KimiVLForConditionalGeneration") +@ModelBase.register("KimiVLForConditionalGeneration", "KimiK25ForConditionalGeneration") class KimiVLModel(MmprojModel): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -10696,9 +10705,17 @@ class KimiVLModel(MmprojModel): self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams_vision.get("layer_norm_eps", 1e-5)) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - is_vision_tensor = "vision_tower" in name or "multi_modal_projector" in name + is_vision_tensor = "vision_tower" in name or "multi_modal_projector" in name or "mm_projector" in name if is_vision_tensor: + # update names: + # "mm_projector.proj.0" -> "mm_projector.proj.linear_1.", + # "mm_projector.proj.2" -> "mm_projector.proj.linear_2.", + if "proj.0." in name: + name = name.replace(".0.", ".linear_1.") + if "proj.2." in name: + name = name.replace(".2.", ".linear_2.") + if "pos_emb.weight" in name: data_torch = data_torch.view(data_torch.shape[0] * data_torch.shape[1], data_torch.shape[2]) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 84aa868809..456b3640c9 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -1255,6 +1255,7 @@ class TensorNameMap: MODEL_TENSOR.V_MMPROJ: ( "multi_modal_projector.linear_{bid}", + "mm_projector.proj.linear_{bid}", "visual.merger.mlp.{bid}", # qwen2vl "merger.mlp.{bid}", ), @@ -1490,6 +1491,7 @@ class TensorNameMap: "multi_modal_projector.norm", "multi_modal_projector.layer_norm", "multi_modal_projector.pre_norm", + "mm_projector.pre_norm", # Kimi-K2.5 "pre_mm_projector_norm", "model.vision.linear_proj.norm1", # cogvlm "merger.ln_q", From a4c9a08270abf5344ef1d943c5c32a71d485a4f0 Mon Sep 17 00:00:00 2001 From: Aes Sedai <7980540+AesSedai@users.noreply.github.com> Date: Wed, 28 Jan 2026 22:24:26 -0800 Subject: [PATCH 02/12] Fix a couple of oversights --- convert_hf_to_gguf.py | 2 +- gguf-py/gguf/tensor_mapping.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 8e293cd970..adfb9ebce5 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -10623,7 +10623,7 @@ class MistralMoeModel(DeepseekV2Model): self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1) # mscale_all_dim * 0.1 def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): - if name.startswith("vision_") or name.startswith("patch_merger."): + if name.startswith("vision_") or name.startswith("patch_merger.") or "mm_projector" in name: return # rename certain tensors so that we can reuse DeepseekV2Model modify_tensors logic diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 456b3640c9..35350c3fe1 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -1255,7 +1255,7 @@ class TensorNameMap: MODEL_TENSOR.V_MMPROJ: ( "multi_modal_projector.linear_{bid}", - "mm_projector.proj.linear_{bid}", + "mm_projector.proj.linear_{bid}", # Kimi-K2.5 "visual.merger.mlp.{bid}", # qwen2vl "merger.mlp.{bid}", ), From 9c44981c0113c9ba3fe04b9b8ba2c034ea3a6021 Mon Sep 17 00:00:00 2001 From: Aes Sedai <7980540+AesSedai@users.noreply.github.com> Date: Sun, 1 Feb 2026 02:14:26 -0800 Subject: [PATCH 03/12] Add image support for Kimi-K2.5 --- convert_hf_to_gguf.py | 71 +++++++++++++++++++++- gguf-py/gguf/constants.py | 1 + tools/mtmd/CMakeLists.txt | 1 + tools/mtmd/clip-graph.h | 11 ++++ tools/mtmd/clip-impl.h | 2 + tools/mtmd/clip.cpp | 109 ++++++++++++++++++++++++++++++++++ tools/mtmd/models/kimik25.cpp | 98 ++++++++++++++++++++++++++++++ tools/mtmd/models/models.h | 7 +++ 8 files changed, 299 insertions(+), 1 deletion(-) create mode 100644 tools/mtmd/models/kimik25.cpp diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index adfb9ebce5..a1d1a05fcf 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -10688,7 +10688,7 @@ class LightOnOCRVisionModel(LlavaVisionModel): yield from super().modify_tensors(data_torch, name, bid) -@ModelBase.register("KimiVLForConditionalGeneration", "KimiK25ForConditionalGeneration") +@ModelBase.register("KimiVLForConditionalGeneration") class KimiVLModel(MmprojModel): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -10729,6 +10729,75 @@ class KimiVLModel(MmprojModel): yield from super().modify_tensors(data_torch, name, bid) +@ModelBase.register("KimiK25ForConditionalGeneration") +class KimiK25Model(MmprojModel): + """Kimi-K2.5 with MoonViT3d vision encoder""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + assert self.hparams_vision is not None, "Kimi-K2.5 requires vision_config in model config" + + self.merge_kernel_size = tuple(self.hparams_vision.get("merge_kernel_size", [2, 2])) + self.patch_size = self.hparams_vision.get("patch_size", 14) + + # Set image_size for compatibility with base class + # Use position embedding dimensions as image_size reference + pos_emb_h = self.hparams_vision.get("init_pos_emb_height", 64) + self.hparams_vision["image_size"] = pos_emb_h * self.patch_size + + def set_gguf_parameters(self): + # Base class MmprojModel.set_gguf_parameters() already writes: + # - vision_block_count, vision_head_count, vision_embedding_length + # - vision_feed_forward_length, vision_patch_size, image_mean, image_std + # via find_vparam() which handles the vt_* prefixed keys in Kimi-K2.5's config + super().set_gguf_parameters() + assert self.hparams_vision is not None + + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.KIMIK25) + + # Position embedding parameters (for interpolation) - KimiK25-specific + self.gguf_writer.add_uint32("vision.pos_emb_height", self.hparams_vision.get("init_pos_emb_height", 64)) + self.gguf_writer.add_uint32("vision.pos_emb_width", self.hparams_vision.get("init_pos_emb_width", 64)) + self.gguf_writer.add_uint32("vision.pos_emb_time", self.hparams_vision.get("init_pos_emb_time", 4)) + + # Projector parameters + self.gguf_writer.add_vision_use_gelu(self.hparams_vision.get("projector_hidden_act", "gelu") == "gelu") + self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams_vision.get("projector_ln_eps", 1e-5)) + self.gguf_writer.add_vision_projector_scale_factor(self.merge_kernel_size[0]) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # Only process vision and projector tensors + is_vision = any(x in name for x in ["vision_tower", "mm_projector"]) + + if not is_vision: + return + + # Split fused QKV tensors in vision encoder + if "wqkv" in name: + split_dim = 0 if "weight" in name else -1 + wq, wk, wv = data_torch.chunk(3, dim=split_dim) + yield from super().modify_tensors(wq, name.replace("wqkv", "wq"), bid) + yield from super().modify_tensors(wk, name.replace("wqkv", "wk"), bid) + yield from super().modify_tensors(wv, name.replace("wqkv", "wv"), bid) + return + + # Temporal embeddings: (T, 1, C) → (T, C) + if "pos_emb.time_weight" in name: + T, _, C = data_torch.shape + data_torch = data_torch.reshape(T, C) + + # PatchMergerMLP tensor name mapping + # proj.0.weight → proj.linear_1.weight + # proj.2.weight → proj.linear_2.weight + if "mm_projector.proj.0." in name: + name = name.replace(".proj.0.", ".proj.linear_1.") + elif "mm_projector.proj.2." in name: + name = name.replace(".proj.2.", ".proj.linear_2.") + + yield from super().modify_tensors(data_torch, name, bid) + + @ModelBase.register("CogVLMForCausalLM") class CogVLMVisionModel(MmprojModel): diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 31273b2b5a..229d9db5e2 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -3610,6 +3610,7 @@ class VisionProjectorType: VOXTRAL = "voxtral" LFM2 = "lfm2" KIMIVL = "kimivl" + KIMIK25 = "kimik25" LIGHTONOCR = "lightonocr" COGVLM = "cogvlm" JANUS_PRO = "janus_pro" diff --git a/tools/mtmd/CMakeLists.txt b/tools/mtmd/CMakeLists.txt index 751440af32..02d71f224e 100644 --- a/tools/mtmd/CMakeLists.txt +++ b/tools/mtmd/CMakeLists.txt @@ -19,6 +19,7 @@ add_library(mtmd models/glm4v.cpp models/internvl.cpp models/kimivl.cpp + models/kimik25.cpp models/llama4.cpp models/llava.cpp models/minicpmv.cpp diff --git a/tools/mtmd/clip-graph.h b/tools/mtmd/clip-graph.h index 4c7f7504cf..8c9d56c8cb 100644 --- a/tools/mtmd/clip-graph.h +++ b/tools/mtmd/clip-graph.h @@ -107,6 +107,17 @@ struct clip_graph { const bool interleave_freq ); + // 2D RoPE with interleaved frequency + // Pattern: [x_freq0, y_freq0, x_freq1, y_freq1, ...] + // build_rope_2d uses split pattern: [x_freq0, x_freq1, ..., y_freq0, y_freq1, ...] + ggml_tensor * build_rope_2d_interleaved( + ggml_context * ctx0, + ggml_tensor * cur, // [n_dim, n_head, n_pos] + ggml_tensor * pos_w, // [n_pos] - X/width positions + ggml_tensor * pos_h, // [n_pos] - Y/height positions + const float freq_base + ); + // aka pixel_shuffle / pixel_unshuffle / patch_merger (Kimi-VL) // support dynamic resolution ggml_tensor * build_patch_merge_permute(ggml_tensor * cur, int scale_factor); diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index dd693623a2..7b012f6d5e 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -233,6 +233,7 @@ enum projector_type { PROJECTOR_TYPE_LFM2A, PROJECTOR_TYPE_GLM4V, PROJECTOR_TYPE_YOUTUVL, + PROJECTOR_TYPE_KIMIK25, PROJECTOR_TYPE_UNKNOWN, }; @@ -266,6 +267,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_LFM2A, "lfm2a"}, { PROJECTOR_TYPE_GLM4V, "glm4v"}, { PROJECTOR_TYPE_YOUTUVL, "youtuvl"}, + { PROJECTOR_TYPE_KIMIK25, "kimik25"}, }; static projector_type clip_projector_type_from_string(const std::string & str) { diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 9fa5afc390..daa7a01379 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -710,6 +710,83 @@ ggml_tensor * clip_graph::build_rope_2d( return cur; } +// 2D RoPE with interleaved frequency +// Pattern: [x_freq0, y_freq0, x_freq1, y_freq1, ...] +// build_rope_2d uses split pattern: [x_freq0, x_freq1, ..., y_freq0, y_freq1, ...] +ggml_tensor * clip_graph::build_rope_2d_interleaved( + ggml_context * ctx0, + ggml_tensor * cur, // [n_dim, n_head, n_pos] + ggml_tensor * pos_w, // [n_pos] - X/width positions + ggml_tensor * pos_h, // [n_pos] - Y/height positions + const float freq_base +) { + const int64_t n_dim = cur->ne[0]; + const int64_t n_head = cur->ne[1]; + const int64_t n_pos = cur->ne[2]; + + GGML_ASSERT(n_dim % 4 == 0); // Must be divisible by 4 for interleaved x,y pairs + + // Step 1: Reshape to expose interleaved structure + // cur: [n_dim, n_head, n_pos] -> [4, n_dim/4, n_head, n_pos] + ggml_tensor * reshaped = ggml_reshape_4d(ctx0, cur, 4, n_dim/4, n_head, n_pos); + + // Step 2: Extract X pairs (elements 0,1 of each group of 4) + // x_pairs: [2, n_dim/4, n_head, n_pos] + ggml_tensor * x_pairs = ggml_view_4d(ctx0, reshaped, + 2, n_dim/4, n_head, n_pos, + reshaped->nb[1], reshaped->nb[2], reshaped->nb[3], + 0); + + // Step 3: Extract Y pairs (elements 2,3 of each group of 4) + // y_pairs: [2, n_dim/4, n_head, n_pos] + ggml_tensor * y_pairs = ggml_view_4d(ctx0, reshaped, + 2, n_dim/4, n_head, n_pos, + reshaped->nb[1], reshaped->nb[2], reshaped->nb[3], + 2 * ggml_element_size(reshaped)); + + // Step 4: Make contiguous and reshape for rope_ext + // [2, n_dim/4, n_head, n_pos] -> [n_dim/2, n_head, n_pos] + x_pairs = ggml_cont(ctx0, x_pairs); + x_pairs = ggml_reshape_3d(ctx0, x_pairs, n_dim/2, n_head, n_pos); + + y_pairs = ggml_cont(ctx0, y_pairs); + y_pairs = ggml_reshape_3d(ctx0, y_pairs, n_dim/2, n_head, n_pos); + + // Step 5: Apply RoPE to X pairs using pos_w, Y pairs using pos_h + x_pairs = ggml_rope_ext( + ctx0, + x_pairs, + pos_w, + nullptr, + n_dim/2, + 0, 0, freq_base, + 1.0f, 0.0f, 1.0f, 0.0f, 0.0f + ); + + y_pairs = ggml_rope_ext( + ctx0, + y_pairs, + pos_h, + nullptr, + n_dim/2, + 0, 0, freq_base, + 1.0f, 0.0f, 1.0f, 0.0f, 0.0f + ); + + // Step 6: Reshape back to [2, n_dim/4, n_head, n_pos] for interleaving + x_pairs = ggml_reshape_4d(ctx0, x_pairs, 2, n_dim/4, n_head, n_pos); + y_pairs = ggml_reshape_4d(ctx0, y_pairs, 2, n_dim/4, n_head, n_pos); + + // Step 7: Interleave X and Y pairs back together + // Concatenate along dimension 0: [4, n_dim/4, n_head, n_pos] + ggml_tensor * result = ggml_concat(ctx0, x_pairs, y_pairs, 0); + + // Step 8: Reshape back to original: [n_dim, n_head, n_pos] + result = ggml_reshape_3d(ctx0, result, n_dim, n_head, n_pos); + + return result; +} + // Generic function to stack frames for audio processing // Abstracts out the StackAudioFrames logic used by ultravox ggml_tensor * clip_graph::build_stack(ggml_tensor * cur, int32_t stack_factor, int32_t n_embed) { @@ -825,6 +902,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 { builder = std::make_unique(ctx, img); } break; + case PROJECTOR_TYPE_KIMIK25: + { + builder = std::make_unique(ctx, img); + } break; case PROJECTOR_TYPE_COGVLM: { builder = std::make_unique(ctx, img); @@ -1139,6 +1220,13 @@ struct clip_model_loader { hparams.set_limit_image_tokens(8, 1024); hparams.set_warmup_n_tokens(256); // avoid OOM on warmup } break; + case PROJECTOR_TYPE_KIMIK25: + { + hparams.rope_theta = 10000.0f; + get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false); + hparams.set_limit_image_tokens(8, 4096); + hparams.set_warmup_n_tokens(256); + } break; case PROJECTOR_TYPE_GEMMA3: { // default value (used by all model sizes in gemma 3 family) @@ -1668,6 +1756,7 @@ struct clip_model_loader { model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias")); } break; case PROJECTOR_TYPE_KIMIVL: + case PROJECTOR_TYPE_KIMIK25: { model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM); model.mm_input_norm_b = get_tensor(TN_MM_INP_NORM_B); @@ -3039,6 +3128,23 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str res_imgs->entries.push_back(std::move(res)); } break; + case PROJECTOR_TYPE_KIMIK25: + { + GGML_ASSERT(params.image_min_pixels > 0 && params.image_max_pixels > 0); + const clip_image_size target_size = img_tool::calc_size_preserved_ratio( + original_size, + params.patch_size * params.n_merge, + params.image_min_pixels, + params.image_max_pixels); + const std::array pad_color = {0, 0, 0}; + + clip_image_u8 resized_img; + img_tool::resize(*img, resized_img, target_size, img_tool::RESIZE_ALGO_BICUBIC, true, pad_color); + clip_image_f32_ptr res(clip_image_f32_init()); + normalize_image_u8_to_f32(resized_img, *res, params.image_mean, params.image_std); + res_imgs->entries.push_back(std::move(res)); + } break; + case PROJECTOR_TYPE_MLP: case PROJECTOR_TYPE_MLP_NORM: case PROJECTOR_TYPE_LDP: @@ -3247,6 +3353,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im } break; case PROJECTOR_TYPE_LFM2: case PROJECTOR_TYPE_KIMIVL: + case PROJECTOR_TYPE_KIMIK25: { // dynamic size int out_patch_size = params.patch_size * ctx->model.hparams.n_merge; @@ -3588,6 +3695,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } break; case PROJECTOR_TYPE_PIXTRAL: case PROJECTOR_TYPE_KIMIVL: + case PROJECTOR_TYPE_KIMIK25: case PROJECTOR_TYPE_LIGHTONOCR: { // set the 2D positions @@ -3770,6 +3878,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { return ctx->model.mm_2_w->ne[1]; case PROJECTOR_TYPE_LFM2: case PROJECTOR_TYPE_KIMIVL: + case PROJECTOR_TYPE_KIMIK25: return ctx->model.mm_2_w->ne[1]; case PROJECTOR_TYPE_COGVLM: return ctx->model.mm_4h_to_h_w->ne[1]; diff --git a/tools/mtmd/models/kimik25.cpp b/tools/mtmd/models/kimik25.cpp new file mode 100644 index 0000000000..a34240b8bb --- /dev/null +++ b/tools/mtmd/models/kimik25.cpp @@ -0,0 +1,98 @@ +#include "models.h" +#include +#include + +// note: this is similar to clip_graph::resize_position_embeddings, major difference is having +// the w/h in ne[1] and ne[2] instead of assuming with sqrt. Could try storing the tensor in 2D instead +// with a w*h? Also the permute is a bit different at (2, 1, 0, 3) instead of (2, 0, 1, 3). +ggml_tensor * clip_graph_kimik25::resize_position_embeddings_3d(uint32_t interpolation_mode) { + ggml_tensor * pos_embd = model.position_embeddings; + const int height = img.ny / patch_size; + const int width = img.nx / patch_size; + const uint32_t mode = interpolation_mode; + + GGML_ASSERT(pos_embd); + + const int64_t stored_c = pos_embd->ne[0]; // C = 1152 + const int64_t orig_w = pos_embd->ne[1]; // W = 64 + const int64_t orig_h = pos_embd->ne[2]; // H = 64 + + GGML_ASSERT(stored_c == n_embd); + + if (height == (int)orig_h && width == (int)orig_w) { + // No interpolation needed, just flatten to [C, H*W] + return ggml_cont_2d(ctx0, pos_embd, n_embd, width * height); + } + + pos_embd = ggml_permute(ctx0, pos_embd, 2, 1, 0, 3); + pos_embd = ggml_interpolate(ctx0, pos_embd, height, width, n_embd, 1, mode); + pos_embd = ggml_permute(ctx0, pos_embd, 1, 2, 0, 3); + pos_embd = ggml_cont_2d(ctx0, pos_embd, n_embd, width * height); + return pos_embd; +} + +ggml_cgraph * clip_graph_kimik25::build() { + ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches); + ggml_set_name(pos_h, "pos_h"); + ggml_set_input(pos_h); + + ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches); + ggml_set_name(pos_w, "pos_w"); + ggml_set_input(pos_w); + + ggml_tensor * learned_pos_embd = resize_position_embeddings_3d(GGML_SCALE_MODE_BICUBIC); + + // Kimi-K2.5 uses INTERLEAVED frequency pattern: [x_freq0, y_freq0, x_freq1, y_freq1, ...] + auto add_pos = [&](ggml_tensor * cur, const clip_layer &) { + return build_rope_2d_interleaved(ctx0, cur, pos_w, pos_h, hparams.rope_theta); + }; + + ggml_tensor * inp = build_inp(); + + // I don't know why, but doing this in the build_vit lead to the ggml_add not occurring? + // Doing it manually here does work. + inp = ggml_add(ctx0, inp, learned_pos_embd); + + ggml_tensor * cur = build_vit( + inp, n_patches, + NORM_TYPE_NORMAL, + hparams.ffn_op, + nullptr, + add_pos); + + cb(cur, "vit_out", -1); + + { + // patch_merger + const int scale_factor = model.hparams.n_merge; + cur = build_patch_merge_permute(cur, scale_factor); + + // projection norm + int proj_inp_dim = cur->ne[0]; + cur = ggml_view_2d(ctx0, cur, + n_embd, cur->ne[1] * scale_factor * scale_factor, + ggml_row_size(cur->type, n_embd), 0); + cur = ggml_norm(ctx0, cur, hparams.eps); + cur = ggml_mul(ctx0, cur, model.mm_input_norm_w); + cur = ggml_add(ctx0, cur, model.mm_input_norm_b); + cur = ggml_view_2d(ctx0, cur, + proj_inp_dim, cur->ne[1], + ggml_row_size(cur->type, proj_inp_dim), 0); + cb(cur, "proj_inp_normed", -1); + + // projection mlp + cur = build_ffn(cur, + model.mm_1_w, model.mm_1_b, + nullptr, nullptr, + model.mm_2_w, model.mm_2_b, + FFN_GELU, + -1); + + cb(cur, "proj_out", -1); + } + + // build the graph + ggml_build_forward_expand(gf, cur); + + return gf; +} diff --git a/tools/mtmd/models/models.h b/tools/mtmd/models/models.h index 9970980c7b..c4c67ace62 100644 --- a/tools/mtmd/models/models.h +++ b/tools/mtmd/models/models.h @@ -109,3 +109,10 @@ struct clip_graph_mobilenetv5 : clip_graph { ggml_tensor * inp, const mobilenetv5_block & block); }; + +struct clip_graph_kimik25 : clip_graph { + clip_graph_kimik25(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} + ggml_cgraph * build() override; + + ggml_tensor * resize_position_embeddings_3d(uint32_t interpolation_mode); +}; From 9b14cb8b2812e2c6c846a825f4343a8e7bbe4e34 Mon Sep 17 00:00:00 2001 From: Aes Sedai <7980540+AesSedai@users.noreply.github.com> Date: Sun, 1 Feb 2026 02:19:53 -0800 Subject: [PATCH 04/12] Revert changes to KimiVLForConditionalGeneration --- convert_hf_to_gguf.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index a1d1a05fcf..00cb23c971 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -10705,17 +10705,9 @@ class KimiVLModel(MmprojModel): self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams_vision.get("layer_norm_eps", 1e-5)) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - is_vision_tensor = "vision_tower" in name or "multi_modal_projector" in name or "mm_projector" in name + is_vision_tensor = "vision_tower" in name or "multi_modal_projector" in name if is_vision_tensor: - # update names: - # "mm_projector.proj.0" -> "mm_projector.proj.linear_1.", - # "mm_projector.proj.2" -> "mm_projector.proj.linear_2.", - if "proj.0." in name: - name = name.replace(".0.", ".linear_1.") - if "proj.2." in name: - name = name.replace(".2.", ".linear_2.") - if "pos_emb.weight" in name: data_torch = data_torch.view(data_torch.shape[0] * data_torch.shape[1], data_torch.shape[2]) From 37a386dd93a5f8e505096a05f75093916512ed72 Mon Sep 17 00:00:00 2001 From: Aes Sedai <7980540+AesSedai@users.noreply.github.com> Date: Sun, 1 Feb 2026 02:56:15 -0800 Subject: [PATCH 05/12] Fix an assert crash --- tools/mtmd/models/kimik25.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tools/mtmd/models/kimik25.cpp b/tools/mtmd/models/kimik25.cpp index a34240b8bb..d79b2f39c2 100644 --- a/tools/mtmd/models/kimik25.cpp +++ b/tools/mtmd/models/kimik25.cpp @@ -69,14 +69,15 @@ ggml_cgraph * clip_graph_kimik25::build() { // projection norm int proj_inp_dim = cur->ne[0]; + int n_merged_patches = cur->ne[1]; cur = ggml_view_2d(ctx0, cur, - n_embd, cur->ne[1] * scale_factor * scale_factor, + n_embd, n_merged_patches * scale_factor * scale_factor, ggml_row_size(cur->type, n_embd), 0); cur = ggml_norm(ctx0, cur, hparams.eps); cur = ggml_mul(ctx0, cur, model.mm_input_norm_w); cur = ggml_add(ctx0, cur, model.mm_input_norm_b); cur = ggml_view_2d(ctx0, cur, - proj_inp_dim, cur->ne[1], + proj_inp_dim, n_merged_patches, ggml_row_size(cur->type, proj_inp_dim), 0); cb(cur, "proj_inp_normed", -1); From b1cf34ebe07edb3d1684c205676793ea81335049 Mon Sep 17 00:00:00 2001 From: Aes Sedai <7980540+AesSedai@users.noreply.github.com> Date: Wed, 4 Feb 2026 06:25:24 -0800 Subject: [PATCH 06/12] Fix permute swapping w / h on accident --- tools/mtmd/models/kimik25.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/mtmd/models/kimik25.cpp b/tools/mtmd/models/kimik25.cpp index d79b2f39c2..6db47e2c97 100644 --- a/tools/mtmd/models/kimik25.cpp +++ b/tools/mtmd/models/kimik25.cpp @@ -26,7 +26,7 @@ ggml_tensor * clip_graph_kimik25::resize_position_embeddings_3d(uint32_t interpo pos_embd = ggml_permute(ctx0, pos_embd, 2, 1, 0, 3); pos_embd = ggml_interpolate(ctx0, pos_embd, height, width, n_embd, 1, mode); - pos_embd = ggml_permute(ctx0, pos_embd, 1, 2, 0, 3); + pos_embd = ggml_permute(ctx0, pos_embd, 2, 1, 0, 3); pos_embd = ggml_cont_2d(ctx0, pos_embd, n_embd, width * height); return pos_embd; } From be1b0c35546cafdfc91296b1f95329834409fe4b Mon Sep 17 00:00:00 2001 From: Aes Sedai <7980540+AesSedai@users.noreply.github.com> Date: Sat, 7 Feb 2026 21:14:17 -0800 Subject: [PATCH 07/12] Kimi-K2.5: Use merged QKV for vision --- convert_hf_to_gguf.py | 9 --------- tools/mtmd/clip.cpp | 5 +++++ 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 826fb707ab..7d8fb6bb12 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -11113,15 +11113,6 @@ class KimiK25Model(MmprojModel): if not is_vision: return - # Split fused QKV tensors in vision encoder - if "wqkv" in name: - split_dim = 0 if "weight" in name else -1 - wq, wk, wv = data_torch.chunk(3, dim=split_dim) - yield from super().modify_tensors(wq, name.replace("wqkv", "wq"), bid) - yield from super().modify_tensors(wk, name.replace("wqkv", "wk"), bid) - yield from super().modify_tensors(wv, name.replace("wqkv", "wv"), bid) - return - # Temporal embeddings: (T, 1, C) → (T, C) if "pos_emb.time_weight" in name: T, _, C = data_torch.shape diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index daa7a01379..0fd07f5ca7 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -726,6 +726,11 @@ ggml_tensor * clip_graph::build_rope_2d_interleaved( GGML_ASSERT(n_dim % 4 == 0); // Must be divisible by 4 for interleaved x,y pairs + // Ensure input is contiguous (needed when using merged QKV with ggml_view) + if (!ggml_is_contiguous(cur)) { + cur = ggml_cont(ctx0, cur); + } + // Step 1: Reshape to expose interleaved structure // cur: [n_dim, n_head, n_pos] -> [4, n_dim/4, n_head, n_pos] ggml_tensor * reshaped = ggml_reshape_4d(ctx0, cur, 4, n_dim/4, n_head, n_pos); From 052fda6c5d61cbfec152f323c8e1e3e770c419a2 Mon Sep 17 00:00:00 2001 From: Aes Sedai <7980540+AesSedai@users.noreply.github.com> Date: Sat, 7 Feb 2026 23:18:45 -0800 Subject: [PATCH 08/12] Kimi-K2.5: pre-convert vision QK to use build_rope_2d --- convert_hf_to_gguf.py | 61 ++++++++++++++++++++++++++++++++++- tools/mtmd/clip.cpp | 18 +++++++++++ tools/mtmd/models/kimik25.cpp | 29 +++++++++++++++-- 3 files changed, 105 insertions(+), 3 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 7d8fb6bb12..070b22fcd9 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -11096,7 +11096,7 @@ class KimiK25Model(MmprojModel): self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.KIMIK25) - # Position embedding parameters (for interpolation) - KimiK25-specific + # Position embedding parameters (for interpolation) self.gguf_writer.add_uint32("vision.pos_emb_height", self.hparams_vision.get("init_pos_emb_height", 64)) self.gguf_writer.add_uint32("vision.pos_emb_width", self.hparams_vision.get("init_pos_emb_width", 64)) self.gguf_writer.add_uint32("vision.pos_emb_time", self.hparams_vision.get("init_pos_emb_time", 4)) @@ -11106,6 +11106,43 @@ class KimiK25Model(MmprojModel): self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams_vision.get("projector_ln_eps", 1e-5)) self.gguf_writer.add_vision_projector_scale_factor(self.merge_kernel_size[0]) + # Image size limits (from preprocessor_config.json media_proc_cfg) + # These are used to set token limits: tokens = pixels / (patch_size²) + in_patch_limit = self.preprocessor_config.get("in_patch_limit_each_frame", + self.preprocessor_config.get("in_patch_limit", 4096)) + min_patches = 8 # reasonable minimum + pixels_per_patch = self.patch_size * self.patch_size + self.gguf_writer.add_vision_min_pixels(min_patches * pixels_per_patch) + self.gguf_writer.add_vision_max_pixels(in_patch_limit * pixels_per_patch) + + @staticmethod + def _permute_rope_interleaved_to_split(weights: Tensor, n_head: int) -> Tensor: + """Permute Q/K weights from interleaved to split RoPE format. + + Kimi-K2.5 uses interleaved 2D RoPE pattern (per head): + [x0_re, x0_im, y0_re, y0_im, x1_re, x1_im, y1_re, y1_im, ...] + i.e., groups of 4: (x_pair, y_pair) repeated + + llama.cpp build_rope_2d expects split format (per head): + [x0_re, x0_im, x1_re, x1_im, ..., y0_re, y0_im, y1_re, y1_im, ...] + i.e., first half is all X pairs, second half is all Y pairs + + This permutation is applied at conversion time so we can use build_rope_2d at runtime. + """ + out_dim, in_dim = weights.shape + head_dim = out_dim // n_head + # Reshape to expose the interleaved structure: + # [n_head, head_dim//4, 2, 2, in_dim] + # where: head_dim//4 = number of (x,y) frequency pairs + # first 2 = x_or_y (0=x, 1=y) + # second 2 = re_or_im (real, imaginary parts of complex rotation) + w = weights.reshape(n_head, head_dim // 4, 2, 2, in_dim) + # Permute to split format: [n_head, 2, head_dim//4, 2, in_dim] + # Now dim 1 separates X (index 0) from Y (index 1) + w = w.permute(0, 2, 1, 3, 4) + # Reshape back: [out_dim, in_dim] + return w.reshape(out_dim, in_dim) + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: # Only process vision and projector tensors is_vision = any(x in name for x in ["vision_tower", "mm_projector"]) @@ -11113,6 +11150,28 @@ class KimiK25Model(MmprojModel): if not is_vision: return + assert self.hparams_vision is not None + n_head = self.hparams_vision.get("num_attention_heads", 16) + + # Permute Q/K weights/biases from interleaved to split RoPE format + # This allows using the build_rope_2d at runtime + if "wqkv" in name: + out_dim = data_torch.shape[0] + qkv_dim = out_dim // 3 + head_dim = qkv_dim // n_head + + if "weight" in name: + wq, wk, wv = data_torch[:qkv_dim, :], data_torch[qkv_dim:2*qkv_dim, :], data_torch[2*qkv_dim:, :] + wq = self._permute_rope_interleaved_to_split(wq, n_head) + wk = self._permute_rope_interleaved_to_split(wk, n_head) + data_torch = torch.cat([wq, wk, wv], dim=0) + elif "bias" in name: + bq, bk, bv = data_torch[:qkv_dim], data_torch[qkv_dim:2*qkv_dim], data_torch[2*qkv_dim:] + # Same permutation as weights: [n_head, head_dim//4, 2, 2] -> [n_head, 2, head_dim//4, 2] + bq = bq.reshape(n_head, head_dim // 4, 2, 2).permute(0, 2, 1, 3).reshape(-1) + bk = bk.reshape(n_head, head_dim // 4, 2, 2).permute(0, 2, 1, 3).reshape(-1) + data_torch = torch.cat([bq, bk, bv], dim=0) + # Temporal embeddings: (T, 1, C) → (T, C) if "pos_emb.time_weight" in name: T, _, C = data_torch.shape diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 0fd07f5ca7..eb174e4b17 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -655,6 +655,11 @@ ggml_tensor * clip_graph::build_rope_2d( const int64_t n_head = cur->ne[1]; const int64_t n_pos = cur->ne[2]; + // Ensure input is contiguous (needed when using merged QKV with ggml_view) + if (!ggml_is_contiguous(cur)) { + cur = ggml_cont(ctx0, cur); + } + // for example, if we have cur tensor of shape (n_dim=8, n_head, n_pos) // we will have a list of 4 inv_freq: 1e-0, 1e-1, 1e-2, 1e-3 // first half of cur will use 1e-0, 1e-2 (even) @@ -1229,7 +1234,20 @@ struct clip_model_loader { { hparams.rope_theta = 10000.0f; get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false); + + // Read min/max pixels from GGUF and convert to token limits + int min_pixels = 0, max_pixels = 0; + get_u32(KEY_IMAGE_MIN_PIXELS, min_pixels, false); + get_u32(KEY_IMAGE_MAX_PIXELS, max_pixels, false); + if (min_pixels > 0 && max_pixels > 0) { + const int pixels_per_patch = hparams.patch_size * hparams.patch_size; + const int min_tokens = min_pixels / pixels_per_patch; + const int max_tokens = max_pixels / pixels_per_patch; + hparams.set_limit_image_tokens(min_tokens, max_tokens); + } else { + // Fallback to hardcoded defaults hparams.set_limit_image_tokens(8, 4096); + } hparams.set_warmup_n_tokens(256); } break; case PROJECTOR_TYPE_GEMMA3: diff --git a/tools/mtmd/models/kimik25.cpp b/tools/mtmd/models/kimik25.cpp index 6db47e2c97..ceb7b848f9 100644 --- a/tools/mtmd/models/kimik25.cpp +++ b/tools/mtmd/models/kimik25.cpp @@ -42,9 +42,34 @@ ggml_cgraph * clip_graph_kimik25::build() { ggml_tensor * learned_pos_embd = resize_position_embeddings_3d(GGML_SCALE_MODE_BICUBIC); - // Kimi-K2.5 uses INTERLEAVED frequency pattern: [x_freq0, y_freq0, x_freq1, y_freq1, ...] + // Kimi-K2.5 uses interleaved 2D RoPE pattern: [x0_re, x0_im, y0_re, y0_im, x1_re, x1_im, ...] + // Q/K weights are permuted during conversion from interleaved to split format. + // build_rope_2d expects split format and outputs split format. + // We need to convert the output back to interleaved format for the attention mechanism. auto add_pos = [&](ggml_tensor * cur, const clip_layer &) { - return build_rope_2d_interleaved(ctx0, cur, pos_w, pos_h, hparams.rope_theta); + const int64_t n_dim = cur->ne[0]; + const int64_t n_head = cur->ne[1]; + const int64_t n_pos = cur->ne[2]; + + // Apply RoPE in split format + cur = build_rope_2d(ctx0, cur, pos_w, pos_h, hparams.rope_theta, false); + + // Convert output from split format back to interleaved format + // Split: [x0_re, x0_im, x1_re, x1_im, ..., y0_re, y0_im, y1_re, y1_im, ...] + // Interleaved: [x0_re, x0_im, y0_re, y0_im, x1_re, x1_im, y1_re, y1_im, ...] + // + // Reshape to [2, n_dim/4, 2, n_head, n_pos] where: + // - first dim 2 = re/im pair + // - n_dim/4 = number of frequency pairs per axis + // - second dim 2 = X half (0) vs Y half (1) + // Then permute to interleave X and Y + // Finally reshape back to [n_dim, n_head, n_pos] + cur = ggml_reshape_4d(ctx0, cur, 2, n_dim/4, 2, n_head * n_pos); + cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); // [2, 2, n_dim/4, n_head*n_pos] + cur = ggml_cont(ctx0, cur); + cur = ggml_reshape_3d(ctx0, cur, n_dim, n_head, n_pos); + + return cur; }; ggml_tensor * inp = build_inp(); From 0c50dd9fe4e43e598edda17f37956df3ba2a377a Mon Sep 17 00:00:00 2001 From: Aes Sedai <7980540+AesSedai@users.noreply.github.com> Date: Sun, 8 Feb 2026 01:19:18 -0800 Subject: [PATCH 09/12] Kimi-K2.5: support non-interleaved rope for vision --- convert_hf_to_gguf.py | 51 ++++++++++----------- gguf-py/gguf/tensor_mapping.py | 1 + tools/mtmd/clip-graph.h | 11 ----- tools/mtmd/clip.cpp | 82 ---------------------------------- tools/mtmd/models/kimik25.cpp | 30 +++---------- 5 files changed, 29 insertions(+), 146 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 070b22fcd9..83c2cd6923 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -11106,8 +11106,8 @@ class KimiK25Model(MmprojModel): self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams_vision.get("projector_ln_eps", 1e-5)) self.gguf_writer.add_vision_projector_scale_factor(self.merge_kernel_size[0]) - # Image size limits (from preprocessor_config.json media_proc_cfg) - # These are used to set token limits: tokens = pixels / (patch_size²) + # Image size limits + # These are used to set token limits: tokens = pixels / (patch_size ^ 2) in_patch_limit = self.preprocessor_config.get("in_patch_limit_each_frame", self.preprocessor_config.get("in_patch_limit", 4096)) min_patches = 8 # reasonable minimum @@ -11116,31 +11116,19 @@ class KimiK25Model(MmprojModel): self.gguf_writer.add_vision_max_pixels(in_patch_limit * pixels_per_patch) @staticmethod - def _permute_rope_interleaved_to_split(weights: Tensor, n_head: int) -> Tensor: - """Permute Q/K weights from interleaved to split RoPE format. - - Kimi-K2.5 uses interleaved 2D RoPE pattern (per head): - [x0_re, x0_im, y0_re, y0_im, x1_re, x1_im, y1_re, y1_im, ...] - i.e., groups of 4: (x_pair, y_pair) repeated - - llama.cpp build_rope_2d expects split format (per head): - [x0_re, x0_im, x1_re, x1_im, ..., y0_re, y0_im, y1_re, y1_im, ...] - i.e., first half is all X pairs, second half is all Y pairs - - This permutation is applied at conversion time so we can use build_rope_2d at runtime. - """ + def _permute_kqv(weights: Tensor, n_head: int) -> Tensor: out_dim, in_dim = weights.shape head_dim = out_dim // n_head - # Reshape to expose the interleaved structure: - # [n_head, head_dim//4, 2, 2, in_dim] - # where: head_dim//4 = number of (x,y) frequency pairs - # first 2 = x_or_y (0=x, 1=y) - # second 2 = re_or_im (real, imaginary parts of complex rotation) w = weights.reshape(n_head, head_dim // 4, 2, 2, in_dim) - # Permute to split format: [n_head, 2, head_dim//4, 2, in_dim] - # Now dim 1 separates X (index 0) from Y (index 1) w = w.permute(0, 2, 1, 3, 4) - # Reshape back: [out_dim, in_dim] + return w.reshape(out_dim, in_dim) + + @staticmethod + def _permute_output_proj(weights: Tensor, n_head: int) -> Tensor: + out_dim, in_dim = weights.shape + head_dim = in_dim // n_head + w = weights.reshape(out_dim, n_head, head_dim // 4, 2, 2) + w = w.permute(0, 1, 3, 2, 4) return w.reshape(out_dim, in_dim) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: @@ -11153,8 +11141,10 @@ class KimiK25Model(MmprojModel): assert self.hparams_vision is not None n_head = self.hparams_vision.get("num_attention_heads", 16) - # Permute Q/K weights/biases from interleaved to split RoPE format - # This allows using the build_rope_2d at runtime + # Permute Q/K/V weights/biases from interleaved to split RoPE format + # This allows using build_rope_2d at runtime without post-permutation. + # V is also permuted so the attention output is in split format, + # which is then handled by the permuted output projection. if "wqkv" in name: out_dim = data_torch.shape[0] qkv_dim = out_dim // 3 @@ -11162,16 +11152,21 @@ class KimiK25Model(MmprojModel): if "weight" in name: wq, wk, wv = data_torch[:qkv_dim, :], data_torch[qkv_dim:2*qkv_dim, :], data_torch[2*qkv_dim:, :] - wq = self._permute_rope_interleaved_to_split(wq, n_head) - wk = self._permute_rope_interleaved_to_split(wk, n_head) + wq = self._permute_kqv(wq, n_head) + wk = self._permute_kqv(wk, n_head) + wv = self._permute_kqv(wv, n_head) data_torch = torch.cat([wq, wk, wv], dim=0) elif "bias" in name: bq, bk, bv = data_torch[:qkv_dim], data_torch[qkv_dim:2*qkv_dim], data_torch[2*qkv_dim:] - # Same permutation as weights: [n_head, head_dim//4, 2, 2] -> [n_head, 2, head_dim//4, 2] bq = bq.reshape(n_head, head_dim // 4, 2, 2).permute(0, 2, 1, 3).reshape(-1) bk = bk.reshape(n_head, head_dim // 4, 2, 2).permute(0, 2, 1, 3).reshape(-1) + bv = bv.reshape(n_head, head_dim // 4, 2, 2).permute(0, 2, 1, 3).reshape(-1) data_torch = torch.cat([bq, bk, bv], dim=0) + # Permute output projection from interleaved to split RoPE format + if "wo.weight" in name: + data_torch = self._permute_output_proj(data_torch, n_head) + # Temporal embeddings: (T, 1, C) → (T, C) if "pos_emb.time_weight" in name: T, _, C = data_torch.shape diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index ba4f644dc2..548b035964 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -1358,6 +1358,7 @@ class TensorNameMap: MODEL_TENSOR.V_ENC_ATTN_QKV: ( "visual.blocks.{bid}.attn.qkv", # qwen3vl "model.vision.transformer.layers.{bid}.attention.query_key_value", # cogvlm + "vision_tower.encoder.blocks.{bid}.wqkv" # Kimi-K2.5 ), MODEL_TENSOR.V_ENC_ATTN_Q: ( diff --git a/tools/mtmd/clip-graph.h b/tools/mtmd/clip-graph.h index 8c9d56c8cb..4c7f7504cf 100644 --- a/tools/mtmd/clip-graph.h +++ b/tools/mtmd/clip-graph.h @@ -107,17 +107,6 @@ struct clip_graph { const bool interleave_freq ); - // 2D RoPE with interleaved frequency - // Pattern: [x_freq0, y_freq0, x_freq1, y_freq1, ...] - // build_rope_2d uses split pattern: [x_freq0, x_freq1, ..., y_freq0, y_freq1, ...] - ggml_tensor * build_rope_2d_interleaved( - ggml_context * ctx0, - ggml_tensor * cur, // [n_dim, n_head, n_pos] - ggml_tensor * pos_w, // [n_pos] - X/width positions - ggml_tensor * pos_h, // [n_pos] - Y/height positions - const float freq_base - ); - // aka pixel_shuffle / pixel_unshuffle / patch_merger (Kimi-VL) // support dynamic resolution ggml_tensor * build_patch_merge_permute(ggml_tensor * cur, int scale_factor); diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index eb174e4b17..168341edf0 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -715,88 +715,6 @@ ggml_tensor * clip_graph::build_rope_2d( return cur; } -// 2D RoPE with interleaved frequency -// Pattern: [x_freq0, y_freq0, x_freq1, y_freq1, ...] -// build_rope_2d uses split pattern: [x_freq0, x_freq1, ..., y_freq0, y_freq1, ...] -ggml_tensor * clip_graph::build_rope_2d_interleaved( - ggml_context * ctx0, - ggml_tensor * cur, // [n_dim, n_head, n_pos] - ggml_tensor * pos_w, // [n_pos] - X/width positions - ggml_tensor * pos_h, // [n_pos] - Y/height positions - const float freq_base -) { - const int64_t n_dim = cur->ne[0]; - const int64_t n_head = cur->ne[1]; - const int64_t n_pos = cur->ne[2]; - - GGML_ASSERT(n_dim % 4 == 0); // Must be divisible by 4 for interleaved x,y pairs - - // Ensure input is contiguous (needed when using merged QKV with ggml_view) - if (!ggml_is_contiguous(cur)) { - cur = ggml_cont(ctx0, cur); - } - - // Step 1: Reshape to expose interleaved structure - // cur: [n_dim, n_head, n_pos] -> [4, n_dim/4, n_head, n_pos] - ggml_tensor * reshaped = ggml_reshape_4d(ctx0, cur, 4, n_dim/4, n_head, n_pos); - - // Step 2: Extract X pairs (elements 0,1 of each group of 4) - // x_pairs: [2, n_dim/4, n_head, n_pos] - ggml_tensor * x_pairs = ggml_view_4d(ctx0, reshaped, - 2, n_dim/4, n_head, n_pos, - reshaped->nb[1], reshaped->nb[2], reshaped->nb[3], - 0); - - // Step 3: Extract Y pairs (elements 2,3 of each group of 4) - // y_pairs: [2, n_dim/4, n_head, n_pos] - ggml_tensor * y_pairs = ggml_view_4d(ctx0, reshaped, - 2, n_dim/4, n_head, n_pos, - reshaped->nb[1], reshaped->nb[2], reshaped->nb[3], - 2 * ggml_element_size(reshaped)); - - // Step 4: Make contiguous and reshape for rope_ext - // [2, n_dim/4, n_head, n_pos] -> [n_dim/2, n_head, n_pos] - x_pairs = ggml_cont(ctx0, x_pairs); - x_pairs = ggml_reshape_3d(ctx0, x_pairs, n_dim/2, n_head, n_pos); - - y_pairs = ggml_cont(ctx0, y_pairs); - y_pairs = ggml_reshape_3d(ctx0, y_pairs, n_dim/2, n_head, n_pos); - - // Step 5: Apply RoPE to X pairs using pos_w, Y pairs using pos_h - x_pairs = ggml_rope_ext( - ctx0, - x_pairs, - pos_w, - nullptr, - n_dim/2, - 0, 0, freq_base, - 1.0f, 0.0f, 1.0f, 0.0f, 0.0f - ); - - y_pairs = ggml_rope_ext( - ctx0, - y_pairs, - pos_h, - nullptr, - n_dim/2, - 0, 0, freq_base, - 1.0f, 0.0f, 1.0f, 0.0f, 0.0f - ); - - // Step 6: Reshape back to [2, n_dim/4, n_head, n_pos] for interleaving - x_pairs = ggml_reshape_4d(ctx0, x_pairs, 2, n_dim/4, n_head, n_pos); - y_pairs = ggml_reshape_4d(ctx0, y_pairs, 2, n_dim/4, n_head, n_pos); - - // Step 7: Interleave X and Y pairs back together - // Concatenate along dimension 0: [4, n_dim/4, n_head, n_pos] - ggml_tensor * result = ggml_concat(ctx0, x_pairs, y_pairs, 0); - - // Step 8: Reshape back to original: [n_dim, n_head, n_pos] - result = ggml_reshape_3d(ctx0, result, n_dim, n_head, n_pos); - - return result; -} - // Generic function to stack frames for audio processing // Abstracts out the StackAudioFrames logic used by ultravox ggml_tensor * clip_graph::build_stack(ggml_tensor * cur, int32_t stack_factor, int32_t n_embed) { diff --git a/tools/mtmd/models/kimik25.cpp b/tools/mtmd/models/kimik25.cpp index ceb7b848f9..5f5cd9b7ed 100644 --- a/tools/mtmd/models/kimik25.cpp +++ b/tools/mtmd/models/kimik25.cpp @@ -42,33 +42,13 @@ ggml_cgraph * clip_graph_kimik25::build() { ggml_tensor * learned_pos_embd = resize_position_embeddings_3d(GGML_SCALE_MODE_BICUBIC); - // Kimi-K2.5 uses interleaved 2D RoPE pattern: [x0_re, x0_im, y0_re, y0_im, x1_re, x1_im, ...] - // Q/K weights are permuted during conversion from interleaved to split format. - // build_rope_2d expects split format and outputs split format. - // We need to convert the output back to interleaved format for the attention mechanism. + // Kimi-K2.5 uses interleaved 2D RoPE pattern natively, but all attention weights + // (Q, K, V, O) are permuted during conversion to use split format throughout. + // This allows using build_rope_2d without any runtime format conversion. + // The dot product in attention is order-independent, so keeping everything in + // split format produces mathematically equivalent results. auto add_pos = [&](ggml_tensor * cur, const clip_layer &) { - const int64_t n_dim = cur->ne[0]; - const int64_t n_head = cur->ne[1]; - const int64_t n_pos = cur->ne[2]; - - // Apply RoPE in split format cur = build_rope_2d(ctx0, cur, pos_w, pos_h, hparams.rope_theta, false); - - // Convert output from split format back to interleaved format - // Split: [x0_re, x0_im, x1_re, x1_im, ..., y0_re, y0_im, y1_re, y1_im, ...] - // Interleaved: [x0_re, x0_im, y0_re, y0_im, x1_re, x1_im, y1_re, y1_im, ...] - // - // Reshape to [2, n_dim/4, 2, n_head, n_pos] where: - // - first dim 2 = re/im pair - // - n_dim/4 = number of frequency pairs per axis - // - second dim 2 = X half (0) vs Y half (1) - // Then permute to interleave X and Y - // Finally reshape back to [n_dim, n_head, n_pos] - cur = ggml_reshape_4d(ctx0, cur, 2, n_dim/4, 2, n_head * n_pos); - cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); // [2, 2, n_dim/4, n_head*n_pos] - cur = ggml_cont(ctx0, cur); - cur = ggml_reshape_3d(ctx0, cur, n_dim, n_head, n_pos); - return cur; }; From d0d1062e7f5ea21f722a4b85cbe99fa900e87e74 Mon Sep 17 00:00:00 2001 From: Aes Sedai <7980540+AesSedai@users.noreply.github.com> Date: Sun, 8 Feb 2026 02:15:12 -0800 Subject: [PATCH 10/12] Kimi-K2.5: fix min / max pixel --- convert_hf_to_gguf.py | 7 +++--- tools/mtmd/clip.cpp | 53 ++++++++++++++++++++++++++++++++++++------- 2 files changed, 48 insertions(+), 12 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 83c2cd6923..c58dd91d9d 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -11107,11 +11107,10 @@ class KimiK25Model(MmprojModel): self.gguf_writer.add_vision_projector_scale_factor(self.merge_kernel_size[0]) # Image size limits - # These are used to set token limits: tokens = pixels / (patch_size ^ 2) - in_patch_limit = self.preprocessor_config.get("in_patch_limit_each_frame", - self.preprocessor_config.get("in_patch_limit", 4096)) + # Note: in_patch_limit is for images, in_patch_limit_each_frame is for video (not supported yet) + in_patch_limit = self.preprocessor_config.get("in_patch_limit", 16384) min_patches = 8 # reasonable minimum - pixels_per_patch = self.patch_size * self.patch_size + pixels_per_patch = self.patch_size ** 2 self.gguf_writer.add_vision_min_pixels(min_patches * pixels_per_patch) self.gguf_writer.add_vision_max_pixels(in_patch_limit * pixels_per_patch) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 168341edf0..dae17c6fb0 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -1153,20 +1153,16 @@ struct clip_model_loader { hparams.rope_theta = 10000.0f; get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false); - // Read min/max pixels from GGUF and convert to token limits int min_pixels = 0, max_pixels = 0; get_u32(KEY_IMAGE_MIN_PIXELS, min_pixels, false); get_u32(KEY_IMAGE_MAX_PIXELS, max_pixels, false); if (min_pixels > 0 && max_pixels > 0) { - const int pixels_per_patch = hparams.patch_size * hparams.patch_size; - const int min_tokens = min_pixels / pixels_per_patch; - const int max_tokens = max_pixels / pixels_per_patch; - hparams.set_limit_image_tokens(min_tokens, max_tokens); + hparams.image_min_pixels = min_pixels; + hparams.image_max_pixels = max_pixels; + hparams.warmup_image_size = static_cast(std::sqrt(max_pixels)); } else { - // Fallback to hardcoded defaults - hparams.set_limit_image_tokens(8, 4096); + hparams.set_limit_image_tokens(2, 4096); } - hparams.set_warmup_n_tokens(256); } break; case PROJECTOR_TYPE_GEMMA3: { @@ -3773,6 +3769,47 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima ggml_backend_tensor_get(embeddings, vec, 0, ggml_nbytes(embeddings)); } + // Debug: dump final embeddings if MTMD_DEBUG_EMBEDDINGS is set + if (std::getenv("MTMD_DEBUG_EMBEDDINGS") != nullptr) { + const int64_t n_embd = embeddings->ne[0]; + const int64_t n_tokens = embeddings->ne[1]; + std::vector emb_data(n_embd * n_tokens); + ggml_backend_tensor_get(embeddings, emb_data.data(), 0, ggml_nbytes(embeddings)); + + LOG_INF("\n=== MTMD_DEBUG_EMBEDDINGS ===\n"); + LOG_INF("Shape: [%lld, %lld]\n", (long long)n_embd, (long long)n_tokens); + + // Print first few values of first token + LOG_INF("Token 0 (first 16 values): "); + for (int i = 0; i < std::min((int64_t)16, n_embd); i++) { + LOG_INF("%.6f ", emb_data[i]); + } + LOG_INF("\n"); + + // Print last few values of first token + if (n_embd > 16) { + LOG_INF("Token 0 (last 16 values): "); + for (int64_t i = n_embd - 16; i < n_embd; i++) { + LOG_INF("%.6f ", emb_data[i]); + } + LOG_INF("\n"); + } + + // Compute and print statistics + float sum = 0.0f, sum_sq = 0.0f, min_val = emb_data[0], max_val = emb_data[0]; + for (size_t i = 0; i < emb_data.size(); i++) { + sum += emb_data[i]; + sum_sq += emb_data[i] * emb_data[i]; + min_val = std::min(min_val, emb_data[i]); + max_val = std::max(max_val, emb_data[i]); + } + float mean = sum / emb_data.size(); + float variance = (sum_sq / emb_data.size()) - (mean * mean); + LOG_INF("Stats: mean=%.6f, std=%.6f, min=%.6f, max=%.6f, sum=%.6f\n", + mean, sqrtf(variance), min_val, max_val, sum); + LOG_INF("=== END MTMD_DEBUG_EMBEDDINGS ===\n\n"); + } + return true; } From c8953657c474b2ad52bafe77ebac6ca9de606df9 Mon Sep 17 00:00:00 2001 From: Aes Sedai <7980540+AesSedai@users.noreply.github.com> Date: Sun, 8 Feb 2026 12:18:49 -0800 Subject: [PATCH 11/12] Kimi-K2.5: remove v/o permutes, unnecessary --- convert_hf_to_gguf.py | 17 +---------------- tools/mtmd/models/kimik25.cpp | 7 ++----- 2 files changed, 3 insertions(+), 21 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index c58dd91d9d..5a3f74812e 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -11122,14 +11122,6 @@ class KimiK25Model(MmprojModel): w = w.permute(0, 2, 1, 3, 4) return w.reshape(out_dim, in_dim) - @staticmethod - def _permute_output_proj(weights: Tensor, n_head: int) -> Tensor: - out_dim, in_dim = weights.shape - head_dim = in_dim // n_head - w = weights.reshape(out_dim, n_head, head_dim // 4, 2, 2) - w = w.permute(0, 1, 3, 2, 4) - return w.reshape(out_dim, in_dim) - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: # Only process vision and projector tensors is_vision = any(x in name for x in ["vision_tower", "mm_projector"]) @@ -11140,10 +11132,8 @@ class KimiK25Model(MmprojModel): assert self.hparams_vision is not None n_head = self.hparams_vision.get("num_attention_heads", 16) - # Permute Q/K/V weights/biases from interleaved to split RoPE format + # Permute Q/K weights/biases from interleaved to split RoPE format # This allows using build_rope_2d at runtime without post-permutation. - # V is also permuted so the attention output is in split format, - # which is then handled by the permuted output projection. if "wqkv" in name: out_dim = data_torch.shape[0] qkv_dim = out_dim // 3 @@ -11153,18 +11143,13 @@ class KimiK25Model(MmprojModel): wq, wk, wv = data_torch[:qkv_dim, :], data_torch[qkv_dim:2*qkv_dim, :], data_torch[2*qkv_dim:, :] wq = self._permute_kqv(wq, n_head) wk = self._permute_kqv(wk, n_head) - wv = self._permute_kqv(wv, n_head) data_torch = torch.cat([wq, wk, wv], dim=0) elif "bias" in name: bq, bk, bv = data_torch[:qkv_dim], data_torch[qkv_dim:2*qkv_dim], data_torch[2*qkv_dim:] bq = bq.reshape(n_head, head_dim // 4, 2, 2).permute(0, 2, 1, 3).reshape(-1) bk = bk.reshape(n_head, head_dim // 4, 2, 2).permute(0, 2, 1, 3).reshape(-1) - bv = bv.reshape(n_head, head_dim // 4, 2, 2).permute(0, 2, 1, 3).reshape(-1) data_torch = torch.cat([bq, bk, bv], dim=0) - # Permute output projection from interleaved to split RoPE format - if "wo.weight" in name: - data_torch = self._permute_output_proj(data_torch, n_head) # Temporal embeddings: (T, 1, C) → (T, C) if "pos_emb.time_weight" in name: diff --git a/tools/mtmd/models/kimik25.cpp b/tools/mtmd/models/kimik25.cpp index 5f5cd9b7ed..cf9f27f63a 100644 --- a/tools/mtmd/models/kimik25.cpp +++ b/tools/mtmd/models/kimik25.cpp @@ -42,11 +42,8 @@ ggml_cgraph * clip_graph_kimik25::build() { ggml_tensor * learned_pos_embd = resize_position_embeddings_3d(GGML_SCALE_MODE_BICUBIC); - // Kimi-K2.5 uses interleaved 2D RoPE pattern natively, but all attention weights - // (Q, K, V, O) are permuted during conversion to use split format throughout. - // This allows using build_rope_2d without any runtime format conversion. - // The dot product in attention is order-independent, so keeping everything in - // split format produces mathematically equivalent results. + // Kimi-K2.5 uses interleaved 2D RoPE pattern natively, but + // Q / K are permuted during conversion to use split format. auto add_pos = [&](ggml_tensor * cur, const clip_layer &) { cur = build_rope_2d(ctx0, cur, pos_w, pos_h, hparams.rope_theta, false); return cur; From 7b4af22fa99b85153d29ea573ff012992d22afde Mon Sep 17 00:00:00 2001 From: Aes Sedai <7980540+AesSedai@users.noreply.github.com> Date: Sun, 8 Feb 2026 12:37:02 -0800 Subject: [PATCH 12/12] Kimi-K2.5: update permute name to match --- convert_hf_to_gguf.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 5a3f74812e..e138c2eebd 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -11115,7 +11115,7 @@ class KimiK25Model(MmprojModel): self.gguf_writer.add_vision_max_pixels(in_patch_limit * pixels_per_patch) @staticmethod - def _permute_kqv(weights: Tensor, n_head: int) -> Tensor: + def permute(weights: Tensor, n_head: int) -> Tensor: out_dim, in_dim = weights.shape head_dim = out_dim // n_head w = weights.reshape(n_head, head_dim // 4, 2, 2, in_dim) @@ -11141,8 +11141,8 @@ class KimiK25Model(MmprojModel): if "weight" in name: wq, wk, wv = data_torch[:qkv_dim, :], data_torch[qkv_dim:2*qkv_dim, :], data_torch[2*qkv_dim:, :] - wq = self._permute_kqv(wq, n_head) - wk = self._permute_kqv(wk, n_head) + wq = self.permute(wq, n_head) + wk = self.permute(wk, n_head) data_torch = torch.cat([wq, wk, wv], dim=0) elif "bias" in name: bq, bk, bv = data_torch[:qkv_dim], data_torch[qkv_dim:2*qkv_dim], data_torch[2*qkv_dim:]