From 6e89a8b296cf880d0c5e869962b6a482f99fb85f Mon Sep 17 00:00:00 2001 From: liyang Date: Tue, 18 Nov 2025 09:55:23 +0800 Subject: [PATCH] Refactor JinaCLIP vision mmproj mapping to use tensor_mapping table --- common/arg.cpp | 1 - convert_hf_to_gguf.py | 272 ++++++++++---------------------- gguf-py/gguf/constants.py | 12 ++ gguf-py/gguf/tensor_mapping.py | 28 ++++ tools/mtmd/clip-graph.h | 1 + tools/mtmd/clip-impl.h | 8 +- tools/mtmd/clip-model.h | 5 + tools/mtmd/clip.cpp | 39 +++-- tools/mtmd/clip.h | 1 - tools/mtmd/models/jinaclip2.cpp | 12 +- tools/mtmd/mtmd-cli.cpp | 163 ++++++++++--------- tools/mtmd/mtmd-helper.cpp | 1 - tools/mtmd/mtmd-helper.h | 3 +- tools/mtmd/mtmd.cpp | 61 ++----- tools/mtmd/mtmd.h | 31 ++-- 15 files changed, 279 insertions(+), 359 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index bffe88d81a..7767c08e10 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2747,7 +2747,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params, int value) { params.embd_normalize = value; } - ).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_MTMD, LLAMA_EXAMPLE_DEBUG})); add_opt(common_arg( {"--embd-output-format"}, "FORMAT", diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index fbf30b81b3..8dc7b17a0a 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1808,7 +1808,7 @@ class MmprojModel(ModelBase): preprocessor_config: dict[str, Any] global_config: dict[str, Any] - n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth", "encoder_layers"] + n_block_keys = ["layers", "n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth", "encoder_layers"] has_vision_encoder: bool = True # by default has_audio_encoder: bool = False @@ -1830,7 +1830,13 @@ class MmprojModel(ModelBase): if "audio_config" not in self.hparams: self.hparams["audio_config"] = {} text_config = {**self.hparams, **self.hparams["text_config"]} - self.n_embd_text = text_config.get("hidden_size", text_config.get("n_embd", 0)) + n_embd_text = ( + text_config.get("hidden_size") + or text_config.get("n_embd") + or text_config.get("embed_dim") + or 0 + ) + self.n_embd_text = int(n_embd_text) if n_embd_text else 0 else: text_config = { k: v for k, v in self.hparams.items() if k not in ["vision_encoder", "audio_encoder"] @@ -5635,13 +5641,6 @@ class XLMRobertaModel(BertModel): if lora_names := hparams.get("lora_adaptations"): self._lora_names = lora_names - - pe_type = (hparams.get("position_embedding_type") or "").lower() - rope_base = hparams.get("rotary_emb_base") - name_path = (hparams.get("_name_or_path") or "").lower() - is_vx = ("jina" in name_path and ("v2" in name_path or "v3" in name_path)) - is_v3 = (pe_type == "rotary" or rope_base is not None) and is_vx - if is_v3 or self._lora_names: self.model_arch = gguf.MODEL_ARCH.JINA_BERT_V3 super().__init__(dir_model, ftype, fname_out, hparams=hparams, **kwargs) @@ -7056,39 +7055,76 @@ class JinaBertV2Model(BertModel): raise NotImplementedError(f'Tokenizer {tokenizer_class} is not supported for JinaBertModel') -@ModelBase.register("JinaCLIPVisionModel", "JinaCLIPModel") -class JinaCLIPVisionModel(MmprojModel): - """JinaCLIP v2 Vision Encoder Model - handles vision component only""" - model_arch = gguf.MODEL_ARCH.MMPROJ +@ModelBase.register("JinaCLIPModel") +class JinaCLIPTextModel(XLMRobertaModel): + model_arch = gguf.MODEL_ARCH.BERT + _text_prefix = "text_model.transformer." - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + @staticmethod + def _load_json_file(path: Path) -> dict[str, Any]: + with open(path, "r", encoding="utf-8") as f: + return json.load(f) - # Load config for vision encoder - config_path = self.dir_model / "config.json" - if not config_path.exists(): - raise FileNotFoundError( - f"JinaCLIPVisionModel: missing config.json in {self.dir_model}. " - "Please ensure the original model config is present; default hyperparameter fallbacks are not used." - ) - with open(config_path, encoding="utf-8") as f: - self.vision_config = json.load(f) - - def set_vocab(self): - # Vision encoder doesn't need vocabulary - pass - - def set_gguf_parameters(self): - cfg = self.vision_config + @staticmethod + def _load_hf_config_json(hf_name_or_path: str) -> dict[str, Any]: + p = Path(hf_name_or_path) + if p.is_dir(): + cfg_path = p / "config.json" + if cfg_path.is_file(): + return JinaCLIPTextModel._load_json_file(cfg_path) try: - width = int(cfg["width"]) # channel dim - head_width = int(cfg["head_width"]) # per-head dim - layers = int(cfg["layers"]) # block count - image_size = int(cfg["image_size"]) # input image size - patch_size = int(cfg["patch_size"]) # patch size - except KeyError as e: - raise KeyError(f"JinaCLIPVisionModel: missing key in config.json: {e}") + from huggingface_hub import hf_hub_download + except Exception: + raise ImportError( + "huggingface_hub is required to fetch the text tower config.json for JinaClip; " + "install this package or provide a local path in text_config.hf_model_name_or_path." + ) + + try: + cfg_path = Path(hf_hub_download(repo_id=hf_name_or_path, filename="config.json", local_files_only=True)) + except Exception: + cfg_path = Path(hf_hub_download(repo_id=hf_name_or_path, filename="config.json", local_files_only=False)) + return JinaCLIPTextModel._load_json_file(cfg_path) + + def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, **kwargs: Any): + jinaclip_hparams = ModelBase.load_hparams(dir_model, False) + text_cfg = jinaclip_hparams.get("text_config") or {} + hf_name = text_cfg.get("hf_model_name_or_path") + if not hf_name: + raise KeyError("JinaCLIPTextModel: missing text_config.hf_model_name_or_path in config.json") + + base_cfg = self._load_hf_config_json(str(hf_name)) + + overrides = text_cfg.get("hf_model_config_kwargs") or {} + if not isinstance(overrides, dict): + raise TypeError("JinaCLIPTextModel: text_config.hf_model_config_kwargs must be a dict") + + merged_hparams = {**base_cfg, **overrides} + + kwargs["hparams"] = merged_hparams + + super().__init__(dir_model, ftype, fname_out, **kwargs) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if not name.startswith(self._text_prefix): + return [] + + name = name[len(self._text_prefix):] + return super().modify_tensors(data_torch, name, bid) + + +@ModelBase.register("JinaCLIPModel") +class JinaCLIPVisionModel(MmprojModel): + + def set_gguf_parameters(self): + cfg = self.hparams + + width = int(self.find_hparam(["width"])) + head_width = int(self.find_hparam(["head_width"])) + layers = int(self.find_hparam(["layers"])) + image_size = int(self.find_hparam(["image_size"])) + patch_size = int(self.find_hparam(["patch_size"])) if width % head_width != 0: raise ValueError( @@ -7103,172 +7139,38 @@ class JinaCLIPVisionModel(MmprojModel): else: raise ValueError("JinaCLIPVisionModel: unable to infer FFN size; please provide 'mlp_ratio' or set 'naive_swiglu' in config.json") + self.gguf_writer.add_file_type(self.ftype) self.gguf_writer.add_clip_has_vision_encoder(True) - proj_dim = int(cfg.get("projection_dim", width)) + proj_dim = int(self.global_config.get("projection_dim") or cfg.get("embed_dim") or width) self.gguf_writer.add_vision_projection_dim(proj_dim) self.gguf_writer.add_vision_image_size(image_size) self.gguf_writer.add_vision_patch_size(patch_size) self.gguf_writer.add_vision_embedding_length(width) + self.gguf_writer.add_vision_feed_forward_length(n_ff) self.gguf_writer.add_vision_block_count(layers) self.gguf_writer.add_vision_head_count(n_head) - self.gguf_writer.add_vision_feed_forward_length(n_ff) - self.gguf_writer.add_vision_attention_layernorm_eps(float(cfg.get("layer_norm_eps", 1e-5))) + self.gguf_writer.add_vision_attention_layernorm_eps(float(cfg.get("layer_norm_eps", 1e-6))) - mean = self.preprocessor_config.get("image_mean", self.preprocessor_config.get("mean")) - std = self.preprocessor_config.get("image_std", self.preprocessor_config.get("std")) - if mean is None or std is None: - raise KeyError( - "JinaCLIPVisionModel: preprocessor_config missing image mean/std (expected keys: 'image_mean'/'image_std' or 'mean'/'std')" - ) + # JinaClip v2 uses mean/std in preprocessor_config.json + mean = self.preprocessor_config["mean"] + std = self.preprocessor_config["std"] self.gguf_writer.add_vision_image_mean(mean) self.gguf_writer.add_vision_image_std(std) self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.JINACLIP2) self.gguf_writer.add_vision_use_silu(True) - def _strip_vm_prefix(self, name: str) -> str: - return name[len('vision_model.'):] if name.startswith('vision_model.') else name - - def _map_block_tensor(self, layer: int, rest: str, data_torch: Tensor, name: str) -> list[tuple[str, Tensor]] | None: - parts = rest.split('.') - # layer norms - if rest.startswith('norm1.'): - suffix = parts[-1] - return [(f'v.blk.{layer}.ln1.{suffix}', data_torch)] - if rest.startswith('norm2.'): - suffix = parts[-1] - return [(f'v.blk.{layer}.ln2.{suffix}', data_torch)] - if rest.startswith('attn.inner_attn_ln.'): - suffix = parts[-1] - return [(f'v.blk.{layer}.attn_ln.{suffix}', data_torch)] - - if rest == 'attn.q_bias': - return [(f'v.blk.{layer}.attn_q.bias', data_torch)] - if rest == 'attn.v_bias': - return [(f'v.blk.{layer}.attn_v.bias', data_torch)] - - if rest.startswith('attn.q_proj.'): - suffix = parts[-1] - return [(f'v.blk.{layer}.attn_q.{suffix}', data_torch)] - if rest.startswith('attn.k_proj.'): - suffix = parts[-1] - return [(f'v.blk.{layer}.attn_k.{suffix}', data_torch)] - if rest.startswith('attn.v_proj.'): - suffix = parts[-1] - return [(f'v.blk.{layer}.attn_v.{suffix}', data_torch)] - if rest.startswith('attn.proj.'): - suffix = parts[-1] - return [(f'v.blk.{layer}.attn_out.{suffix}', data_torch)] - - # MLP - if rest.startswith('mlp.w1.'): - suffix = parts[-1] - return [(f'v.blk.{layer}.ffn_gate.{suffix}', data_torch)] - if rest.startswith('mlp.w2.'): - suffix = parts[-1] - return [(f'v.blk.{layer}.ffn_up.{suffix}', data_torch)] - if rest.startswith('mlp.w3.'): - suffix = parts[-1] - return [(f'v.blk.{layer}.ffn_down.{suffix}', data_torch)] - if rest.startswith('mlp.ffn_ln.'): - suffix = parts[-1] - return [(f'v.blk.{layer}.ffn_norm.{suffix}', data_torch)] - if rest.startswith('mlp.fc1.'): - suffix = parts[-1] - return [(f'v.blk.{layer}.ffn_up.{suffix}', data_torch)] - if rest.startswith('mlp.fc2.'): - suffix = parts[-1] - return [(f'v.blk.{layer}.ffn_down.{suffix}', data_torch)] - return None - - def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", ".bias")) -> str: - """Prefer base table-driven mapping; keep Jina-specific targets if already mapped; fallback to legacy mapper.""" - # Already a GGUF target name (e.g., "v.*" or "mm.*"): return as-is - if name.startswith('v.') or name.startswith('mm.'): - return name - # Try the base mapping first - try: - return super().map_tensor_name(name, try_suffixes=try_suffixes) - except Exception: - # Fallback to legacy Jina-specific mapper for any remaining edge keys - if hasattr(self, "_map_jinaclip_tensor_name"): - mapped = self._map_jinaclip_tensor_name(name) # type: ignore[attr-defined] - if mapped: - return mapped - return name - - def get_tensors(self) -> Iterator[tuple[str, Tensor]]: - yielded_any = False - try: - for name, tensor in super().get_tensors(): - yielded_any = True - yield name, tensor - except Exception as e: - logger.warning("mmproj(jinaclip): base get_tensors failed, falling back: %s", e) - if yielded_any: - return - - candidates = [ - self.dir_model / "pytorch_model.bin", - self.dir_model / "vision_model_weights.bin", - ] - model_path = next((p for p in candidates if p.exists()), None) - if model_path is None: - raise FileNotFoundError(f"mmproj(jinaclip): no model weights found in {self.dir_model}") - try: - state_dict = torch.load(model_path, map_location="cpu", weights_only=True) - except TypeError: - state_dict = torch.load(model_path, map_location="cpu") - - for name, tensor in state_dict.items(): - yield name, tensor - - def _should_be_f32(self, gguf_name: str) -> bool: - patterns = ( - ".ln1.weight", ".ln1.bias", - ".ln2.weight", ".ln2.bias", - ".attn_ln.weight", ".attn_ln.bias", - ".ffn_norm.weight", ".ffn_norm.bias", - "v.patch_embd.proj.bias", - ) - return any(p in gguf_name for p in patterns) - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - del bid # unused + if name.startswith("vision_model."): + name = name[len("vision_model."):] + elif not (name.startswith("v.") or name.startswith("mm.")): + return [] - src = name - if src.startswith('v.') or src.startswith('mm.'): - return [(src, data_torch)] - - # Drop 'vision_model.' prefix if present - src_no_vm = self._strip_vm_prefix(src) - - # Top-level direct mappings — use gguf constants directly for canonical names - if src_no_vm == 'cls_token': - base = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_CLS] - return [(base, data_torch)] - if src_no_vm.startswith('patch_embed.proj.'): - suffix = src_no_vm.split('.')[-1] - base = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] - return [(f'{base}.{suffix}', data_torch)] - if src_no_vm == 'pos_embed': - pos_name = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_POS] + '.weight' + if name == "pos_embed": + pos_name = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_POS] + ".weight" return [(pos_name, data_torch)] - if src_no_vm.startswith('norm.'): - suffix = src_no_vm.split('.')[-1] - base = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_POST_NORM] - return [(f'{base}.{suffix}', data_torch)] - - if src_no_vm.startswith('blocks.'): - parts = src_no_vm.split('.') - if len(parts) >= 3 and parts[1].isdigit(): - layer = int(parts[1]) - rest = '.'.join(parts[2:]) - mapped = self._map_block_tensor(layer, rest, data_torch, name) - if mapped is not None: - return mapped try: return [(self.map_tensor_name(name), data_torch)] diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 8d91a47797..86754346b8 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -668,9 +668,13 @@ class MODEL_TENSOR(IntEnum): V_ENC_ATTN_O = auto() V_ENC_ATTN_O_NORM = auto() V_ENC_POST_ATTN_NORM = auto() + V_ENC_ATTN_LN = auto() V_ENC_FFN_UP = auto() V_ENC_FFN_GATE = auto() V_ENC_FFN_DOWN = auto() + V_ENC_FFN_NORM = auto() + V_ENC_ATTN_Q_BIAS = auto() + V_ENC_ATTN_V_BIAS = auto() V_LAYER_SCALE_1 = auto() V_LAYER_SCALE_2 = auto() V_PRE_NORM = auto() @@ -1086,9 +1090,13 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.V_ENC_ATTN_O: "v.blk.{bid}.attn_out", MODEL_TENSOR.V_ENC_ATTN_O_NORM: "v.blk.{bid}.attn_out_norm", MODEL_TENSOR.V_ENC_POST_ATTN_NORM: "v.blk.{bid}.ln2", + MODEL_TENSOR.V_ENC_ATTN_LN: "v.blk.{bid}.attn_ln", MODEL_TENSOR.V_ENC_FFN_UP: "v.blk.{bid}.ffn_up", MODEL_TENSOR.V_ENC_FFN_GATE: "v.blk.{bid}.ffn_gate", MODEL_TENSOR.V_ENC_FFN_DOWN: "v.blk.{bid}.ffn_down", + MODEL_TENSOR.V_ENC_FFN_NORM: "v.blk.{bid}.ffn_norm", + MODEL_TENSOR.V_ENC_ATTN_Q_BIAS: "v.blk.{bid}.attn_q.bias", + MODEL_TENSOR.V_ENC_ATTN_V_BIAS: "v.blk.{bid}.attn_v.bias", MODEL_TENSOR.V_LAYER_SCALE_1: "v.blk.{bid}.ls1", MODEL_TENSOR.V_LAYER_SCALE_2: "v.blk.{bid}.ls2", MODEL_TENSOR.V_PRE_NORM: "v.pre_ln", @@ -1204,9 +1212,13 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.V_ENC_ATTN_O, MODEL_TENSOR.V_ENC_ATTN_O_NORM, MODEL_TENSOR.V_ENC_POST_ATTN_NORM, + MODEL_TENSOR.V_ENC_ATTN_LN, MODEL_TENSOR.V_ENC_FFN_UP, MODEL_TENSOR.V_ENC_FFN_GATE, MODEL_TENSOR.V_ENC_FFN_DOWN, + MODEL_TENSOR.V_ENC_FFN_NORM, + MODEL_TENSOR.V_ENC_ATTN_Q_BIAS, + MODEL_TENSOR.V_ENC_ATTN_V_BIAS, MODEL_TENSOR.V_LAYER_SCALE_1, MODEL_TENSOR.V_LAYER_SCALE_2, MODEL_TENSOR.V_PRE_NORM, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 84aa868809..684255495d 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -1281,6 +1281,7 @@ class TensorNameMap: "model.vision_tower.embeddings.cls_token", # Intern-S1 "vision_model.class_embedding", # llama 4 "model.vision.patch_embedding.cls_embedding", # cogvlm + "cls_token", # JinaCLIP v2 vision ), MODEL_TENSOR.V_ENC_EMBD_PATCH: ( @@ -1295,6 +1296,7 @@ class TensorNameMap: "vision_tower.patch_embed.proj", # kimi-vl "model.vision.patch_embedding.proj", # cogvlm "siglip2.vision_model.embeddings.patch_embedding", + "patch_embed.proj", # JinaCLIP v2 vision ), MODEL_TENSOR.V_ENC_EMBD_NORM: ( @@ -1329,6 +1331,7 @@ class TensorNameMap: "visual.blocks.{bid}.attn.q", # qwen2vl, generated "vision_tower.encoder.blocks.{bid}.wq", # kimi-vl, generated "siglip2.vision_model.encoder.layers.{bid}.self_attn.q_proj", # youtuvl + "blocks.{bid}.attn.q_proj", # JinaCLIP v2 vision ), MODEL_TENSOR.V_ENC_ATTN_Q_NORM: ( @@ -1347,6 +1350,7 @@ class TensorNameMap: "visual.blocks.{bid}.attn.k", # qwen2vl, generated "vision_tower.encoder.blocks.{bid}.wk", # kimi-vl, generated "siglip2.vision_model.encoder.layers.{bid}.self_attn.k_proj", + "blocks.{bid}.attn.k_proj", # JinaCLIP v2 vision ), MODEL_TENSOR.V_ENC_ATTN_K_NORM: ( @@ -1365,6 +1369,7 @@ class TensorNameMap: "visual.blocks.{bid}.attn.v", # qwen2vl, generated "vision_tower.encoder.blocks.{bid}.wv", # kimi-vl, generated "siglip2.vision_model.encoder.layers.{bid}.self_attn.v_proj", + "blocks.{bid}.attn.v_proj", # JinaCLIP v2 vision ), MODEL_TENSOR.V_ENC_INPUT_NORM: ( @@ -1380,6 +1385,7 @@ class TensorNameMap: "vision_tower.encoder.blocks.{bid}.norm0", # kimi-vl (norm0/norm1) "model.vision.transformer.layers.{bid}.input_layernorm", # cogvlm "siglip2.vision_model.encoder.layers.{bid}.layer_norm1", + "blocks.{bid}.norm1", # JinaCLIP v2 vision ), MODEL_TENSOR.V_ENC_ATTN_O: ( @@ -1396,6 +1402,7 @@ class TensorNameMap: "vision_tower.encoder.blocks.{bid}.wo", # kimi-vl "model.vision.transformer.layers.{bid}.attention.dense", # cogvlm "siglip2.vision_model.encoder.layers.{bid}.self_attn.out_proj", # youtuvl + "blocks.{bid}.attn.proj", # JinaCLIP v2 vision ), MODEL_TENSOR.V_ENC_POST_ATTN_NORM: ( @@ -1411,6 +1418,11 @@ class TensorNameMap: "vision_tower.encoder.blocks.{bid}.norm1", # kimi-vl (norm0/norm1) "model.vision.transformer.layers.{bid}.post_attention_layernorm", # cogvlm "siglip2.vision_model.encoder.layers.{bid}.layer_norm2", + "blocks.{bid}.norm2", # JinaCLIP v2 vision + ), + + MODEL_TENSOR.V_ENC_ATTN_LN: ( + "blocks.{bid}.attn.inner_attn_ln", # JinaCLIP v2 vision ), MODEL_TENSOR.V_ENC_FFN_UP: ( @@ -1427,12 +1439,14 @@ class TensorNameMap: "vision_tower.encoder.blocks.{bid}.mlp.fc0", # kimi-vl (fc0/fc1) "model.vision.transformer.layers.{bid}.mlp.fc1", # cogvlm "siglip2.vision_model.encoder.layers.{bid}.mlp.fc1", + "blocks.{bid}.mlp.w2", # JinaCLIP v2 vision (up) ), MODEL_TENSOR.V_ENC_FFN_GATE: ( "vision_tower.transformer.layers.{bid}.feed_forward.gate_proj", # pixtral-hf "vision_encoder.transformer.layers.{bid}.feed_forward.w1", # pixtral "visual.blocks.{bid}.mlp.gate_proj", # qwen2.5vl + "blocks.{bid}.mlp.w1", # JinaCLIP v2 vision ), MODEL_TENSOR.V_ENC_FFN_DOWN: ( @@ -1449,6 +1463,11 @@ class TensorNameMap: "vision_tower.encoder.blocks.{bid}.mlp.fc1", # kimi-vl (fc0/fc1) "model.vision.transformer.layers.{bid}.mlp.fc2", # cogvlm "siglip2.vision_model.encoder.layers.{bid}.mlp.fc2", + "blocks.{bid}.mlp.w3", # JinaCLIP v2 vision (down) + ), + + MODEL_TENSOR.V_ENC_FFN_NORM: ( + "blocks.{bid}.mlp.ffn_ln", # JinaCLIP v2 vision ), MODEL_TENSOR.V_LAYER_SCALE_1: ( @@ -1461,6 +1480,14 @@ class TensorNameMap: "model.vision_tower.encoder.layer.{bid}.lambda_2", # Intern-S1 ), + MODEL_TENSOR.V_ENC_ATTN_Q_BIAS: ( + "blocks.{bid}.attn.q_bias", # JinaCLIP v2 vision + ), + + MODEL_TENSOR.V_ENC_ATTN_V_BIAS: ( + "blocks.{bid}.attn.v_bias", # JinaCLIP v2 vision + ), + MODEL_TENSOR.V_PRE_NORM: ( "vision_tower.vision_model.pre_layrnorm", "vision_tower.ln_pre", # pixtral-hf @@ -1474,6 +1501,7 @@ class TensorNameMap: "vision_model.layernorm_post", # llama4 "visual.merger.ln_q", # qwen2vl "vision_tower.encoder.final_layernorm", # kimi-vl + "norm", # JinaCLIP v2 vision "visual.post_layernorm", # glm4v "siglip2.vision_model.post_layernorm", ), diff --git a/tools/mtmd/clip-graph.h b/tools/mtmd/clip-graph.h index 4c7f7504cf..62145d6dc9 100644 --- a/tools/mtmd/clip-graph.h +++ b/tools/mtmd/clip-graph.h @@ -31,6 +31,7 @@ struct clip_graph { const float eps; const float kq_scale; const clip_flash_attn_type flash_attn_type; + norm_type block_norm_t = NORM_TYPE_NORMAL; ggml_context_ptr ctx0_ptr; ggml_context * ctx0; diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index 6da23e120b..ea016b2e57 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -76,15 +76,15 @@ #define TN_ATTN_Q "%s.blk.%d.attn_q.%s" #define TN_ATTN_V "%s.blk.%d.attn_v.%s" #define TN_ATTN_OUTPUT "%s.blk.%d.attn_out.%s" -#define TN_ATTN_LN "%s.blk.%d.attn_ln.%s" // inner attention LayerNorm +#define TN_ATTN_LN "%s.blk.%d.attn_ln.%s" #define TN_ATTN_K_NORM "%s.blk.%d.attn_k_norm.%s" #define TN_ATTN_Q_NORM "%s.blk.%d.attn_q_norm.%s" #define TN_FFN_DOWN "%s.blk.%d.ffn_down.%s" #define TN_FFN_GATE "%s.blk.%d.ffn_gate.%s" #define TN_FFN_UP "%s.blk.%d.ffn_up.%s" -#define TN_FFN_NORM "%s.blk.%d.ffn_norm.%s" -#define TN_LN_1 "%s.blk.%d.ln1.%s" -#define TN_LN_2 "%s.blk.%d.ln2.%s" +#define TN_FFN_GATE "%s.blk.%d.ffn_gate.%s" +#define TN_LN_1 "%s.blk.%d.ln1.%s" // layer norm +#define TN_LN_2 "%s.blk.%d.ln2.%s" // layer norm #define TN_LS_1 "%s.blk.%d.ls1.%s" // layer scale #define TN_LS_2 "%s.blk.%d.ls2.%s" // layer scale #define TN_LN_PRE "%s.pre_ln.%s" diff --git a/tools/mtmd/clip-model.h b/tools/mtmd/clip-model.h index d4ff9151bb..4c01a8ad54 100644 --- a/tools/mtmd/clip-model.h +++ b/tools/mtmd/clip-model.h @@ -117,6 +117,9 @@ struct clip_layer { ggml_tensor * k_norm = nullptr; ggml_tensor * q_norm = nullptr; + ggml_tensor * attn_out_norm_w = nullptr; + ggml_tensor * attn_out_norm_b = nullptr; + // layernorm 1 ggml_tensor * ln_1_w = nullptr; ggml_tensor * ln_1_b = nullptr; @@ -125,6 +128,8 @@ struct clip_layer { ggml_tensor * ff_up_b = nullptr; ggml_tensor * ff_gate_w = nullptr; ggml_tensor * ff_gate_b = nullptr; + ggml_tensor * ffn_hidden_norm_w = nullptr; + ggml_tensor * ffn_hidden_norm_b = nullptr; ggml_tensor * ff_down_w = nullptr; ggml_tensor * ff_down_b = nullptr; diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 25a93c050c..c9d90bc82f 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -292,6 +292,8 @@ ggml_tensor * clip_graph::build_vit( ggml_tensor * learned_pos_embd, std::function add_pos ) { + block_norm_t = norm_t; + if (learned_pos_embd) { inp = ggml_add(ctx0, inp, learned_pos_embd); cb(inp, "pos_embed", -1); @@ -559,6 +561,14 @@ ggml_tensor * clip_graph::build_ffn( } break; } + if (il >= 0 && il < (int) model.layers.size()) { + const auto & layer = model.layers[il]; + if (layer.ffn_hidden_norm_w) { + cur = build_norm(cur, layer.ffn_hidden_norm_w, layer.ffn_hidden_norm_b, block_norm_t, eps, il); + cb(cur, "ffn_hidden_normed", il); + } + } + if (down) { cur = ggml_mul_mat(ctx0, down, cur); } @@ -628,6 +638,14 @@ ggml_tensor * clip_graph::build_attn( cb(cur, "kqv_out", il); + if (il >= 0 && il < (int) model.layers.size()) { + const auto & layer = model.layers[il]; + if (layer.attn_out_norm_w) { + cur = build_norm(cur, layer.attn_out_norm_w, layer.attn_out_norm_b, block_norm_t, eps, il); + cb(cur, "kqv_out_normed", il); + } + } + if (wo) { cur = ggml_mul_mat(ctx0, wo, cur); } @@ -1204,7 +1222,7 @@ struct clip_model_loader { case PROJECTOR_TYPE_JINACLIP2: { hparams.rope_theta = 10000.0f; - get_f32(KEY_VISION_ROPE_THETA, hparams.rope_theta, /*required=*/false); + get_f32(KEY_VISION_ROPE_THETA, hparams.rope_theta, false); } break; case PROJECTOR_TYPE_ULTRAVOX: case PROJECTOR_TYPE_QWEN2A: @@ -1364,6 +1382,7 @@ struct clip_model_loader { layer.qkv_w = get_tensor(string_format(TN_ATTN_QKV, prefix, il, "weight"), false); layer.k_norm = get_tensor(string_format(TN_ATTN_K_NORM, prefix, il, "weight"), false); layer.q_norm = get_tensor(string_format(TN_ATTN_Q_NORM, prefix, il, "weight"), false); + layer.attn_out_norm_w = get_tensor(string_format(TN_ATTN_LN, prefix, il, "weight"), false); layer.ln_1_w = get_tensor(string_format(TN_LN_1, prefix, il, "weight"), false); layer.ln_2_w = get_tensor(string_format(TN_LN_2, prefix, il, "weight"), false); layer.ls_1_w = get_tensor(string_format(TN_LS_1, prefix, il, "weight"), false); // no bias @@ -1374,6 +1393,7 @@ struct clip_model_loader { layer.v_b = get_tensor(string_format(TN_ATTN_V, prefix, il, "bias"), false); layer.o_b = get_tensor(string_format(TN_ATTN_OUTPUT, prefix, il, "bias"), false); layer.qkv_b = get_tensor(string_format(TN_ATTN_QKV, prefix, il, "bias"), false); + layer.attn_out_norm_b = get_tensor(string_format(TN_ATTN_LN, prefix, il, "bias"), false); layer.ln_1_b = get_tensor(string_format(TN_LN_1, prefix, il, "bias"), false); layer.ln_2_b = get_tensor(string_format(TN_LN_2, prefix, il, "bias"), false); @@ -1382,6 +1402,8 @@ struct clip_model_loader { layer.ff_up_b = get_tensor(string_format(TN_FFN_UP, prefix, il, "bias"), false); layer.ff_gate_w = get_tensor(string_format(TN_FFN_GATE, prefix, il, "weight"), false); layer.ff_gate_b = get_tensor(string_format(TN_FFN_GATE, prefix, il, "bias"), false); + layer.ffn_hidden_norm_w = get_tensor(string_format(TN_FFN_NORM, prefix, il, "weight"), false); + layer.ffn_hidden_norm_b = get_tensor(string_format(TN_FFN_NORM, prefix, il, "bias"), false); layer.ff_down_w = get_tensor(string_format(TN_FFN_DOWN, prefix, il, "weight")); layer.ff_down_b = get_tensor(string_format(TN_FFN_DOWN, prefix, il, "bias"), false); @@ -1793,7 +1815,6 @@ struct clip_model_loader { } break; case PROJECTOR_TYPE_JINACLIP2: { - // JinaCLIP2 is a pure vision encoder without additional projection layers. } break; case PROJECTOR_TYPE_LFM2A: { @@ -3035,7 +3056,6 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str clip_image_u8 processed_image; const int sz = params.image_size; - // 1) Preserve aspect ratio: resize so that the shorter side == sz (bicubic). const int in_w = img->nx; const int in_h = img->ny; if (in_w <= 0 || in_h <= 0) { @@ -3055,14 +3075,12 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str clip_image_u8 resized_keep_ratio; img_tool::resize(*img, resized_keep_ratio, clip_image_size{out_w, out_h}, img_tool::RESIZE_ALGO_BICUBIC); - // 2) Center-crop to sz x sz. const int x0 = std::max(0, (resized_keep_ratio.nx - sz) / 2); const int y0 = std::max(0, (resized_keep_ratio.ny - sz) / 2); const int crop_w = std::min(sz, resized_keep_ratio.nx); const int crop_h = std::min(sz, resized_keep_ratio.ny); img_tool::crop(resized_keep_ratio, processed_image, x0, y0, crop_w, crop_h); - // 3) Normalize. clip_image_f32_ptr img_f32(clip_image_f32_init()); normalize_image_u8_to_f32(processed_image, *img_f32, params.image_mean, params.image_std); res_imgs->entries.push_back(std::move(img_f32)); @@ -3699,19 +3717,14 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima const int half_local = d_head_local / 2; std::vector rope_c_first(half_local); std::vector rope_c_second(half_local); - const float odd = std::pow(hparams.rope_theta, (float) -2.0f / (float) d_head_local); for (int k = 0; k < half_local; ++k) { rope_c_first[k] = 1.0f / s; - rope_c_second[k] = 1.0f / (s * odd); + rope_c_second[k] = 1.0f / s; } - ggml_tensor * t1 = ggml_graph_get_tensor(gf, "rope_c_first"); - ggml_tensor * t2 = ggml_graph_get_tensor(gf, "rope_c_second"); - GGML_ASSERT(t1 && (t1->flags & GGML_TENSOR_FLAG_INPUT)); - GGML_ASSERT(t2 && (t2->flags & GGML_TENSOR_FLAG_INPUT)); - ggml_backend_tensor_set(t1, rope_c_first.data(), 0, ggml_nbytes(t1)); - ggml_backend_tensor_set(t2, rope_c_second.data(), 0, ggml_nbytes(t2)); + set_input_f32("rope_c_first", rope_c_first); + set_input_f32("rope_c_second", rope_c_second); } break; case PROJECTOR_TYPE_MLP: case PROJECTOR_TYPE_MLP_NORM: diff --git a/tools/mtmd/clip.h b/tools/mtmd/clip.h index 53f77b43af..71b58484d6 100644 --- a/tools/mtmd/clip.h +++ b/tools/mtmd/clip.h @@ -111,7 +111,6 @@ bool clip_is_llava(const struct clip_ctx * ctx); // note for contributor: this clip_is_(model) pattern is deprecated // do NOT add new functions like this - bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec); // use by audio input diff --git a/tools/mtmd/models/jinaclip2.cpp b/tools/mtmd/models/jinaclip2.cpp index 47527233fe..7086a7cfbe 100644 --- a/tools/mtmd/models/jinaclip2.cpp +++ b/tools/mtmd/models/jinaclip2.cpp @@ -10,12 +10,10 @@ ggml_cgraph * clip_graph_jinaclip2::build() { GGML_ASSERT(n_patches_x == n_patches_y && "only square images supported"); - // input for learned position embeddings ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos); ggml_set_name(positions, "positions"); ggml_set_input(positions); - // inputs for 2D RoPE positions (includes CLS at index 0) ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos); ggml_set_name(pos_h, "pos_h"); ggml_set_input(pos_h); @@ -24,7 +22,6 @@ ggml_cgraph * clip_graph_jinaclip2::build() { ggml_set_name(pos_w, "pos_w"); ggml_set_input(pos_w); - // frequency scaling factors for the 2D RoPE halves GGML_ASSERT(d_head % 2 == 0); ggml_tensor * rope_c_first = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, d_head / 2); ggml_set_name(rope_c_first, "rope_c_first"); @@ -41,7 +38,7 @@ ggml_cgraph * clip_graph_jinaclip2::build() { inp = ggml_add(ctx0, inp, ggml_get_rows(ctx0, model.position_embeddings, positions)); auto apply_rope_2d = [&](ggml_tensor * cur) -> ggml_tensor * { - // cur is [d_head, n_head, n_pos]; convert to [d_head, n_pos, n_head] for convenient slicing + ggml_tensor * cur_in = ggml_permute(ctx0, cur, 0, 2, 1, 3); const int64_t n_dim = cur_in->ne[0]; @@ -64,11 +61,10 @@ ggml_cgraph * clip_graph_jinaclip2::build() { pos_offset = 1; } - // select positions for patch tokens + // select positions ggml_tensor * pos_a = ggml_view_1d(ctx0, pos_h, n_pos_patches, pos_offset * (int64_t) ggml_element_size(pos_h)); ggml_tensor * pos_b = ggml_view_1d(ctx0, pos_w, n_pos_patches, pos_offset * (int64_t) ggml_element_size(pos_w)); - // first half (H) ggml_tensor * first = ggml_view_3d(ctx0, patches, half, nhead, n_pos_patches, patches->nb[2], patches->nb[1], 0); @@ -85,7 +81,6 @@ ggml_cgraph * clip_graph_jinaclip2::build() { half, n_pos_patches, nhead, first_rot->nb[2], first_rot->nb[1], 0); - // second half (W) ggml_tensor * second = ggml_view_3d(ctx0, patches, half, nhead, n_pos_patches, patches->nb[2], patches->nb[1], @@ -119,8 +114,7 @@ ggml_cgraph * clip_graph_jinaclip2::build() { nullptr, add_pos); - // Output: CLS embedding only (1 token). - ggml_tensor * cls = ggml_view_2d(ctx0, cur, cur->ne[0], /*rows=*/1, cur->nb[1], /*offset=*/0); + ggml_tensor * cls = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], 0); ggml_set_name(cls, "cls_view"); ggml_build_forward_expand(gf, cls); diff --git a/tools/mtmd/mtmd-cli.cpp b/tools/mtmd/mtmd-cli.cpp index e0b0eb67e9..9159b5cf56 100644 --- a/tools/mtmd/mtmd-cli.cpp +++ b/tools/mtmd/mtmd-cli.cpp @@ -41,11 +41,7 @@ static void show_additional_info(int /*argc*/, char ** argv) { "Experimental CLI for multimodal\n\n" "Usage: %s [options] -m --mmproj --image --audio