From 042c3cb8c5cdd857e55300630f405bb5041afd4b Mon Sep 17 00:00:00 2001 From: Aes Sedai <7980540+AesSedai@users.noreply.github.com> Date: Wed, 28 Jan 2026 22:06:59 -0800 Subject: [PATCH] Move dequant_model to after the text_config merge Add new kimi-k2.5 keys to mtmd convert Update V_MMPROJ tensor mapping for new mm_projector.proj keys Update V_M_IMP_NORM for new mm_projector.pre_norm key --- convert_hf_to_gguf.py | 41 ++++++++++++++++++++++++---------- gguf-py/gguf/tensor_mapping.py | 2 ++ 2 files changed, 31 insertions(+), 12 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index eb43520f98..8e293cd970 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -160,8 +160,6 @@ class ModelBase: self.ftype = gguf.LlamaFileType.MOSTLY_F16 logger.info("heuristics unable to detect tensor dtype, defaulting to --outtype f16") - self.dequant_model() - # Configure GGUF Writer self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file, split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard) @@ -527,6 +525,8 @@ class ModelBase: return () def prepare_tensors(self): + self.dequant_model() + # Handle empty tensor_map for models with block_count=0 (like MobileNetV5) if self.tensor_map.mapping: max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,") @@ -1808,7 +1808,7 @@ class MmprojModel(ModelBase): preprocessor_config: dict[str, Any] global_config: dict[str, Any] - n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth", "encoder_layers"] + n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth", "encoder_layers", "vt_num_hidden_layers"] has_vision_encoder: bool = True # by default has_audio_encoder: bool = False @@ -1863,7 +1863,15 @@ class MmprojModel(ModelBase): preprocessor_config_path = self.dir_model / "preprocessor_config.json" if preprocessor_config_path.is_file(): with open(preprocessor_config_path, "r", encoding="utf-8") as f: - self.preprocessor_config = json.load(f) + cfg = json.load(f) + # move media_proc_cfg to root level for compat + if "media_proc_cfg" in cfg: + cfg = { + **cfg, + **cfg["media_proc_cfg"], + } + # merge configs + self.preprocessor_config = {**self.preprocessor_config, **cfg} # prefer processor_config.json if possible processor_config_path = self.dir_model / "processor_config.json" @@ -1912,10 +1920,10 @@ class MmprojModel(ModelBase): self.image_size = self.find_vparam(["image_size"]) self.gguf_writer.add_vision_image_size(self.image_size) self.gguf_writer.add_vision_patch_size(self.find_vparam(["patch_size"])) - self.gguf_writer.add_vision_embedding_length(self.find_vparam(["hidden_size"])) - self.gguf_writer.add_vision_feed_forward_length(self.find_vparam(["intermediate_size"])) + self.gguf_writer.add_vision_embedding_length(self.find_vparam(["hidden_size", "vt_hidden_size"])) + self.gguf_writer.add_vision_feed_forward_length(self.find_vparam(["intermediate_size", "vt_intermediate_size"])) self.gguf_writer.add_vision_block_count(self.find_vparam(self.n_block_keys)) - self.gguf_writer.add_vision_head_count(self.find_vparam(["num_attention_heads", "num_heads"])) + self.gguf_writer.add_vision_head_count(self.find_vparam(["num_attention_heads", "num_heads", "vt_num_attention_heads"])) # preprocessor config image_mean = _MISTRAL_COMMON_DATASET_MEAN if self.is_mistral_format else self.preprocessor_config["image_mean"] @@ -7360,6 +7368,7 @@ class DeepseekModel(TextModel): "DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM", "KimiVLForConditionalGeneration", + "KimiK25ForConditionalGeneration", "YoutuForCausalLM", "YoutuVLForConditionalGeneration", ) @@ -7478,8 +7487,8 @@ class DeepseekV2Model(TextModel): _experts: list[dict[str, Tensor]] | None = None def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - # skip vision tensors and remove "language_model." for Kimi-VL - if "vision_tower" in name or "multi_modal_projector" in name: + # skip vision tensors and remove "language_model." for Kimi-VL and Kimi-K2.5 + if "vision_tower" in name or "multi_modal_projector" in name or "mm_projector" in name: return if name.startswith("siglip2.") or name.startswith("merger."): return @@ -10614,7 +10623,7 @@ class MistralMoeModel(DeepseekV2Model): self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1) # mscale_all_dim * 0.1 def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): - if name.startswith("vision_") or name.startswith("patch_merger.") or "mm_projector" in name: + if name.startswith("vision_") or name.startswith("patch_merger."): return # rename certain tensors so that we can reuse DeepseekV2Model modify_tensors logic @@ -10679,7 +10688,7 @@ class LightOnOCRVisionModel(LlavaVisionModel): yield from super().modify_tensors(data_torch, name, bid) -@ModelBase.register("KimiVLForConditionalGeneration") +@ModelBase.register("KimiVLForConditionalGeneration", "KimiK25ForConditionalGeneration") class KimiVLModel(MmprojModel): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -10696,9 +10705,17 @@ class KimiVLModel(MmprojModel): self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams_vision.get("layer_norm_eps", 1e-5)) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - is_vision_tensor = "vision_tower" in name or "multi_modal_projector" in name + is_vision_tensor = "vision_tower" in name or "multi_modal_projector" in name or "mm_projector" in name if is_vision_tensor: + # update names: + # "mm_projector.proj.0" -> "mm_projector.proj.linear_1.", + # "mm_projector.proj.2" -> "mm_projector.proj.linear_2.", + if "proj.0." in name: + name = name.replace(".0.", ".linear_1.") + if "proj.2." in name: + name = name.replace(".2.", ".linear_2.") + if "pos_emb.weight" in name: data_torch = data_torch.view(data_torch.shape[0] * data_torch.shape[1], data_torch.shape[2]) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 84aa868809..456b3640c9 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -1255,6 +1255,7 @@ class TensorNameMap: MODEL_TENSOR.V_MMPROJ: ( "multi_modal_projector.linear_{bid}", + "mm_projector.proj.linear_{bid}", "visual.merger.mlp.{bid}", # qwen2vl "merger.mlp.{bid}", ), @@ -1490,6 +1491,7 @@ class TensorNameMap: "multi_modal_projector.norm", "multi_modal_projector.layer_norm", "multi_modal_projector.pre_norm", + "mm_projector.pre_norm", # Kimi-K2.5 "pre_mm_projector_norm", "model.vision.linear_proj.norm1", # cogvlm "merger.ln_q",