This commit is contained in:
pockers21 2026-02-02 00:18:18 +02:00 committed by GitHub
commit 6d71ec0e16
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 665 additions and 6 deletions

View File

@ -2746,14 +2746,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params, int value) {
params.embd_normalize = value;
}
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_DEBUG}));
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_MTMD, LLAMA_EXAMPLE_DEBUG}));
add_opt(common_arg(
{"--embd-output-format"}, "FORMAT",
"empty = default, \"array\" = [[],[]...], \"json\" = openai style, \"json+\" = same \"json\" + cosine similarity matrix, \"raw\" = plain whitespace-delimited output (one embedding per line)",
[](common_params & params, const std::string & value) {
params.embd_out = value;
}
).set_examples({LLAMA_EXAMPLE_EMBEDDING}));
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_MTMD}));
add_opt(common_arg(
{"--embd-separator"}, "STRING",
"separator of embeddings (default \\n) for example \"<#sep#>\"",

View File

@ -1808,7 +1808,7 @@ class MmprojModel(ModelBase):
preprocessor_config: dict[str, Any]
global_config: dict[str, Any]
n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth", "encoder_layers"]
n_block_keys = ["layers", "n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth", "encoder_layers"]
has_vision_encoder: bool = True # by default
has_audio_encoder: bool = False
@ -1830,7 +1830,13 @@ class MmprojModel(ModelBase):
if "audio_config" not in self.hparams:
self.hparams["audio_config"] = {}
text_config = {**self.hparams, **self.hparams["text_config"]}
self.n_embd_text = text_config.get("hidden_size", text_config.get("n_embd", 0))
n_embd_text = (
text_config.get("hidden_size")
or text_config.get("n_embd")
or text_config.get("embed_dim")
or 0
)
self.n_embd_text = int(n_embd_text) if n_embd_text else 0
else:
text_config = {
k: v for k, v in self.hparams.items() if k not in ["vision_encoder", "audio_encoder"]
@ -7049,6 +7055,130 @@ class JinaBertV2Model(BertModel):
raise NotImplementedError(f'Tokenizer {tokenizer_class} is not supported for JinaBertModel')
@ModelBase.register("JinaCLIPModel")
class JinaCLIPTextModel(XLMRobertaModel):
model_arch = gguf.MODEL_ARCH.BERT
_text_prefix = "text_model.transformer."
@staticmethod
def _load_json_file(path: Path) -> dict[str, Any]:
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
@staticmethod
def _load_hf_config_json(hf_name_or_path: str) -> dict[str, Any]:
p = Path(hf_name_or_path)
if p.is_dir():
cfg_path = p / "config.json"
if cfg_path.is_file():
return JinaCLIPTextModel._load_json_file(cfg_path)
try:
from huggingface_hub import hf_hub_download
except Exception:
raise ImportError(
"huggingface_hub is required to fetch the text tower config.json for JinaClip; "
"install this package or provide a local path in text_config.hf_model_name_or_path."
)
try:
cfg_path = Path(hf_hub_download(repo_id=hf_name_or_path, filename="config.json", local_files_only=True))
except Exception:
cfg_path = Path(hf_hub_download(repo_id=hf_name_or_path, filename="config.json", local_files_only=False))
return JinaCLIPTextModel._load_json_file(cfg_path)
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, **kwargs: Any):
jinaclip_hparams = ModelBase.load_hparams(dir_model, False)
text_cfg = jinaclip_hparams.get("text_config") or {}
hf_name = text_cfg.get("hf_model_name_or_path")
if not hf_name:
raise KeyError("JinaCLIPTextModel: missing text_config.hf_model_name_or_path in config.json")
base_cfg = self._load_hf_config_json(str(hf_name))
overrides = text_cfg.get("hf_model_config_kwargs") or {}
if not isinstance(overrides, dict):
raise TypeError("JinaCLIPTextModel: text_config.hf_model_config_kwargs must be a dict")
merged_hparams = {**base_cfg, **overrides}
kwargs["hparams"] = merged_hparams
super().__init__(dir_model, ftype, fname_out, **kwargs)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if not name.startswith(self._text_prefix):
return []
name = name[len(self._text_prefix):]
return super().modify_tensors(data_torch, name, bid)
@ModelBase.register("JinaCLIPModel")
class JinaCLIPVisionModel(MmprojModel):
def set_gguf_parameters(self):
cfg = self.hparams
width = int(self.find_hparam(["width"]))
head_width = int(self.find_hparam(["head_width"]))
layers = int(self.find_hparam(["layers"]))
image_size = int(self.find_hparam(["image_size"]))
patch_size = int(self.find_hparam(["patch_size"]))
if width % head_width != 0:
raise ValueError(
f"JinaCLIPVisionModel: width ({width}) not divisible by head_width ({head_width})"
)
n_head = width // head_width
if "mlp_ratio" in cfg:
n_ff = int(width * float(cfg["mlp_ratio"]))
elif bool(cfg.get("naive_swiglu", False)):
n_ff = int((width * 8) // 3)
else:
raise ValueError("JinaCLIPVisionModel: unable to infer FFN size; please provide 'mlp_ratio' or set 'naive_swiglu' in config.json")
self.gguf_writer.add_file_type(self.ftype)
self.gguf_writer.add_clip_has_vision_encoder(True)
proj_dim = int(self.global_config.get("projection_dim") or cfg.get("embed_dim") or width)
self.gguf_writer.add_vision_projection_dim(proj_dim)
self.gguf_writer.add_vision_image_size(image_size)
self.gguf_writer.add_vision_patch_size(patch_size)
self.gguf_writer.add_vision_embedding_length(width)
self.gguf_writer.add_vision_feed_forward_length(n_ff)
self.gguf_writer.add_vision_block_count(layers)
self.gguf_writer.add_vision_head_count(n_head)
self.gguf_writer.add_vision_attention_layernorm_eps(float(cfg.get("layer_norm_eps", 1e-6)))
# JinaClip v2 uses mean/std in preprocessor_config.json
mean = self.preprocessor_config["mean"]
std = self.preprocessor_config["std"]
self.gguf_writer.add_vision_image_mean(mean)
self.gguf_writer.add_vision_image_std(std)
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.JINACLIP2)
self.gguf_writer.add_vision_use_silu(True)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if name.startswith("vision_model."):
name = name[len("vision_model."):]
elif not (name.startswith("v.") or name.startswith("mm.")):
return []
if name == "pos_embed":
pos_name = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_POS] + ".weight"
return [(pos_name, data_torch)]
try:
return [(self.map_tensor_name(name), data_torch)]
except Exception:
logger.debug("mmproj(jinaclip): skip unmapped tensor %s", name)
return []
@ModelBase.register("OpenELMForCausalLM")
class OpenELMModel(TextModel):
model_arch = gguf.MODEL_ARCH.OPENELM

View File

@ -668,9 +668,13 @@ class MODEL_TENSOR(IntEnum):
V_ENC_ATTN_O = auto()
V_ENC_ATTN_O_NORM = auto()
V_ENC_POST_ATTN_NORM = auto()
V_ENC_ATTN_LN = auto()
V_ENC_FFN_UP = auto()
V_ENC_FFN_GATE = auto()
V_ENC_FFN_DOWN = auto()
V_ENC_FFN_NORM = auto()
V_ENC_ATTN_Q_BIAS = auto()
V_ENC_ATTN_V_BIAS = auto()
V_LAYER_SCALE_1 = auto()
V_LAYER_SCALE_2 = auto()
V_PRE_NORM = auto()
@ -1086,9 +1090,13 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.V_ENC_ATTN_O: "v.blk.{bid}.attn_out",
MODEL_TENSOR.V_ENC_ATTN_O_NORM: "v.blk.{bid}.attn_out_norm",
MODEL_TENSOR.V_ENC_POST_ATTN_NORM: "v.blk.{bid}.ln2",
MODEL_TENSOR.V_ENC_ATTN_LN: "v.blk.{bid}.attn_ln",
MODEL_TENSOR.V_ENC_FFN_UP: "v.blk.{bid}.ffn_up",
MODEL_TENSOR.V_ENC_FFN_GATE: "v.blk.{bid}.ffn_gate",
MODEL_TENSOR.V_ENC_FFN_DOWN: "v.blk.{bid}.ffn_down",
MODEL_TENSOR.V_ENC_FFN_NORM: "v.blk.{bid}.ffn_norm",
MODEL_TENSOR.V_ENC_ATTN_Q_BIAS: "v.blk.{bid}.attn_q.bias",
MODEL_TENSOR.V_ENC_ATTN_V_BIAS: "v.blk.{bid}.attn_v.bias",
MODEL_TENSOR.V_LAYER_SCALE_1: "v.blk.{bid}.ls1",
MODEL_TENSOR.V_LAYER_SCALE_2: "v.blk.{bid}.ls2",
MODEL_TENSOR.V_PRE_NORM: "v.pre_ln",
@ -1204,9 +1212,13 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.V_ENC_ATTN_O,
MODEL_TENSOR.V_ENC_ATTN_O_NORM,
MODEL_TENSOR.V_ENC_POST_ATTN_NORM,
MODEL_TENSOR.V_ENC_ATTN_LN,
MODEL_TENSOR.V_ENC_FFN_UP,
MODEL_TENSOR.V_ENC_FFN_GATE,
MODEL_TENSOR.V_ENC_FFN_DOWN,
MODEL_TENSOR.V_ENC_FFN_NORM,
MODEL_TENSOR.V_ENC_ATTN_Q_BIAS,
MODEL_TENSOR.V_ENC_ATTN_V_BIAS,
MODEL_TENSOR.V_LAYER_SCALE_1,
MODEL_TENSOR.V_LAYER_SCALE_2,
MODEL_TENSOR.V_PRE_NORM,
@ -3604,6 +3616,7 @@ class VisionProjectorType:
QWEN3VL = "qwen3vl_merger"
ULTRAVOX = "ultravox"
INTERNVL = "internvl"
JINACLIP2 = "jinaclip2"
QWEN2A = "qwen2a" # audio
GLMA = "glma" # audio
QWEN25O = "qwen2.5o" # omni

View File

@ -1281,6 +1281,7 @@ class TensorNameMap:
"model.vision_tower.embeddings.cls_token", # Intern-S1
"vision_model.class_embedding", # llama 4
"model.vision.patch_embedding.cls_embedding", # cogvlm
"cls_token", # JinaCLIP v2 vision
),
MODEL_TENSOR.V_ENC_EMBD_PATCH: (
@ -1295,6 +1296,7 @@ class TensorNameMap:
"vision_tower.patch_embed.proj", # kimi-vl
"model.vision.patch_embedding.proj", # cogvlm
"siglip2.vision_model.embeddings.patch_embedding",
"patch_embed.proj", # JinaCLIP v2 vision
),
MODEL_TENSOR.V_ENC_EMBD_NORM: (
@ -1329,6 +1331,7 @@ class TensorNameMap:
"visual.blocks.{bid}.attn.q", # qwen2vl, generated
"vision_tower.encoder.blocks.{bid}.wq", # kimi-vl, generated
"siglip2.vision_model.encoder.layers.{bid}.self_attn.q_proj", # youtuvl
"blocks.{bid}.attn.q_proj", # JinaCLIP v2 vision
),
MODEL_TENSOR.V_ENC_ATTN_Q_NORM: (
@ -1347,6 +1350,7 @@ class TensorNameMap:
"visual.blocks.{bid}.attn.k", # qwen2vl, generated
"vision_tower.encoder.blocks.{bid}.wk", # kimi-vl, generated
"siglip2.vision_model.encoder.layers.{bid}.self_attn.k_proj",
"blocks.{bid}.attn.k_proj", # JinaCLIP v2 vision
),
MODEL_TENSOR.V_ENC_ATTN_K_NORM: (
@ -1365,6 +1369,7 @@ class TensorNameMap:
"visual.blocks.{bid}.attn.v", # qwen2vl, generated
"vision_tower.encoder.blocks.{bid}.wv", # kimi-vl, generated
"siglip2.vision_model.encoder.layers.{bid}.self_attn.v_proj",
"blocks.{bid}.attn.v_proj", # JinaCLIP v2 vision
),
MODEL_TENSOR.V_ENC_INPUT_NORM: (
@ -1380,6 +1385,7 @@ class TensorNameMap:
"vision_tower.encoder.blocks.{bid}.norm0", # kimi-vl (norm0/norm1)
"model.vision.transformer.layers.{bid}.input_layernorm", # cogvlm
"siglip2.vision_model.encoder.layers.{bid}.layer_norm1",
"blocks.{bid}.norm1", # JinaCLIP v2 vision
),
MODEL_TENSOR.V_ENC_ATTN_O: (
@ -1396,6 +1402,7 @@ class TensorNameMap:
"vision_tower.encoder.blocks.{bid}.wo", # kimi-vl
"model.vision.transformer.layers.{bid}.attention.dense", # cogvlm
"siglip2.vision_model.encoder.layers.{bid}.self_attn.out_proj", # youtuvl
"blocks.{bid}.attn.proj", # JinaCLIP v2 vision
),
MODEL_TENSOR.V_ENC_POST_ATTN_NORM: (
@ -1411,6 +1418,11 @@ class TensorNameMap:
"vision_tower.encoder.blocks.{bid}.norm1", # kimi-vl (norm0/norm1)
"model.vision.transformer.layers.{bid}.post_attention_layernorm", # cogvlm
"siglip2.vision_model.encoder.layers.{bid}.layer_norm2",
"blocks.{bid}.norm2", # JinaCLIP v2 vision
),
MODEL_TENSOR.V_ENC_ATTN_LN: (
"blocks.{bid}.attn.inner_attn_ln", # JinaCLIP v2 vision
),
MODEL_TENSOR.V_ENC_FFN_UP: (
@ -1427,12 +1439,14 @@ class TensorNameMap:
"vision_tower.encoder.blocks.{bid}.mlp.fc0", # kimi-vl (fc0/fc1)
"model.vision.transformer.layers.{bid}.mlp.fc1", # cogvlm
"siglip2.vision_model.encoder.layers.{bid}.mlp.fc1",
"blocks.{bid}.mlp.w2", # JinaCLIP v2 vision (up)
),
MODEL_TENSOR.V_ENC_FFN_GATE: (
"vision_tower.transformer.layers.{bid}.feed_forward.gate_proj", # pixtral-hf
"vision_encoder.transformer.layers.{bid}.feed_forward.w1", # pixtral
"visual.blocks.{bid}.mlp.gate_proj", # qwen2.5vl
"blocks.{bid}.mlp.w1", # JinaCLIP v2 vision
),
MODEL_TENSOR.V_ENC_FFN_DOWN: (
@ -1449,6 +1463,11 @@ class TensorNameMap:
"vision_tower.encoder.blocks.{bid}.mlp.fc1", # kimi-vl (fc0/fc1)
"model.vision.transformer.layers.{bid}.mlp.fc2", # cogvlm
"siglip2.vision_model.encoder.layers.{bid}.mlp.fc2",
"blocks.{bid}.mlp.w3", # JinaCLIP v2 vision (down)
),
MODEL_TENSOR.V_ENC_FFN_NORM: (
"blocks.{bid}.mlp.ffn_ln", # JinaCLIP v2 vision
),
MODEL_TENSOR.V_LAYER_SCALE_1: (
@ -1461,6 +1480,14 @@ class TensorNameMap:
"model.vision_tower.encoder.layer.{bid}.lambda_2", # Intern-S1
),
MODEL_TENSOR.V_ENC_ATTN_Q_BIAS: (
"blocks.{bid}.attn.q_bias", # JinaCLIP v2 vision
),
MODEL_TENSOR.V_ENC_ATTN_V_BIAS: (
"blocks.{bid}.attn.v_bias", # JinaCLIP v2 vision
),
MODEL_TENSOR.V_PRE_NORM: (
"vision_tower.vision_model.pre_layrnorm",
"vision_tower.ln_pre", # pixtral-hf
@ -1474,6 +1501,7 @@ class TensorNameMap:
"vision_model.layernorm_post", # llama4
"visual.merger.ln_q", # qwen2vl
"vision_tower.encoder.final_layernorm", # kimi-vl
"norm", # JinaCLIP v2 vision
"visual.post_layernorm", # glm4v
"siglip2.vision_model.post_layernorm",
),

View File

@ -19,6 +19,7 @@ add_library(mtmd
models/glm4v.cpp
models/internvl.cpp
models/kimivl.cpp
models/jinaclip2.cpp
models/llama4.cpp
models/llava.cpp
models/minicpmv.cpp

View File

@ -31,6 +31,7 @@ struct clip_graph {
const float eps;
const float kq_scale;
const clip_flash_attn_type flash_attn_type;
norm_type block_norm_t = NORM_TYPE_NORMAL;
ggml_context_ptr ctx0_ptr;
ggml_context * ctx0;

View File

@ -44,6 +44,7 @@
#define KEY_PROJ_SCALE_FACTOR "clip.vision.projector.scale_factor"
#define KEY_SPATIAL_MERGE_SIZE "clip.vision.spatial_merge_size"
#define KEY_IS_DEEPSTACK_LAYERS "clip.vision.is_deepstack_layers"
#define KEY_VISION_ROPE_THETA "clip.vision.rope_theta"
#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type"
#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints"
@ -75,6 +76,7 @@
#define TN_ATTN_Q "%s.blk.%d.attn_q.%s"
#define TN_ATTN_V "%s.blk.%d.attn_v.%s"
#define TN_ATTN_OUTPUT "%s.blk.%d.attn_out.%s"
#define TN_ATTN_LN "%s.blk.%d.attn_ln.%s"
#define TN_ATTN_K_NORM "%s.blk.%d.attn_k_norm.%s"
#define TN_ATTN_Q_NORM "%s.blk.%d.attn_q_norm.%s"
#define TN_FFN_DOWN "%s.blk.%d.ffn_down.%s"
@ -225,6 +227,7 @@ enum projector_type {
PROJECTOR_TYPE_QWEN25O, // will be replaced by QWEN2A or QWEN25VL depending on clip_ctx
PROJECTOR_TYPE_VOXTRAL,
PROJECTOR_TYPE_MUSIC_FLAMINGO,
PROJECTOR_TYPE_JINACLIP2, // JinaCLIP v2
PROJECTOR_TYPE_LFM2,
PROJECTOR_TYPE_KIMIVL,
PROJECTOR_TYPE_LIGHTONOCR,
@ -261,6 +264,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
{ PROJECTOR_TYPE_LFM2, "lfm2"},
{ PROJECTOR_TYPE_KIMIVL, "kimivl"},
{ PROJECTOR_TYPE_LIGHTONOCR,"lightonocr"},
{ PROJECTOR_TYPE_JINACLIP2, "jinaclip2"},
{ PROJECTOR_TYPE_COGVLM, "cogvlm"},
{ PROJECTOR_TYPE_JANUS_PRO, "janus_pro"},
{ PROJECTOR_TYPE_LFM2A, "lfm2a"},

View File

@ -117,6 +117,9 @@ struct clip_layer {
ggml_tensor * k_norm = nullptr;
ggml_tensor * q_norm = nullptr;
ggml_tensor * attn_out_norm_w = nullptr;
ggml_tensor * attn_out_norm_b = nullptr;
// layernorm 1
ggml_tensor * ln_1_w = nullptr;
ggml_tensor * ln_1_b = nullptr;
@ -125,6 +128,8 @@ struct clip_layer {
ggml_tensor * ff_up_b = nullptr;
ggml_tensor * ff_gate_w = nullptr;
ggml_tensor * ff_gate_b = nullptr;
ggml_tensor * ffn_hidden_norm_w = nullptr;
ggml_tensor * ffn_hidden_norm_b = nullptr;
ggml_tensor * ff_down_w = nullptr;
ggml_tensor * ff_down_b = nullptr;

View File

@ -292,6 +292,8 @@ ggml_tensor * clip_graph::build_vit(
ggml_tensor * learned_pos_embd,
std::function<ggml_tensor *(ggml_tensor *, const clip_layer &)> add_pos
) {
block_norm_t = norm_t;
if (learned_pos_embd) {
inp = ggml_add(ctx0, inp, learned_pos_embd);
cb(inp, "pos_embed", -1);
@ -489,7 +491,6 @@ ggml_tensor * clip_graph::build_norm(
cur = ggml_add(ctx0, cur, mb);
cb(cur, "norm_b", il);
}
return cur;
}
@ -560,6 +561,14 @@ ggml_tensor * clip_graph::build_ffn(
} break;
}
if (il >= 0 && il < (int) model.layers.size()) {
const auto & layer = model.layers[il];
if (layer.ffn_hidden_norm_w) {
cur = build_norm(cur, layer.ffn_hidden_norm_w, layer.ffn_hidden_norm_b, block_norm_t, eps, il);
cb(cur, "ffn_hidden_normed", il);
}
}
if (down) {
cur = ggml_mul_mat(ctx0, down, cur);
}
@ -629,6 +638,14 @@ ggml_tensor * clip_graph::build_attn(
cb(cur, "kqv_out", il);
if (il >= 0 && il < (int) model.layers.size()) {
const auto & layer = model.layers[il];
if (layer.attn_out_norm_w) {
cur = build_norm(cur, layer.attn_out_norm_w, layer.attn_out_norm_b, block_norm_t, eps, il);
cb(cur, "kqv_out_normed", il);
}
}
if (wo) {
cur = ggml_mul_mat(ctx0, wo, cur);
}
@ -813,6 +830,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
{
builder = std::make_unique<clip_graph_llama4>(ctx, img);
} break;
case PROJECTOR_TYPE_JINACLIP2:
{
builder = std::make_unique<clip_graph_jinaclip2>(ctx, img);
} break;
case PROJECTOR_TYPE_ULTRAVOX:
case PROJECTOR_TYPE_VOXTRAL:
case PROJECTOR_TYPE_QWEN2A:
@ -1200,6 +1221,11 @@ struct clip_model_loader {
get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false);
set_llava_uhd_res_candidates(model, 3);
} break;
case PROJECTOR_TYPE_JINACLIP2:
{
hparams.rope_theta = 10000.0f;
get_f32(KEY_VISION_ROPE_THETA, hparams.rope_theta, false);
} break;
case PROJECTOR_TYPE_ULTRAVOX:
case PROJECTOR_TYPE_QWEN2A:
case PROJECTOR_TYPE_GLMA:
@ -1358,6 +1384,7 @@ struct clip_model_loader {
layer.qkv_w = get_tensor(string_format(TN_ATTN_QKV, prefix, il, "weight"), false);
layer.k_norm = get_tensor(string_format(TN_ATTN_K_NORM, prefix, il, "weight"), false);
layer.q_norm = get_tensor(string_format(TN_ATTN_Q_NORM, prefix, il, "weight"), false);
layer.attn_out_norm_w = get_tensor(string_format(TN_ATTN_LN, prefix, il, "weight"), false);
layer.ln_1_w = get_tensor(string_format(TN_LN_1, prefix, il, "weight"), false);
layer.ln_2_w = get_tensor(string_format(TN_LN_2, prefix, il, "weight"), false);
layer.ls_1_w = get_tensor(string_format(TN_LS_1, prefix, il, "weight"), false); // no bias
@ -1368,6 +1395,7 @@ struct clip_model_loader {
layer.v_b = get_tensor(string_format(TN_ATTN_V, prefix, il, "bias"), false);
layer.o_b = get_tensor(string_format(TN_ATTN_OUTPUT, prefix, il, "bias"), false);
layer.qkv_b = get_tensor(string_format(TN_ATTN_QKV, prefix, il, "bias"), false);
layer.attn_out_norm_b = get_tensor(string_format(TN_ATTN_LN, prefix, il, "bias"), false);
layer.ln_1_b = get_tensor(string_format(TN_LN_1, prefix, il, "bias"), false);
layer.ln_2_b = get_tensor(string_format(TN_LN_2, prefix, il, "bias"), false);
@ -1376,6 +1404,8 @@ struct clip_model_loader {
layer.ff_up_b = get_tensor(string_format(TN_FFN_UP, prefix, il, "bias"), false);
layer.ff_gate_w = get_tensor(string_format(TN_FFN_GATE, prefix, il, "weight"), false);
layer.ff_gate_b = get_tensor(string_format(TN_FFN_GATE, prefix, il, "bias"), false);
layer.ffn_hidden_norm_w = get_tensor(string_format(TN_FFN_NORM, prefix, il, "weight"), false);
layer.ffn_hidden_norm_b = get_tensor(string_format(TN_FFN_NORM, prefix, il, "bias"), false);
layer.ff_down_w = get_tensor(string_format(TN_FFN_DOWN, prefix, il, "weight"));
layer.ff_down_b = get_tensor(string_format(TN_FFN_DOWN, prefix, il, "bias"), false);
@ -1785,6 +1815,9 @@ struct clip_model_loader {
model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"));
model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"));
} break;
case PROJECTOR_TYPE_JINACLIP2:
{
} break;
case PROJECTOR_TYPE_LFM2A:
{
for (int i : {0, 2, 3, 5, 6}) {
@ -3020,6 +3053,41 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
res_imgs->grid_y = inst.grid_size.height;
} break;
case PROJECTOR_TYPE_JINACLIP2:
{
clip_image_u8 processed_image;
const int sz = params.image_size;
const int in_w = img->nx;
const int in_h = img->ny;
if (in_w <= 0 || in_h <= 0) {
LOG_ERR("%s: invalid input image size %dx%d\n", __func__, in_w, in_h);
return false;
}
int out_w = 0, out_h = 0;
if (in_w < in_h) {
out_w = sz;
out_h = std::max(1, (int) std::round((double) in_h * sz / in_w));
} else {
out_h = sz;
out_w = std::max(1, (int) std::round((double) in_w * sz / in_h));
}
clip_image_u8 resized_keep_ratio;
img_tool::resize(*img, resized_keep_ratio, clip_image_size{out_w, out_h}, img_tool::RESIZE_ALGO_BICUBIC);
const int x0 = std::max(0, (resized_keep_ratio.nx - sz) / 2);
const int y0 = std::max(0, (resized_keep_ratio.ny - sz) / 2);
const int crop_w = std::min(sz, resized_keep_ratio.nx);
const int crop_h = std::min(sz, resized_keep_ratio.ny);
img_tool::crop(resized_keep_ratio, processed_image, x0, y0, crop_w, crop_h);
clip_image_f32_ptr img_f32(clip_image_f32_init());
normalize_image_u8_to_f32(processed_image, *img_f32, params.image_mean, params.image_std);
res_imgs->entries.push_back(std::move(img_f32));
} break;
case PROJECTOR_TYPE_LFM2:
case PROJECTOR_TYPE_KIMIVL:
{
@ -3183,6 +3251,10 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
{
// do nothing
} break;
case PROJECTOR_TYPE_JINACLIP2:
{
n_patches = 1;
} break;
case PROJECTOR_TYPE_LDP:
case PROJECTOR_TYPE_LDPV2:
case PROJECTOR_TYPE_GLM_EDGE:
@ -3613,6 +3685,52 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
}
set_input_i32("positions", positions);
} break;
case PROJECTOR_TYPE_JINACLIP2:
{
std::vector<int32_t> positions(n_pos);
for (int i = 0; i < n_pos; i++) {
positions[i] = i;
}
set_input_i32("positions", positions);
const int n_patches = model.class_embedding ? (n_pos - 1) : n_pos;
const int n_patches_per_col = image_size_width / patch_size;
std::vector<int32_t> pos_data(n_pos, 0);
for (int i = 0; i < n_patches; ++i) {
const int idx = model.class_embedding ? (i + 1) : i;
pos_data[idx] = i / n_patches_per_col;
}
set_input_i32("pos_h", pos_data);
std::fill(pos_data.begin(), pos_data.end(), 0);
for (int i = 0; i < n_patches; ++i) {
const int idx = model.class_embedding ? (i + 1) : i;
pos_data[idx] = i % n_patches_per_col;
}
set_input_i32("pos_w", pos_data);
int pt_seq_len = 16;
if (patch_size > 0) {
const int cand = (int) llroundf(224.0f / (float) patch_size);
if (cand > 0) {
pt_seq_len = cand;
}
}
const float s = (float) pt_seq_len / (float) n_patches_per_col;
const int d_head_local = hparams.n_embd / hparams.n_head;
const int half_local = d_head_local / 2;
std::vector<float> rope_c_first(half_local);
std::vector<float> rope_c_second(half_local);
for (int k = 0; k < half_local; ++k) {
rope_c_first[k] = 1.0f / s;
rope_c_second[k] = 1.0f / s;
}
set_input_f32("rope_c_first", rope_c_first);
set_input_f32("rope_c_second", rope_c_second);
} break;
case PROJECTOR_TYPE_MLP:
case PROJECTOR_TYPE_MLP_NORM:
case PROJECTOR_TYPE_LDP:
@ -3737,6 +3855,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
case PROJECTOR_TYPE_PIXTRAL:
case PROJECTOR_TYPE_LIGHTONOCR:
return ctx->model.mm_2_w->ne[1];
case PROJECTOR_TYPE_JINACLIP2:
return ctx->model.hparams.projection_dim;
case PROJECTOR_TYPE_MLP_NORM:
return ctx->model.mm_3_b->ne[0];
case PROJECTOR_TYPE_MINICPMV:

View File

@ -0,0 +1,122 @@
#include "models.h"
#include <cmath>
ggml_cgraph * clip_graph_jinaclip2::build() {
const bool has_cls = model.class_embedding != nullptr;
GGML_ASSERT(has_cls && "JinaCLIP2 requires a CLS token");
const int n_pos = n_patches + (has_cls ? 1 : 0);
GGML_ASSERT(n_patches_x == n_patches_y && "only square images supported");
ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos);
ggml_set_name(positions, "positions");
ggml_set_input(positions);
ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos);
ggml_set_name(pos_h, "pos_h");
ggml_set_input(pos_h);
ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos);
ggml_set_name(pos_w, "pos_w");
ggml_set_input(pos_w);
GGML_ASSERT(d_head % 2 == 0);
ggml_tensor * rope_c_first = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, d_head / 2);
ggml_set_name(rope_c_first, "rope_c_first");
ggml_set_input(rope_c_first);
ggml_tensor * rope_c_second = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, d_head / 2);
ggml_set_name(rope_c_second, "rope_c_second");
ggml_set_input(rope_c_second);
ggml_tensor * inp = build_inp();
if (has_cls) {
inp = ggml_concat(ctx0, model.class_embedding, inp, 1);
}
inp = ggml_add(ctx0, inp, ggml_get_rows(ctx0, model.position_embeddings, positions));
auto apply_rope_2d = [&](ggml_tensor * cur) -> ggml_tensor * {
ggml_tensor * cur_in = ggml_permute(ctx0, cur, 0, 2, 1, 3);
const int64_t n_dim = cur_in->ne[0];
const int64_t seq = cur_in->ne[1];
const int64_t nhead = cur_in->ne[2];
GGML_ASSERT(seq == n_pos);
GGML_ASSERT(n_dim % 2 == 0);
const int64_t half = n_dim / 2;
ggml_tensor * cls = nullptr;
ggml_tensor * patches = cur_in;
int64_t n_pos_patches = seq;
int64_t pos_offset = 0;
if (has_cls) {
cls = ggml_view_3d(ctx0, cur_in, n_dim, 1, nhead, cur_in->nb[1], cur_in->nb[2], 0);
patches = ggml_view_3d(ctx0, cur_in, n_dim, seq - 1, nhead, cur_in->nb[1], cur_in->nb[2], cur_in->nb[1]);
n_pos_patches = seq - 1;
pos_offset = 1;
}
// select positions
ggml_tensor * pos_a = ggml_view_1d(ctx0, pos_h, n_pos_patches, pos_offset * (int64_t) ggml_element_size(pos_h));
ggml_tensor * pos_b = ggml_view_1d(ctx0, pos_w, n_pos_patches, pos_offset * (int64_t) ggml_element_size(pos_w));
ggml_tensor * first = ggml_view_3d(ctx0, patches,
half, nhead, n_pos_patches,
patches->nb[2], patches->nb[1], 0);
ggml_tensor * first_rot = ggml_rope_ext(
ctx0,
first,
pos_a,
rope_c_first,
half,
0, 0, hparams.rope_theta,
1.0f,
0.0f, 1.0f, 0.0f, 0.0f);
first = ggml_view_3d(ctx0, first_rot,
half, n_pos_patches, nhead,
first_rot->nb[2], first_rot->nb[1], 0);
ggml_tensor * second = ggml_view_3d(ctx0, patches,
half, nhead, n_pos_patches,
patches->nb[2], patches->nb[1],
half * (int64_t) ggml_element_size(patches));
ggml_tensor * second_rot = ggml_rope_ext(
ctx0,
second,
pos_b,
rope_c_second,
half,
0, 0, hparams.rope_theta,
1.0f,
0.0f, 1.0f, 0.0f, 0.0f);
second = ggml_view_3d(ctx0, second_rot,
half, n_pos_patches, nhead,
second_rot->nb[2], second_rot->nb[1], 0);
ggml_tensor * patches_out = ggml_concat(ctx0, first, second, 0);
ggml_tensor * out_seq = has_cls ? ggml_concat(ctx0, cls, patches_out, 1) : patches_out;
return ggml_permute(ctx0, out_seq, 0, 2, 1, 3);
};
auto add_pos = [&](ggml_tensor * cur, const clip_layer &) {
return apply_rope_2d(cur);
};
ggml_tensor * cur = build_vit(
inp, n_pos,
NORM_TYPE_NORMAL,
hparams.ffn_op,
nullptr,
add_pos);
ggml_tensor * cls = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], 0);
ggml_set_name(cls, "cls_view");
ggml_build_forward_expand(gf, cls);
return gf;
}

View File

@ -52,6 +52,11 @@ struct clip_graph_kimivl : clip_graph {
ggml_cgraph * build() override;
};
struct clip_graph_jinaclip2 : clip_graph {
clip_graph_jinaclip2(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
ggml_cgraph * build() override;
};
struct clip_graph_cogvlm : clip_graph {
clip_graph_cogvlm(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
ggml_cgraph * build() override;

View File

@ -40,7 +40,8 @@ static void show_additional_info(int /*argc*/, char ** argv) {
LOG(
"Experimental CLI for multimodal\n\n"
"Usage: %s [options] -m <model> --mmproj <mmproj> --image <image> --audio <audio> -p <prompt>\n\n"
" -m and --mmproj are required\n"
" -m and --mmproj are required in chat/generation modes\n"
" Embedding mode: -m + --mmproj + --image + --embd-output-format (no prompt required)\n"
" -hf user/repo can replace both -m and --mmproj in most cases\n"
" --image, --audio and -p are optional, if NOT provided, the CLI will run in chat mode\n"
" to disable using GPU for mmproj model, add --no-mmproj-offload\n",
@ -174,6 +175,117 @@ struct mtmd_cli_context {
}
};
static int run_mmproj_only(common_params & params) {
if (params.embd_out.empty()) return -1;
if (!params.prompt.empty()) return -1;
if (params.mmproj.path.empty() || params.image.empty()) return -1;
mtmd_context_params ctx_params = mtmd_context_params_default();
ctx_params.use_gpu = params.mmproj_use_gpu;
ctx_params.warmup = params.warmup;
ctx_params.image_min_tokens = params.image_min_tokens;
ctx_params.image_max_tokens = params.image_max_tokens;
mtmd_mmproj_context_t mctx = mtmd_mmproj_init(params.mmproj.path.c_str(), ctx_params);
if (!mctx) {
LOG_ERR("[ERROR] Failed to load vision mmproj: %s\n", params.mmproj.path.c_str());
return 1;
}
const std::string fmt = params.embd_out;
std::vector<std::vector<float>> embeddings;
embeddings.reserve(params.image.size());
for (size_t i = 0; i < params.image.size(); ++i) {
const char * image_path = params.image[i].c_str();
mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_file_noctx(image_path));
if (!bmp.ptr) {
LOG_ERR("[ERROR] Failed to decode image %s\n", image_path);
mtmd_mmproj_free(mctx);
return 1;
}
float * emb = nullptr; size_t n_el = 0;
int enc_rc = mtmd_mmproj_encode_bitmap(mctx, bmp.ptr.get(), params.cpuparams.n_threads, &emb, &n_el);
if (enc_rc != 0 || !emb || n_el == 0) {
LOG_ERR("[ERROR] Image encoding failed: %s\n", image_path);
mtmd_mmproj_free(mctx);
return 1;
}
std::vector<float> image_embd(emb, emb + n_el);
std::free(emb);
if (params.embd_normalize != -1) {
common_embd_normalize(image_embd.data(), image_embd.data(), (int) image_embd.size(), params.embd_normalize);
}
embeddings.emplace_back(std::move(image_embd));
}
const bool is_array = fmt == "array";
const bool is_json = fmt == "json" || fmt == "json+";
if (is_array || is_json) {
const bool not_array = !is_array;
LOG(not_array ? "{\n \"object\": \"list\",\n \"data\": [\n" : "[");
for (size_t j = 0; j < embeddings.size(); ++j) {
const auto & e = embeddings[j];
if (not_array) LOG(" {\n \"object\": \"embedding\",\n \"index\": %zu,\n \"embedding\": ", j);
LOG("[");
for (size_t i = 0; i < e.size(); ++i) {
LOG(params.embd_normalize == 0 ? "%1.0f" : "%1.7f", e[i]);
if (i + 1 < e.size()) LOG(",");
}
LOG(not_array ? "]\n }" : "]");
if (j + 1 < embeddings.size()) LOG(not_array ? ",\n" : ",");
}
LOG(not_array ? "\n ]" : "]\n");
if (fmt == "json+" && embeddings.size() > 1) {
bool same_dim = true;
const size_t n_dim = embeddings[0].size();
for (size_t i = 1; i < embeddings.size(); ++i) {
if (embeddings[i].size() != n_dim) {
same_dim = false;
break;
}
}
if (same_dim) {
LOG(",\n \"cosineSimilarity\": [\n");
for (size_t i = 0; i < embeddings.size(); ++i) {
LOG(" [");
for (size_t j = 0; j < embeddings.size(); ++j) {
float sim = common_embd_similarity_cos(embeddings[i].data(), embeddings[j].data(), (int) n_dim);
LOG("%6.2f", sim);
if (j + 1 < embeddings.size()) LOG(", ");
}
LOG(" ]");
if (i + 1 < embeddings.size()) LOG(",\n");
}
LOG("\n ]");
}
}
if (not_array) LOG("\n}\n");
} else if (fmt == "raw") {
for (size_t j = 0; j < embeddings.size(); ++j) {
const auto & e = embeddings[j];
for (size_t i = 0; i < e.size(); ++i) {
if (i) LOG(" ");
LOG(params.embd_normalize == 0 ? "%1.0f" : "%1.7f", e[i]);
}
LOG("\n");
}
} else {
LOG_ERR("[ERROR] Invalid --embd-output-format: '%s'\n", fmt.c_str());
mtmd_mmproj_free(mctx);
return 1;
}
mtmd_mmproj_free(mctx);
return 0;
}
static int generate_response(mtmd_cli_context & ctx, int n_predict) {
llama_tokens generated_tokens;
for (int i = 0; i < n_predict; i++) {
@ -282,6 +394,11 @@ int main(int argc, char ** argv) {
return 1;
}
{
int rc = run_mmproj_only(params);
if (rc >= 0) return rc;
}
common_init();
mtmd_helper_log_set(common_log_default_callback, nullptr);

View File

@ -519,3 +519,15 @@ mtmd_bitmap * mtmd_helper_bitmap_init_from_file(mtmd_context * ctx, const char *
return mtmd_helper_bitmap_init_from_buf(ctx, buf.data(), buf.size());
}
mtmd_bitmap * mtmd_helper_bitmap_init_from_file_noctx(const char * fname) {
int nx = 0, ny = 0, nc = 0;
unsigned char * data = stbi_load(fname, &nx, &ny, &nc, 3);
if (!data) {
LOG_ERR("%s: failed to decode image file %s\n", __func__, fname);
return nullptr;
}
mtmd_bitmap * result = mtmd_bitmap_init(nx, ny, data);
stbi_image_free(data);
return result;
}

View File

@ -40,6 +40,9 @@ MTMD_API mtmd_bitmap * mtmd_helper_bitmap_init_from_file(mtmd_context * ctx, con
// this function is thread-safe
MTMD_API mtmd_bitmap * mtmd_helper_bitmap_init_from_buf(mtmd_context * ctx, const unsigned char * buf, size_t len);
// Decode an image file without mtmd_context
MTMD_API mtmd_bitmap * mtmd_helper_bitmap_init_from_file_noctx(const char * fname);
// helper to count the total number of tokens from a list of chunks, useful to keep track of KV cache
MTMD_API size_t mtmd_helper_get_n_tokens(const mtmd_input_chunks * chunks);

View File

@ -425,6 +425,87 @@ void mtmd_free(mtmd_context * ctx) {
delete ctx;
}
struct mtmd_mmproj_context {
clip_ctx * ctx_v = nullptr;
};
mtmd_mmproj_context_t mtmd_mmproj_init(const char * mmproj_fname,
const struct mtmd_context_params ctx_params) {
clip_context_params clip_params{};
clip_params.use_gpu = ctx_params.use_gpu;
clip_params.flash_attn_type = CLIP_FLASH_ATTN_TYPE_AUTO;
clip_params.image_min_tokens = ctx_params.image_min_tokens;
clip_params.image_max_tokens = ctx_params.image_max_tokens;
clip_params.warmup = ctx_params.warmup;
clip_params.cb_eval = nullptr;
clip_params.cb_eval_user_data = nullptr;
auto res = clip_init(mmproj_fname, clip_params);
if (!res.ctx_v) {
return nullptr;
}
auto * ctx = new mtmd_mmproj_context();
ctx->ctx_v = res.ctx_v;
return ctx;
}
void mtmd_mmproj_free(mtmd_mmproj_context_t ctx) {
if (!ctx) return;
clip_free(ctx->ctx_v);
delete ctx;
}
int32_t mtmd_mmproj_encode_bitmap(mtmd_mmproj_context_t ctx,
const mtmd_bitmap * bmp,
int32_t n_threads,
float ** out_data,
size_t * out_count) {
if (!ctx || !ctx->ctx_v || !bmp || !out_data || !out_count) {
LOG_ERR("%s: invalid args: ctx=%p ctx_v=%p bmp=%p out_data=%p out_count=%p\n",
__func__, (void*) ctx, ctx ? (void*) ctx->ctx_v : (void*) nullptr,
(void*) bmp, (void*) out_data, (void*) out_count);
return 1;
}
clip_image_u8_ptr img_u8(clip_image_u8_init());
img_u8->nx = bmp->nx;
img_u8->ny = bmp->ny;
img_u8->buf.resize(bmp->data.size());
std::memcpy(img_u8->buf.data(), bmp->data.data(), img_u8->nx * img_u8->ny * 3);
clip_image_f32_batch batch_f32;
bool ok = clip_image_preprocess(ctx->ctx_v, img_u8.get(), &batch_f32);
if (!ok) {
LOG_ERR("%s: image preprocess failed (nx=%u ny=%u proj=%d)\n",
__func__, img_u8->nx, img_u8->ny, (int) clip_get_projector_type(ctx->ctx_v));
return 1;
}
clip_image_f32 * processed_img = clip_image_f32_get_img(&batch_f32, 0);
if (!processed_img) {
LOG_ERR("%s: preprocessed image is null\n", __func__);
return 1;
}
const int n_tok = clip_n_output_tokens(ctx->ctx_v, processed_img);
const int n_embd = clip_n_mmproj_embd(ctx->ctx_v);
const size_t n_el = (size_t) n_tok * (size_t) n_embd;
std::vector<float> buf(n_el);
if (!clip_image_encode(ctx->ctx_v, n_threads, processed_img, buf.data())) {
LOG_ERR("%s: image encode failed (threads=%d tokens=%d embd=%d)\n",
__func__, n_threads, n_tok, n_embd);
return 1;
}
float * out = (float *) std::malloc(n_el * sizeof(float));
if (!out) {
LOG_ERR("%s: malloc failed (elements=%zu bytes=%zu)\n", __func__, n_el, n_el * sizeof(float));
return 1;
}
std::memcpy(out, buf.data(), n_el * sizeof(float));
*out_data = out;
*out_count = n_el;
return 0;
}
struct mtmd_tokenizer {
mtmd_context * ctx;
std::vector<const mtmd_bitmap *> bitmaps;

View File

@ -231,6 +231,23 @@ MTMD_API float * mtmd_get_output_embd(mtmd_context * ctx);
// If this is not called, or NULL is supplied, everything is output on stderr.
MTMD_API void mtmd_log_set(ggml_log_callback log_callback, void * user_data);
typedef struct mtmd_mmproj_context * mtmd_mmproj_context_t;
// initialize a minimal context that only loads the projector
MTMD_API mtmd_mmproj_context_t mtmd_mmproj_init(const char * mmproj_fname,
const struct mtmd_context_params ctx_params);
// free projector-only context
MTMD_API void mtmd_mmproj_free(mtmd_mmproj_context_t ctx);
// encode a bitmap to projector embeddings
// returns 0 on success, 1 on failure
MTMD_API int32_t mtmd_mmproj_encode_bitmap(mtmd_mmproj_context_t ctx,
const mtmd_bitmap * bmp,
int32_t n_threads,
float ** out_data,
size_t * out_count);
/////////////////////////////////////////
// test function, to be used in test-mtmd-c-api.c