This commit is contained in:
Saba Fallah 2025-12-17 12:00:13 +08:00 committed by GitHub
commit 6a8a8dbac6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
30 changed files with 1843 additions and 40 deletions

View File

@ -711,6 +711,9 @@ class ModelBase:
if "thinker_config" in config: if "thinker_config" in config:
# rename for Qwen2.5-Omni # rename for Qwen2.5-Omni
config["text_config"] = config["thinker_config"]["text_config"] config["text_config"] = config["thinker_config"]["text_config"]
if "language_config" in config:
# rename for DeepSeekOCR
config["text_config"] = config["language_config"]
return config return config
@classmethod @classmethod
@ -1688,7 +1691,7 @@ class MmprojModel(ModelBase):
preprocessor_config: dict[str, Any] preprocessor_config: dict[str, Any]
global_config: dict[str, Any] global_config: dict[str, Any]
n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth", "encoder_layers"] n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth", "layers", "encoder_layers"]
has_vision_encoder: bool = True # by default has_vision_encoder: bool = True # by default
has_audio_encoder: bool = False has_audio_encoder: bool = False
@ -5956,6 +5959,68 @@ class Gemma3VisionModel(MmprojModel):
return [] # skip other tensors return [] # skip other tensors
@ModelBase.register("DeepseekOCRForCausalLM")
class DeepseekOCRVisionModel(MmprojModel):
def set_gguf_parameters(self):
super().set_gguf_parameters()
hparams = self.hparams
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.DEEPSEEKOCR)
# default values below are taken from HF tranformers code
self.gguf_writer.add_vision_attention_layernorm_eps(hparams.get("layer_norm_eps", 1e-6))
self.gguf_writer.add_vision_use_gelu(True)
# calculate proj_scale_factor (used by tinygemma3 test model)
image_seq_length = self.preprocessor_config.get("image_seq_length", 256)
n_per_side = int(image_seq_length ** 0.5)
image_size = self.hparams["image_size"]
patch_size = self.hparams["patch_size"]
proj_scale_factor = (image_size // patch_size) // n_per_side
if proj_scale_factor > 0 and proj_scale_factor != 4:
# we only need to write this if it's not the default value
# in this case, we are converting a test model
self.gguf_writer.add_vision_projector_scale_factor(proj_scale_factor)
# @bluebread: there's no window_size in config but just add it here anyway
self.gguf_writer.add_vision_window_size(self.hparams.get("window_size", 14))
# SAM configuration
sam_hparams = hparams['sam']
self.gguf_writer.add_vision_sam_layers_count(sam_hparams['layers'])
self.gguf_writer.add_vision_sam_embedding_length(sam_hparams['width'])
self.gguf_writer.add_vision_sam_head_count(sam_hparams['heads'])
def get_vision_config(self) -> dict[str, Any]:
vision_config: dict[str, Any] | None = self.global_config.get("vision_config")
if not vision_config:
raise ValueError("DeepseekOCR model requires 'vision_config' in the model configuration, but it was not found")
vision_config['sam'] = vision_config['width']['sam_vit_b']
vision_config.update(vision_config['width']['clip-l-14-224'])
vision_config['hidden_size'] = vision_config['width']
vision_config['num_heads'] = vision_config['heads']
vision_config['intermediate_size'] = vision_config['heads'] * 4
return vision_config
def tensor_force_quant(self, name, new_name, bid, n_dims):
if ".embeddings." in name or 'pos_embed' in name:
return gguf.GGMLQuantizationType.F32
if ".rel_pos_h" in name or '.rel_pos_w' in name:
return gguf.GGMLQuantizationType.F32
return gguf.GGMLQuantizationType.F16
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# Only process vision-related tensors, skip language model tensors
# Vision components: sam_model, vision_model, projector, image_newline, view_seperator
# Language model components to skip: lm_head, embed_tokens, layers, norm
if name.startswith(("lm_head.", "model.embed_tokens.", "model.layers.", "model.norm.")):
return []
if ".attn.rel_pos_h" in name or ".attn.rel_pos_w" in name:
return [(self.map_tensor_name(name, try_suffixes=("",)), data_torch)]
return [(self.map_tensor_name(name), data_torch)]
@ModelBase.register("Gemma3nForConditionalGeneration") @ModelBase.register("Gemma3nForConditionalGeneration")
class Gemma3NModel(Gemma3Model): class Gemma3NModel(Gemma3Model):
model_arch = gguf.MODEL_ARCH.GEMMA3N model_arch = gguf.MODEL_ARCH.GEMMA3N
@ -7122,6 +7187,16 @@ class DeepseekModel(TextModel):
class DeepseekV2Model(TextModel): class DeepseekV2Model(TextModel):
model_arch = gguf.MODEL_ARCH.DEEPSEEK2 model_arch = gguf.MODEL_ARCH.DEEPSEEK2
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
hparams: dict = ModelBase.load_hparams(self.dir_model, is_mistral_format=False)
self.origin_hf_arch = hparams.get('architectures', [None])[0]
if self.origin_hf_arch == "DeepseekOCRForCausalLM":
self.model_arch = gguf.MODEL_ARCH.DEEPSEEK2OCR
self.gguf_writer.arch = gguf.MODEL_ARCH_NAMES[self.model_arch]
self.gguf_writer.add_architecture()
def set_vocab(self): def set_vocab(self):
try: try:
self._set_vocab_gpt2() self._set_vocab_gpt2()
@ -7177,30 +7252,41 @@ class DeepseekV2Model(TextModel):
raise NotImplementedError(f"Deepseek pre-tokenizer {tokpre!r} is not supported yet!") raise NotImplementedError(f"Deepseek pre-tokenizer {tokpre!r} is not supported yet!")
def set_gguf_parameters(self): def set_gguf_parameters(self):
is_ocr = (self.model_arch == gguf.MODEL_ARCH.DEEPSEEK2OCR)
# note: deepseek2 using MLA converts into MQA (ie: GQA with 1 group) if is_ocr:
self.hparams["num_key_value_heads"] = 1 self.hparams['rope_theta'] = self.hparams.get('rope_theta', 10000.0)
else:
# note: deepseek2 using MLA converts into MQA (ie: GQA with 1 group)
self.hparams["num_key_value_heads"] = 1
self.hparams['rms_norm_eps'] = self.hparams.get('rms_norm_eps', 1e-6)
super().set_gguf_parameters() super().set_gguf_parameters()
hparams = self.hparams hparams = self.hparams
kv_lora_rank = hparams["kv_lora_rank"] if hparams.get("kv_lora_rank") is not None else 512
routed_scaling_factor = hparams.get("routed_scaling_factor", 1.0)
norm_topk_prob = hparams.get("norm_topk_prob", False)
self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"]) self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"])
self.gguf_writer.add_vocab_size(hparams["vocab_size"]) self.gguf_writer.add_vocab_size(hparams["vocab_size"])
if "q_lora_rank" in hparams and hparams["q_lora_rank"] is not None: if "q_lora_rank" in hparams and hparams["q_lora_rank"] is not None:
self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"]) self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"])
self.gguf_writer.add_kv_lora_rank(hparams["kv_lora_rank"]) if "kv_lora_rank" in hparams and hparams["kv_lora_rank"] is not None:
self.gguf_writer.add_kv_lora_rank(kv_lora_rank)
# note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA # note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA
self.gguf_writer.add_key_length(hparams["kv_lora_rank"] + hparams["qk_rope_head_dim"]) if not is_ocr:
self.gguf_writer.add_value_length(hparams["kv_lora_rank"]) self.gguf_writer.add_key_length(kv_lora_rank + hparams["qk_rope_head_dim"])
self.gguf_writer.add_key_length_mla(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"]) self.gguf_writer.add_value_length(kv_lora_rank)
self.gguf_writer.add_value_length_mla(hparams["v_head_dim"]) self.gguf_writer.add_key_length_mla(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"])
self.gguf_writer.add_value_length_mla(hparams["v_head_dim"])
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"]) self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
self.gguf_writer.add_expert_count(hparams["n_routed_experts"]) self.gguf_writer.add_expert_count(hparams["n_routed_experts"])
self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"]) self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"])
self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"]) self.gguf_writer.add_expert_weights_scale(routed_scaling_factor)
self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"]) self.gguf_writer.add_expert_weights_norm(norm_topk_prob)
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
@ -7214,7 +7300,12 @@ class DeepseekV2Model(TextModel):
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# skip vision tensors and remove "language_model." for Kimi-VL # skip vision tensors and remove "language_model." for Kimi-VL
if "vision_tower" in name or "multi_modal_projector" in name: if ("vision_" in name
or "multi_modal_projector" in name
or "image_newline" in name
or "model.projector" in name
or "sam_model" in name
or "view_seperator" in name):
return [] return []
if name.startswith("language_model."): if name.startswith("language_model."):

View File

@ -289,5 +289,7 @@ void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
upscale_f32_bicubic_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], upscale_f32_bicubic_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
sf0, sf1, sf2, sf3, pixel_offset, stream); sf0, sf1, sf2, sf3, pixel_offset, stream);
} else {
GGML_ABORT("fatal error");
} }
} }

View File

@ -4914,6 +4914,7 @@ static struct ggml_tensor * ggml_interpolate_impl(
GGML_ASSERT((mode & 0xFF) < GGML_SCALE_MODE_COUNT); GGML_ASSERT((mode & 0xFF) < GGML_SCALE_MODE_COUNT);
// TODO: implement antialias for modes other than bilinear // TODO: implement antialias for modes other than bilinear
GGML_ASSERT(!(mode & GGML_SCALE_FLAG_ANTIALIAS) || (mode & 0xFF) == GGML_SCALE_MODE_BILINEAR); GGML_ASSERT(!(mode & GGML_SCALE_FLAG_ANTIALIAS) || (mode & 0xFF) == GGML_SCALE_MODE_BILINEAR);
GGML_ASSERT(a->type == GGML_TYPE_F32);
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3); struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);
@ -5259,6 +5260,7 @@ struct ggml_tensor * ggml_flash_attn_ext(
GGML_ASSERT(q->ne[3] == v->ne[3]); GGML_ASSERT(q->ne[3] == v->ne[3]);
if (mask) { if (mask) {
GGML_ASSERT(mask->type == GGML_TYPE_F16);
GGML_ASSERT(ggml_is_contiguous(mask)); GGML_ASSERT(ggml_is_contiguous(mask));
//GGML_ASSERT(ggml_can_repeat_rows(mask, qk)); //GGML_ASSERT(ggml_can_repeat_rows(mask, qk));

View File

@ -290,6 +290,7 @@ class Keys:
IMAGE_MEAN = "clip.vision.image_mean" IMAGE_MEAN = "clip.vision.image_mean"
IMAGE_STD = "clip.vision.image_std" IMAGE_STD = "clip.vision.image_std"
SPATIAL_MERGE_SIZE = "clip.vision.spatial_merge_size" SPATIAL_MERGE_SIZE = "clip.vision.spatial_merge_size"
WINDOW_SIZE = "clip.vision.window_size"
USE_GELU = "clip.use_gelu" USE_GELU = "clip.use_gelu"
USE_SILU = "clip.use_silu" USE_SILU = "clip.use_silu"
N_WA_PATTERN = "clip.vision.n_wa_pattern" # used by qwen2.5vl N_WA_PATTERN = "clip.vision.n_wa_pattern" # used by qwen2.5vl
@ -302,6 +303,11 @@ class Keys:
class Projector: class Projector:
SCALE_FACTOR = "clip.vision.projector.scale_factor" SCALE_FACTOR = "clip.vision.projector.scale_factor"
class SAM:
BLOCK_COUNT = "clip.vision.sam.block_count"
EMBEDDING_LENGTH = "clip.vision.sam.embedding_length"
HEAD_COUNT = "clip.vision.sam.head_count"
class ClipAudio: class ClipAudio:
NUM_MEL_BINS = "clip.audio.num_mel_bins" NUM_MEL_BINS = "clip.audio.num_mel_bins"
EMBEDDING_LENGTH = "clip.audio.embedding_length" EMBEDDING_LENGTH = "clip.audio.embedding_length"
@ -404,6 +410,7 @@ class MODEL_ARCH(IntEnum):
ARCTIC = auto() ARCTIC = auto()
DEEPSEEK = auto() DEEPSEEK = auto()
DEEPSEEK2 = auto() DEEPSEEK2 = auto()
DEEPSEEK2OCR = auto()
CHATGLM = auto() CHATGLM = auto()
GLM4 = auto() GLM4 = auto()
GLM4_MOE = auto() GLM4_MOE = auto()
@ -688,6 +695,22 @@ class MODEL_TENSOR(IntEnum):
V_MM_GATE = auto() # cogvlm V_MM_GATE = auto() # cogvlm
V_TOK_BOI = auto() # cogvlm V_TOK_BOI = auto() # cogvlm
V_TOK_EOI = auto() # cogvlm V_TOK_EOI = auto() # cogvlm
V_SAM_POS_EMBD = auto() # Deepseek-OCR
V_SAM_PATCH_EMBD = auto() # Deepseek-OCR
V_SAM_PRE_NORM = auto() # Deepseek-OCR
V_SAM_POST_NORM = auto() # Deepseek-OCR
V_SAM_ATTN_POS_H = auto() # Deepseek-OCR
V_SAM_ATTN_POS_W = auto() # Deepseek-OCR
V_SAM_ATTN_QKV = auto() # Deepseek-OCR
V_SAM_ATTN_OUT = auto() # Deepseek-OCR
V_SAM_MLP_LIN_1 = auto() # Deepseek-OCR
V_SAM_MLP_LIN_2 = auto() # Deepseek-OCR
V_SAM_NECK = auto() # Deepseek-OCR
V_SAM_NET_2 = auto() # Deepseek-OCR
V_SAM_NET_3 = auto() # Deepseek-OCR
V_ENC_EMBD_IMGNL = auto() # Deepseek-OCR
V_ENC_EMBD_VSEP = auto() # Deepseek-OCR
# audio (mtmd) # audio (mtmd)
A_ENC_EMBD_POS = auto() A_ENC_EMBD_POS = auto()
A_ENC_CONV1D = auto() A_ENC_CONV1D = auto()
@ -780,6 +803,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.ARCTIC: "arctic", MODEL_ARCH.ARCTIC: "arctic",
MODEL_ARCH.DEEPSEEK: "deepseek", MODEL_ARCH.DEEPSEEK: "deepseek",
MODEL_ARCH.DEEPSEEK2: "deepseek2", MODEL_ARCH.DEEPSEEK2: "deepseek2",
MODEL_ARCH.DEEPSEEK2OCR: "deepseek2-ocr",
MODEL_ARCH.CHATGLM: "chatglm", MODEL_ARCH.CHATGLM: "chatglm",
MODEL_ARCH.GLM4: "glm4", MODEL_ARCH.GLM4: "glm4",
MODEL_ARCH.GLM4_MOE: "glm4moe", MODEL_ARCH.GLM4_MOE: "glm4moe",
@ -1063,6 +1087,22 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.V_MM_GATE: "mm.gate", MODEL_TENSOR.V_MM_GATE: "mm.gate",
MODEL_TENSOR.V_TOK_BOI: "v.boi", MODEL_TENSOR.V_TOK_BOI: "v.boi",
MODEL_TENSOR.V_TOK_EOI: "v.eoi", MODEL_TENSOR.V_TOK_EOI: "v.eoi",
# DeepSeek-OCR SAM
MODEL_TENSOR.V_SAM_POS_EMBD: "v.sam.pos_embd",
MODEL_TENSOR.V_SAM_PATCH_EMBD: "v.sam.patch_embd",
MODEL_TENSOR.V_SAM_PRE_NORM: "v.sam.blk.{bid}.pre_ln",
MODEL_TENSOR.V_SAM_POST_NORM: "v.sam.blk.{bid}.post_ln",
MODEL_TENSOR.V_SAM_ATTN_POS_H: "v.sam.blk.{bid}.attn.pos_h",
MODEL_TENSOR.V_SAM_ATTN_POS_W: "v.sam.blk.{bid}.attn.pos_w",
MODEL_TENSOR.V_SAM_ATTN_QKV: "v.sam.blk.{bid}.attn.qkv",
MODEL_TENSOR.V_SAM_ATTN_OUT: "v.sam.blk.{bid}.attn.out",
MODEL_TENSOR.V_SAM_MLP_LIN_1: "v.sam.blk.{bid}.mlp.lin1",
MODEL_TENSOR.V_SAM_MLP_LIN_2: "v.sam.blk.{bid}.mlp.lin2",
MODEL_TENSOR.V_SAM_NECK: "v.sam.neck.{bid}",
MODEL_TENSOR.V_SAM_NET_2: "v.sam.net_2",
MODEL_TENSOR.V_SAM_NET_3: "v.sam.net_3",
MODEL_TENSOR.V_ENC_EMBD_IMGNL: "v.image_newline", # Deepseek-OCR
MODEL_TENSOR.V_ENC_EMBD_VSEP: "v.view_seperator", # Deepseek-OCR
# audio (mtmd) # audio (mtmd)
MODEL_TENSOR.A_ENC_EMBD_POS: "a.position_embd", MODEL_TENSOR.A_ENC_EMBD_POS: "a.position_embd",
MODEL_TENSOR.A_ENC_CONV1D: "a.conv1d.{bid}", MODEL_TENSOR.A_ENC_CONV1D: "a.conv1d.{bid}",
@ -1100,6 +1140,8 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.V_ENC_EMBD_PATCH, MODEL_TENSOR.V_ENC_EMBD_PATCH,
MODEL_TENSOR.V_ENC_EMBD_NORM, MODEL_TENSOR.V_ENC_EMBD_NORM,
MODEL_TENSOR.V_ENC_EMBD_POS, MODEL_TENSOR.V_ENC_EMBD_POS,
MODEL_TENSOR.V_ENC_EMBD_IMGNL,
MODEL_TENSOR.V_ENC_EMBD_VSEP,
MODEL_TENSOR.V_ENC_INPUT_NORM, MODEL_TENSOR.V_ENC_INPUT_NORM,
MODEL_TENSOR.V_ENC_ATTN_QKV, MODEL_TENSOR.V_ENC_ATTN_QKV,
MODEL_TENSOR.V_ENC_ATTN_Q, MODEL_TENSOR.V_ENC_ATTN_Q,
@ -1143,6 +1185,19 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.V_MM_GATE, MODEL_TENSOR.V_MM_GATE,
MODEL_TENSOR.V_TOK_BOI, MODEL_TENSOR.V_TOK_BOI,
MODEL_TENSOR.V_TOK_EOI, MODEL_TENSOR.V_TOK_EOI,
MODEL_TENSOR.V_SAM_POS_EMBD,
MODEL_TENSOR.V_SAM_PATCH_EMBD,
MODEL_TENSOR.V_SAM_PRE_NORM,
MODEL_TENSOR.V_SAM_POST_NORM,
MODEL_TENSOR.V_SAM_ATTN_POS_H,
MODEL_TENSOR.V_SAM_ATTN_POS_W,
MODEL_TENSOR.V_SAM_ATTN_QKV,
MODEL_TENSOR.V_SAM_ATTN_OUT,
MODEL_TENSOR.V_SAM_MLP_LIN_1,
MODEL_TENSOR.V_SAM_MLP_LIN_2,
MODEL_TENSOR.V_SAM_NECK,
MODEL_TENSOR.V_SAM_NET_2,
MODEL_TENSOR.V_SAM_NET_3,
# audio # audio
MODEL_TENSOR.A_ENC_EMBD_POS, MODEL_TENSOR.A_ENC_EMBD_POS,
MODEL_TENSOR.A_ENC_CONV1D, MODEL_TENSOR.A_ENC_CONV1D,
@ -2311,7 +2366,41 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.ATTN_Q_B, MODEL_TENSOR.ATTN_Q_B,
MODEL_TENSOR.ATTN_KV_A_MQA, MODEL_TENSOR.ATTN_KV_A_MQA,
MODEL_TENSOR.ATTN_KV_B, MODEL_TENSOR.ATTN_KV_B,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_K_B, MODEL_TENSOR.ATTN_K_B,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_V_B,
MODEL_TENSOR.ATTN_Q_A_NORM,
MODEL_TENSOR.ATTN_KV_A_NORM,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_ROT_EMBD,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_GATE_SHEXP,
MODEL_TENSOR.FFN_DOWN_SHEXP,
MODEL_TENSOR.FFN_UP_SHEXP,
MODEL_TENSOR.FFN_EXP_PROBS_B,
],
MODEL_ARCH.DEEPSEEK2OCR: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_Q_A,
MODEL_TENSOR.ATTN_Q_B,
MODEL_TENSOR.ATTN_KV_A_MQA,
MODEL_TENSOR.ATTN_KV_B,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_K_B,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_V_B, MODEL_TENSOR.ATTN_V_B,
MODEL_TENSOR.ATTN_Q_A_NORM, MODEL_TENSOR.ATTN_Q_A_NORM,
MODEL_TENSOR.ATTN_KV_A_NORM, MODEL_TENSOR.ATTN_KV_A_NORM,
@ -3174,6 +3263,10 @@ MODEL_TENSOR_SKIP: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_ROT_EMBD, MODEL_TENSOR.ATTN_ROT_EMBD,
], ],
MODEL_ARCH.DEEPSEEK2OCR: [
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_ROT_EMBD,
],
MODEL_ARCH.CHATGLM: [ MODEL_ARCH.CHATGLM: [
MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ROPE_FREQS,
], ],
@ -3363,6 +3456,7 @@ class VisionProjectorType:
LIGHTONOCR = "lightonocr" LIGHTONOCR = "lightonocr"
COGVLM = "cogvlm" COGVLM = "cogvlm"
JANUS_PRO = "janus_pro" JANUS_PRO = "janus_pro"
DEEPSEEKOCR = "deepseekocr"
GLM4V = "glm4v" GLM4V = "glm4v"

View File

@ -1112,6 +1112,9 @@ class GGUFWriter:
def add_vision_spatial_merge_size(self, value: int) -> None: def add_vision_spatial_merge_size(self, value: int) -> None:
self.add_uint32(Keys.ClipVision.SPATIAL_MERGE_SIZE, value) self.add_uint32(Keys.ClipVision.SPATIAL_MERGE_SIZE, value)
def add_vision_window_size(self, value: int) -> None:
self.add_uint32(Keys.ClipVision.WINDOW_SIZE, value)
def add_vision_use_gelu(self, value: bool) -> None: def add_vision_use_gelu(self, value: bool) -> None:
self.add_bool(Keys.ClipVision.USE_GELU, value) self.add_bool(Keys.ClipVision.USE_GELU, value)
@ -1127,6 +1130,15 @@ class GGUFWriter:
def add_vision_is_deepstack_layers(self, layers: Sequence[bool]) -> None: def add_vision_is_deepstack_layers(self, layers: Sequence[bool]) -> None:
self.add_array(Keys.ClipVision.IS_DEEPSTACK_LAYERS, layers) self.add_array(Keys.ClipVision.IS_DEEPSTACK_LAYERS, layers)
def add_vision_sam_layers_count(self, value: int) -> None:
self.add_uint32(Keys.ClipVision.SAM.BLOCK_COUNT, value)
def add_vision_sam_embedding_length(self, value: int) -> None:
self.add_uint32(Keys.ClipVision.SAM.EMBEDDING_LENGTH, value)
def add_vision_sam_head_count(self, value: int) -> None:
self.add_uint32(Keys.ClipVision.SAM.HEAD_COUNT, value)
# audio models # audio models
def add_audio_projection_dim(self, value: int) -> None: def add_audio_projection_dim(self, value: int) -> None:

View File

@ -1212,6 +1212,7 @@ class TensorNameMap:
MODEL_TENSOR.V_MMPROJ_FC: ( MODEL_TENSOR.V_MMPROJ_FC: (
"model.connector.modality_projection.proj", # SmolVLM "model.connector.modality_projection.proj", # SmolVLM
"model.vision.linear_proj.linear_proj", # cogvlm "model.vision.linear_proj.linear_proj", # cogvlm
"model.projector.layers", # Deepseek-OCR
"visual.merger.proj", # glm4v "visual.merger.proj", # glm4v
), ),
@ -1231,6 +1232,7 @@ class TensorNameMap:
"model.vision_tower.embeddings.cls_token", # Intern-S1 "model.vision_tower.embeddings.cls_token", # Intern-S1
"vision_model.class_embedding", # llama 4 "vision_model.class_embedding", # llama 4
"model.vision.patch_embedding.cls_embedding", # cogvlm "model.vision.patch_embedding.cls_embedding", # cogvlm
"model.vision_model.embeddings.class_embedding", # Deepseek-OCR
), ),
MODEL_TENSOR.V_ENC_EMBD_PATCH: ( MODEL_TENSOR.V_ENC_EMBD_PATCH: (
@ -1244,6 +1246,7 @@ class TensorNameMap:
"visual.patch_embed.proj", # qwen2vl "visual.patch_embed.proj", # qwen2vl
"vision_tower.patch_embed.proj", # kimi-vl "vision_tower.patch_embed.proj", # kimi-vl
"model.vision.patch_embedding.proj", # cogvlm "model.vision.patch_embedding.proj", # cogvlm
"model.vision_model.embeddings.patch_embedding", # Deepseek-OCR CLIP
), ),
MODEL_TENSOR.V_ENC_EMBD_NORM: ( MODEL_TENSOR.V_ENC_EMBD_NORM: (
@ -1258,13 +1261,22 @@ class TensorNameMap:
"vision_model.positional_embedding_vlm", # llama 4 "vision_model.positional_embedding_vlm", # llama 4
"vision_tower.patch_embed.pos_emb", # kimi-vl "vision_tower.patch_embed.pos_emb", # kimi-vl
"visual.pos_embed", # qwen3vl "visual.pos_embed", # qwen3vl
"model.vision.patch_embedding.position_embedding", # cogvlm "model.vision.patch_embedding.position_embedding", # cogvlm
"visual.embeddings.position_embedding", # glm4v "visual.embeddings.position_embedding", # glm4v
), ),
MODEL_TENSOR.V_ENC_EMBD_IMGNL: (
"model.image_newline", # Deepseek-OCR
),
MODEL_TENSOR.V_ENC_EMBD_VSEP: (
"model.view_seperator", # Deepseek-OCR
),
MODEL_TENSOR.V_ENC_ATTN_QKV: ( MODEL_TENSOR.V_ENC_ATTN_QKV: (
"visual.blocks.{bid}.attn.qkv", # qwen3vl "visual.blocks.{bid}.attn.qkv", # qwen3vl
"model.vision.transformer.layers.{bid}.attention.query_key_value", # cogvlm "model.vision.transformer.layers.{bid}.attention.query_key_value", # cogvlm
"model.vision_model.transformer.layers.{bid}.self_attn.qkv_proj", # Deepseek-OCR CLIP
), ),
MODEL_TENSOR.V_ENC_ATTN_Q: ( MODEL_TENSOR.V_ENC_ATTN_Q: (
@ -1277,6 +1289,7 @@ class TensorNameMap:
"vision_encoder.transformer.layers.{bid}.attention.wq", # pixtral "vision_encoder.transformer.layers.{bid}.attention.wq", # pixtral
"visual.blocks.{bid}.attn.q", # qwen2vl, generated "visual.blocks.{bid}.attn.q", # qwen2vl, generated
"vision_tower.encoder.blocks.{bid}.wq", # kimi-vl, generated "vision_tower.encoder.blocks.{bid}.wq", # kimi-vl, generated
"model.vision_model.transformer.layers.{bid}.self_attn.q_proj", # Deepseek-OCR CLIP, generated
), ),
MODEL_TENSOR.V_ENC_ATTN_Q_NORM: ( MODEL_TENSOR.V_ENC_ATTN_Q_NORM: (
@ -1294,6 +1307,7 @@ class TensorNameMap:
"vision_encoder.transformer.layers.{bid}.attention.wk", # pixtral "vision_encoder.transformer.layers.{bid}.attention.wk", # pixtral
"visual.blocks.{bid}.attn.k", # qwen2vl, generated "visual.blocks.{bid}.attn.k", # qwen2vl, generated
"vision_tower.encoder.blocks.{bid}.wk", # kimi-vl, generated "vision_tower.encoder.blocks.{bid}.wk", # kimi-vl, generated
"model.vision_model.transformer.layers.{bid}.self_attn.k_proj", # Deepseek-OCR CLIP, generated
), ),
MODEL_TENSOR.V_ENC_ATTN_K_NORM: ( MODEL_TENSOR.V_ENC_ATTN_K_NORM: (
@ -1311,6 +1325,7 @@ class TensorNameMap:
"vision_encoder.transformer.layers.{bid}.attention.wv", # pixtral "vision_encoder.transformer.layers.{bid}.attention.wv", # pixtral
"visual.blocks.{bid}.attn.v", # qwen2vl, generated "visual.blocks.{bid}.attn.v", # qwen2vl, generated
"vision_tower.encoder.blocks.{bid}.wv", # kimi-vl, generated "vision_tower.encoder.blocks.{bid}.wv", # kimi-vl, generated
"model.vision_model.transformer.layers.{bid}.self_attn.v_proj", # Deepseek-OCR CLIP, generated
), ),
MODEL_TENSOR.V_ENC_INPUT_NORM: ( MODEL_TENSOR.V_ENC_INPUT_NORM: (
@ -1325,6 +1340,7 @@ class TensorNameMap:
"visual.blocks.{bid}.norm1", # qwen2vl "visual.blocks.{bid}.norm1", # qwen2vl
"vision_tower.encoder.blocks.{bid}.norm0", # kimi-vl (norm0/norm1) "vision_tower.encoder.blocks.{bid}.norm0", # kimi-vl (norm0/norm1)
"model.vision.transformer.layers.{bid}.input_layernorm", # cogvlm "model.vision.transformer.layers.{bid}.input_layernorm", # cogvlm
"model.vision_model.transformer.layers.{bid}.layer_norm1", # Deepseek-OCR CLIP
), ),
MODEL_TENSOR.V_ENC_ATTN_O: ( MODEL_TENSOR.V_ENC_ATTN_O: (
@ -1340,6 +1356,7 @@ class TensorNameMap:
"visual.blocks.{bid}.attn.proj", # qwen2vl "visual.blocks.{bid}.attn.proj", # qwen2vl
"vision_tower.encoder.blocks.{bid}.wo", # kimi-vl "vision_tower.encoder.blocks.{bid}.wo", # kimi-vl
"model.vision.transformer.layers.{bid}.attention.dense", # cogvlm "model.vision.transformer.layers.{bid}.attention.dense", # cogvlm
"model.vision_model.transformer.layers.{bid}.self_attn.out_proj", # Deepseek-OCR CLIP
), ),
MODEL_TENSOR.V_ENC_POST_ATTN_NORM: ( MODEL_TENSOR.V_ENC_POST_ATTN_NORM: (
@ -1354,6 +1371,7 @@ class TensorNameMap:
"visual.blocks.{bid}.norm2", # qwen2vl "visual.blocks.{bid}.norm2", # qwen2vl
"vision_tower.encoder.blocks.{bid}.norm1", # kimi-vl (norm0/norm1) "vision_tower.encoder.blocks.{bid}.norm1", # kimi-vl (norm0/norm1)
"model.vision.transformer.layers.{bid}.post_attention_layernorm", # cogvlm "model.vision.transformer.layers.{bid}.post_attention_layernorm", # cogvlm
"model.vision_model.transformer.layers.{bid}.layer_norm2", # Deepseek-OCR CLIP
), ),
MODEL_TENSOR.V_ENC_FFN_UP: ( MODEL_TENSOR.V_ENC_FFN_UP: (
@ -1368,6 +1386,7 @@ class TensorNameMap:
"visual.blocks.{bid}.mlp.up_proj", # qwen2.5vl "visual.blocks.{bid}.mlp.up_proj", # qwen2.5vl
"visual.blocks.{bid}.mlp.linear_fc1", # qwen3vl "visual.blocks.{bid}.mlp.linear_fc1", # qwen3vl
"vision_tower.encoder.blocks.{bid}.mlp.fc0", # kimi-vl (fc0/fc1) "vision_tower.encoder.blocks.{bid}.mlp.fc0", # kimi-vl (fc0/fc1)
"model.vision_model.transformer.layers.{bid}.mlp.fc1", # Deepseek-OCR CLIP
"model.vision.transformer.layers.{bid}.mlp.fc1", # cogvlm "model.vision.transformer.layers.{bid}.mlp.fc1", # cogvlm
), ),
@ -1390,6 +1409,7 @@ class TensorNameMap:
"visual.blocks.{bid}.mlp.linear_fc2", # qwen3vl "visual.blocks.{bid}.mlp.linear_fc2", # qwen3vl
"vision_tower.encoder.blocks.{bid}.mlp.fc1", # kimi-vl (fc0/fc1) "vision_tower.encoder.blocks.{bid}.mlp.fc1", # kimi-vl (fc0/fc1)
"model.vision.transformer.layers.{bid}.mlp.fc2", # cogvlm "model.vision.transformer.layers.{bid}.mlp.fc2", # cogvlm
"model.vision_model.transformer.layers.{bid}.mlp.fc2", # Deepseek-OCR CLIP
), ),
MODEL_TENSOR.V_LAYER_SCALE_1: ( MODEL_TENSOR.V_LAYER_SCALE_1: (
@ -1407,6 +1427,7 @@ class TensorNameMap:
"vision_tower.ln_pre", # pixtral-hf "vision_tower.ln_pre", # pixtral-hf
"vision_encoder.ln_pre", # pixtral "vision_encoder.ln_pre", # pixtral
"vision_model.layernorm_pre", # llama4 "vision_model.layernorm_pre", # llama4
"model.vision_model.pre_layrnorm", # Deepseek-OCR CLIP
), ),
MODEL_TENSOR.V_POST_NORM: ( MODEL_TENSOR.V_POST_NORM: (
@ -1504,6 +1525,58 @@ class TensorNameMap:
"model.visual.deepstack_merger_list.{bid}.linear_fc2", # deepstack in qwen3vl "model.visual.deepstack_merger_list.{bid}.linear_fc2", # deepstack in qwen3vl
), ),
MODEL_TENSOR.V_SAM_POS_EMBD: (
"model.sam_model.pos_embed",
),
MODEL_TENSOR.V_SAM_PATCH_EMBD: (
"model.sam_model.patch_embed.proj",
),
MODEL_TENSOR.V_SAM_PRE_NORM: (
"model.sam_model.blocks.{bid}.norm1", # deepstack in qwen3vl
),
MODEL_TENSOR.V_SAM_POST_NORM: (
"model.sam_model.blocks.{bid}.norm2", # deepstack in qwen3vl
),
MODEL_TENSOR.V_SAM_ATTN_POS_H: (
"model.sam_model.blocks.{bid}.attn.rel_pos_h",
),
MODEL_TENSOR.V_SAM_ATTN_POS_W: (
"model.sam_model.blocks.{bid}.attn.rel_pos_w",
),
MODEL_TENSOR.V_SAM_ATTN_QKV: (
"model.sam_model.blocks.{bid}.attn.qkv",
),
MODEL_TENSOR.V_SAM_ATTN_OUT: (
"model.sam_model.blocks.{bid}.attn.proj",
),
MODEL_TENSOR.V_SAM_MLP_LIN_1: (
"model.sam_model.blocks.{bid}.mlp.lin1",
),
MODEL_TENSOR.V_SAM_MLP_LIN_2: (
"model.sam_model.blocks.{bid}.mlp.lin2",
),
MODEL_TENSOR.V_SAM_NECK: (
"model.sam_model.neck.{bid}",
),
MODEL_TENSOR.V_SAM_NET_2: (
"model.sam_model.net_2",
),
MODEL_TENSOR.V_SAM_NET_3: (
"model.sam_model.net_3",
),
MODEL_TENSOR.V_MM_POST_FC_NORM: ( MODEL_TENSOR.V_MM_POST_FC_NORM: (
"model.vision.linear_proj.norm1", # cogvlm "model.vision.linear_proj.norm1", # cogvlm
), ),

View File

@ -67,6 +67,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_ARCTIC, "arctic" }, { LLM_ARCH_ARCTIC, "arctic" },
{ LLM_ARCH_DEEPSEEK, "deepseek" }, { LLM_ARCH_DEEPSEEK, "deepseek" },
{ LLM_ARCH_DEEPSEEK2, "deepseek2" }, { LLM_ARCH_DEEPSEEK2, "deepseek2" },
{ LLM_ARCH_DEEPSEEK2OCR, "deepseek2-ocr" },
{ LLM_ARCH_CHATGLM, "chatglm" }, { LLM_ARCH_CHATGLM, "chatglm" },
{ LLM_ARCH_GLM4, "glm4" }, { LLM_ARCH_GLM4, "glm4" },
{ LLM_ARCH_GLM4_MOE, "glm4moe" }, { LLM_ARCH_GLM4_MOE, "glm4moe" },
@ -1462,6 +1463,38 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
LLM_TENSOR_FFN_UP_SHEXP, LLM_TENSOR_FFN_UP_SHEXP,
LLM_TENSOR_FFN_EXP_PROBS_B, LLM_TENSOR_FFN_EXP_PROBS_B,
}; };
case LLM_ARCH_DEEPSEEK2OCR:
return {
LLM_TENSOR_TOKEN_EMBD,
LLM_TENSOR_OUTPUT_NORM,
LLM_TENSOR_OUTPUT,
LLM_TENSOR_ATTN_NORM,
LLM_TENSOR_ATTN_Q_A_NORM,
LLM_TENSOR_ATTN_KV_A_NORM,
LLM_TENSOR_ATTN_Q,
LLM_TENSOR_ATTN_K,
LLM_TENSOR_ATTN_V,
LLM_TENSOR_ATTN_Q_A,
LLM_TENSOR_ATTN_Q_B,
LLM_TENSOR_ATTN_KV_A_MQA,
LLM_TENSOR_ATTN_KV_B,
LLM_TENSOR_ATTN_K_B,
LLM_TENSOR_ATTN_V_B,
LLM_TENSOR_ATTN_OUT,
LLM_TENSOR_FFN_NORM,
LLM_TENSOR_FFN_GATE,
LLM_TENSOR_FFN_UP,
LLM_TENSOR_FFN_DOWN,
LLM_TENSOR_FFN_GATE_INP,
LLM_TENSOR_FFN_GATE_EXPS,
LLM_TENSOR_FFN_DOWN_EXPS,
LLM_TENSOR_FFN_UP_EXPS,
LLM_TENSOR_FFN_GATE_INP_SHEXP,
LLM_TENSOR_FFN_GATE_SHEXP,
LLM_TENSOR_FFN_DOWN_SHEXP,
LLM_TENSOR_FFN_UP_SHEXP,
LLM_TENSOR_FFN_EXP_PROBS_B,
};
case LLM_ARCH_PLM: case LLM_ARCH_PLM:
return { return {
LLM_TENSOR_TOKEN_EMBD, LLM_TENSOR_TOKEN_EMBD,

View File

@ -71,6 +71,7 @@ enum llm_arch {
LLM_ARCH_ARCTIC, LLM_ARCH_ARCTIC,
LLM_ARCH_DEEPSEEK, LLM_ARCH_DEEPSEEK,
LLM_ARCH_DEEPSEEK2, LLM_ARCH_DEEPSEEK2,
LLM_ARCH_DEEPSEEK2OCR,
LLM_ARCH_CHATGLM, LLM_ARCH_CHATGLM,
LLM_ARCH_GLM4, LLM_ARCH_GLM4,
LLM_ARCH_GLM4_MOE, LLM_ARCH_GLM4_MOE,

View File

@ -49,6 +49,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
{ "deepseek", LLM_CHAT_TEMPLATE_DEEPSEEK }, { "deepseek", LLM_CHAT_TEMPLATE_DEEPSEEK },
{ "deepseek2", LLM_CHAT_TEMPLATE_DEEPSEEK_2 }, { "deepseek2", LLM_CHAT_TEMPLATE_DEEPSEEK_2 },
{ "deepseek3", LLM_CHAT_TEMPLATE_DEEPSEEK_3 }, { "deepseek3", LLM_CHAT_TEMPLATE_DEEPSEEK_3 },
{ "deepseek-ocr", LLM_CHAT_TEMPLATE_DEEPSEEK_OCR },
{ "command-r", LLM_CHAT_TEMPLATE_COMMAND_R }, { "command-r", LLM_CHAT_TEMPLATE_COMMAND_R },
{ "llama3", LLM_CHAT_TEMPLATE_LLAMA_3 }, { "llama3", LLM_CHAT_TEMPLATE_LLAMA_3 },
{ "chatglm3", LLM_CHAT_TEMPLATE_CHATGLM_3 }, { "chatglm3", LLM_CHAT_TEMPLATE_CHATGLM_3 },
@ -541,6 +542,11 @@ int32_t llm_chat_apply_template(
if (add_ass) { if (add_ass) {
ss << LU8("<Assistant>"); ss << LU8("<Assistant>");
} }
} else if (tmpl == LLM_CHAT_TEMPLATE_DEEPSEEK_OCR) {
for (auto message : chat) {
// no template
ss << message->content;
}
} else if (tmpl == LLM_CHAT_TEMPLATE_EXAONE_3) { } else if (tmpl == LLM_CHAT_TEMPLATE_EXAONE_3) {
// ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb // ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb
// EXAONE-3.0-7.8B-Instruct // EXAONE-3.0-7.8B-Instruct

View File

@ -28,6 +28,7 @@ enum llm_chat_template {
LLM_CHAT_TEMPLATE_DEEPSEEK, LLM_CHAT_TEMPLATE_DEEPSEEK,
LLM_CHAT_TEMPLATE_DEEPSEEK_2, LLM_CHAT_TEMPLATE_DEEPSEEK_2,
LLM_CHAT_TEMPLATE_DEEPSEEK_3, LLM_CHAT_TEMPLATE_DEEPSEEK_3,
LLM_CHAT_TEMPLATE_DEEPSEEK_OCR,
LLM_CHAT_TEMPLATE_COMMAND_R, LLM_CHAT_TEMPLATE_COMMAND_R,
LLM_CHAT_TEMPLATE_LLAMA_3, LLM_CHAT_TEMPLATE_LLAMA_3,
LLM_CHAT_TEMPLATE_CHATGLM_3, LLM_CHAT_TEMPLATE_CHATGLM_3,

View File

@ -1168,7 +1168,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
if (!weight_before_ffn) { if (!weight_before_ffn) {
experts = ggml_mul(ctx0, experts, weights); experts = ggml_mul(ctx0, experts, weights);
cb(cur, "ffn_moe_weighted", il); cb(experts, "ffn_moe_weighted", il);
} }
ggml_tensor * cur_experts[LLAMA_MAX_EXPERTS] = { nullptr }; ggml_tensor * cur_experts[LLAMA_MAX_EXPERTS] = { nullptr };

View File

@ -1400,6 +1400,14 @@ ggml_tensor * llama_kv_cache::build_rope_shift(
? LLAMA_ROPE_TYPE_NEOX ? LLAMA_ROPE_TYPE_NEOX
: hparams.rope_type; : hparams.rope_type;
// See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly.
// See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
/*
const float yarn_attn_factor = (model.arch == LLM_ARCH_DEEPSEEK2 || model.arch == LLM_ARCH_DEEPSEEK2OCR)
? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale))
: cparams.yarn_attn_factor;
*/
ggml_tensor * tmp; ggml_tensor * tmp;
if (ggml_is_quantized(cur->type)) { if (ggml_is_quantized(cur->type)) {

View File

@ -1616,15 +1616,19 @@ void llama_model::load_hparams(llama_model_loader & ml) {
} }
} break; } break;
case LLM_ARCH_DEEPSEEK2: case LLM_ARCH_DEEPSEEK2:
case LLM_ARCH_DEEPSEEK2OCR:
{ {
// lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B // lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B
bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26); bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26);
bool is_ocr = (arch == LLM_ARCH_DEEPSEEK2OCR);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead);
if (!is_lite) { if (!is_lite && !is_ocr) {
ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q);
} }
ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); if (!is_ocr) {
ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv);
}
ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla, false); ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla, false);
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla, false); ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla, false);
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
@ -1651,6 +1655,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
hparams.f_attn_temp_offset = 0.0f; hparams.f_attn_temp_offset = 0.0f;
switch (hparams.n_layer) { switch (hparams.n_layer) {
case 12: type = LLM_TYPE_3B; break;
case 27: type = LLM_TYPE_16B; break; case 27: type = LLM_TYPE_16B; break;
case 60: type = LLM_TYPE_236B; break; case 60: type = LLM_TYPE_236B; break;
case 61: type = LLM_TYPE_671B; break; case 61: type = LLM_TYPE_671B; break;
@ -4677,9 +4682,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
} }
} break; } break;
case LLM_ARCH_DEEPSEEK2: case LLM_ARCH_DEEPSEEK2:
case LLM_ARCH_DEEPSEEK2OCR:
{ {
// lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B // lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B
const bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26); const bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26);
const bool is_ocr = (arch == LLM_ARCH_DEEPSEEK2OCR);
const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0); const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0);
@ -4705,6 +4712,35 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
for (int i = 0; i < n_layer; ++i) { for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i]; auto & layer = layers[i];
if (is_ocr) {
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd}, 0);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd}, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
if (i < (int) hparams.n_layer_dense_lead) {
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
}
else {
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED);
// MoE branch
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0);
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
// Shared expert branch
layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0);
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
}
continue;
}
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
if (!is_lite) { if (!is_lite) {
layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, 0); layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, 0);
@ -6910,7 +6946,7 @@ void llama_model::print_info() const {
LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale);
} }
if (arch == LLM_ARCH_DEEPSEEK2) { if (arch == LLM_ARCH_DEEPSEEK2 || arch == LLM_ARCH_DEEPSEEK2OCR) {
LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead);
LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q); LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q);
LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv); LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv);
@ -7441,6 +7477,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
llm = std::make_unique<llm_build_deepseek>(*this, params); llm = std::make_unique<llm_build_deepseek>(*this, params);
} break; } break;
case LLM_ARCH_DEEPSEEK2: case LLM_ARCH_DEEPSEEK2:
case LLM_ARCH_DEEPSEEK2OCR:
{ {
llm = std::make_unique<llm_build_deepseek2>(*this, params); llm = std::make_unique<llm_build_deepseek2>(*this, params);
} break; } break;
@ -7792,6 +7829,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_ARCTIC: case LLM_ARCH_ARCTIC:
case LLM_ARCH_DEEPSEEK: case LLM_ARCH_DEEPSEEK:
case LLM_ARCH_DEEPSEEK2: case LLM_ARCH_DEEPSEEK2:
case LLM_ARCH_DEEPSEEK2OCR:
case LLM_ARCH_PLM: case LLM_ARCH_PLM:
case LLM_ARCH_CHATGLM: case LLM_ARCH_CHATGLM:
case LLM_ARCH_GRANITE: case LLM_ARCH_GRANITE:

View File

@ -2364,6 +2364,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|| t.first == "_<EOT>" || t.first == "_<EOT>"
|| t.first == "<|end_of_text|>" || t.first == "<|end_of_text|>"
|| t.first == "<end_of_utterance>" // smoldocling || t.first == "<end_of_utterance>" // smoldocling
|| t.first == "<end▁of▁sentence>" // deepseek-ocr
) { ) {
special_eog_ids.insert(t.second); special_eog_ids.insert(t.second);
if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {

View File

@ -4,6 +4,7 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
llm_graph_context(params) { llm_graph_context(params) {
// lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B // lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B
bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26); bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26);
bool is_ocr = (model.arch == LLM_ARCH_DEEPSEEK2OCR);
const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0); const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0);
@ -55,7 +56,38 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
cb(cur, "attn_norm", il); cb(cur, "attn_norm", il);
// self_attention // self_attention
{ if (is_ocr) {
const int n_embed_head = hparams.n_embd / hparams.n_head();
const int ocr_rope_type = GGML_ROPE_TYPE_NEOX;
GGML_ASSERT(n_embed_head == n_embd_head_k && n_embed_head == n_embd_head_v);
ggml_tensor * Qcur = NULL;
ggml_tensor * Kcur = NULL;
ggml_tensor * Vcur = NULL;
Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
cb(Qcur, "q", il);
cb(Kcur, "k", il);
cb(Vcur, "v", il);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embed_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embed_head, n_head, n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embed_head, n_head, n_tokens);
GGML_ASSERT(fabs(freq_base - 10000.0) < 1e-4);
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_embed_head, ocr_rope_type, 0, freq_base, 1, 0, 1, 0, 0);
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_embed_head, ocr_rope_type, 0, freq_base, 1, 0, 1, 0, 0);
cb(Qcur, "q_pe", il);
cb(Kcur, "k_pe", il);
cur = build_attn(inp_attn,
model.layers[il].wo, NULL,
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
cb(cur, "attn_out", il);
}
else {
ggml_tensor * q = NULL; ggml_tensor * q = NULL;
if (!is_lite) { if (!is_lite) {
q = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur); q = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur);

View File

@ -26,6 +26,7 @@ add_library(mtmd
models/qwen3vl.cpp models/qwen3vl.cpp
models/siglip.cpp models/siglip.cpp
models/whisper-enc.cpp models/whisper-enc.cpp
models/deepseekocr.cpp
) )
set_target_properties(mtmd PROPERTIES set_target_properties(mtmd PROPERTIES

View File

@ -7,6 +7,7 @@
#include <climits> #include <climits>
#include <cstdarg> #include <cstdarg>
#include <cinttypes> #include <cinttypes>
#include <cstring>
#include <string> #include <string>
#include <map> #include <map>
#include <sstream> #include <sstream>
@ -52,7 +53,9 @@
#define KEY_ATTN_WINDOW_SIZE "clip.vision.window_size" #define KEY_ATTN_WINDOW_SIZE "clip.vision.window_size"
#define KEY_MINICPMV_VERSION "clip.minicpmv_version" #define KEY_MINICPMV_VERSION "clip.minicpmv_version"
#define KEY_MINICPMV_QUERY_NUM "clip.minicpmv_query_num" #define KEY_MINICPMV_QUERY_NUM "clip.minicpmv_query_num"
#define KEY_SAM_N_HEAD "clip.vision.sam.head_count"
#define KEY_SAM_N_BLOCK "clip.vision.sam.block_count"
#define KEY_SAM_N_EMBD "clip.vision.sam.embedding_length"
// audio-specific // audio-specific
#define KEY_AUDIO_PROJ_TYPE "clip.audio.projector_type" // for models with mixed modalities #define KEY_AUDIO_PROJ_TYPE "clip.audio.projector_type" // for models with mixed modalities
#define KEY_A_NUM_MEL_BINS "clip.audio.num_mel_bins" #define KEY_A_NUM_MEL_BINS "clip.audio.num_mel_bins"
@ -94,12 +97,13 @@
#define TN_MVLM_PROJ_MLP "mm.model.mlp.%d.%s" #define TN_MVLM_PROJ_MLP "mm.model.mlp.%d.%s"
#define TN_MVLM_PROJ_BLOCK "mm.model.mb_block.%d.block.%d.%s" #define TN_MVLM_PROJ_BLOCK "mm.model.mb_block.%d.block.%d.%s"
#define TN_MVLM_PROJ_PEG "mm.model.peg.%d.%s" #define TN_MVLM_PROJ_PEG "mm.model.peg.%d.%s"
#define TN_IMAGE_NEWLINE "model.image_newline" #define TN_IMAGE_NEWLINE "v.image_newline"
#define TN_IMAGE_SEPERATOR "v.view_seperator"
#define TN_MM_INP_NORM "mm.input_norm.weight" #define TN_MM_INP_NORM "mm.input_norm.weight"
#define TN_MM_INP_NORM_B "mm.input_norm.bias" #define TN_MM_INP_NORM_B "mm.input_norm.bias"
#define TN_MM_INP_PROJ "mm.input_projection.weight" // gemma3 #define TN_MM_INP_PROJ "mm.input_projection.weight" // gemma3
#define TN_MM_SOFT_EMB_N "mm.soft_emb_norm.weight" // gemma3 #define TN_MM_SOFT_EMB_N "mm.soft_emb_norm.weight" // gemma3
#define TN_MM_PROJECTOR "mm.model.fc.weight" // idefics3 #define TN_MM_PROJECTOR "mm.model.fc.%s" // idefics3, deepseekocr
#define TN_MM_PATCH_MERGER "mm.patch_merger.%s" // mistral small 3.1, glm4v #define TN_MM_PATCH_MERGER "mm.patch_merger.%s" // mistral small 3.1, glm4v
#define TN_TOK_IMG_BREAK "v.token_embd.img_break" // pixtral #define TN_TOK_IMG_BREAK "v.token_embd.img_break" // pixtral
#define TN_TOK_GLM_BOI "adapter.boi" // glm-edge (these embeddings are not in text model) #define TN_TOK_GLM_BOI "adapter.boi" // glm-edge (these embeddings are not in text model)
@ -138,6 +142,20 @@
#define TN_TOK_BOI "v.boi" #define TN_TOK_BOI "v.boi"
#define TN_TOK_EOI "v.eoi" #define TN_TOK_EOI "v.eoi"
// deepseek-ocr
#define TN_SAM_POS_EMBD "v.sam.pos_embd"
#define TN_SAM_PATCH_EMBD "v.sam.patch_embd.%s"
#define TN_SAM_PRE_NORM "v.sam.blk.%d.pre_ln.%s"
#define TN_SAM_POST_NORM "v.sam.blk.%d.post_ln.%s"
#define TN_SAM_ATTN_POS_H "v.sam.blk.%d.attn.pos_h"
#define TN_SAM_ATTN_POS_W "v.sam.blk.%d.attn.pos_w"
#define TN_SAM_ATTN_QKV "v.sam.blk.%d.attn.qkv.%s"
#define TN_SAM_ATTN_OUT "v.sam.blk.%d.attn.out.%s"
#define TN_SAM_FFN_UP "v.sam.blk.%d.mlp.lin1.%s"
#define TN_SAM_FFN_DOWN "v.sam.blk.%d.mlp.lin2.%s"
#define TN_SAM_NECK "v.sam.neck.%d.%s"
#define TN_SAM_NET "v.sam.net_%d.%s"
// align x to upper multiple of n // align x to upper multiple of n
#define CLIP_ALIGN(x, n) ((((x) + (n) - 1) / (n)) * (n)) #define CLIP_ALIGN(x, n) ((((x) + (n) - 1) / (n)) * (n))
@ -170,6 +188,7 @@ enum projector_type {
PROJECTOR_TYPE_LIGHTONOCR, PROJECTOR_TYPE_LIGHTONOCR,
PROJECTOR_TYPE_COGVLM, PROJECTOR_TYPE_COGVLM,
PROJECTOR_TYPE_JANUS_PRO, PROJECTOR_TYPE_JANUS_PRO,
PROJECTOR_TYPE_DEEPSEEKOCR,
PROJECTOR_TYPE_GLM4V, PROJECTOR_TYPE_GLM4V,
PROJECTOR_TYPE_UNKNOWN, PROJECTOR_TYPE_UNKNOWN,
}; };
@ -198,6 +217,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
{ PROJECTOR_TYPE_LIGHTONOCR,"lightonocr"}, { PROJECTOR_TYPE_LIGHTONOCR,"lightonocr"},
{ PROJECTOR_TYPE_COGVLM, "cogvlm"}, { PROJECTOR_TYPE_COGVLM, "cogvlm"},
{ PROJECTOR_TYPE_JANUS_PRO, "janus_pro"}, { PROJECTOR_TYPE_JANUS_PRO, "janus_pro"},
{ PROJECTOR_TYPE_DEEPSEEKOCR,"deepseekocr"},
{ PROJECTOR_TYPE_GLM4V, "glm4v"}, { PROJECTOR_TYPE_GLM4V, "glm4v"},
}; };
@ -442,6 +462,32 @@ static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) {
// debugging // debugging
// //
static std::string to_ne_string(const ggml_tensor * t) {
std::string str;
for (int i = 0; i < GGML_MAX_DIMS; ++i) {
str += std::to_string(t->ne[i]);
if (i + 1 < GGML_MAX_DIMS) {
str += ", ";
}
}
return str;
}
static void print_tensor_info(ggml_tensor * t) {
const struct ggml_tensor * src0 = t->src[0];
const struct ggml_tensor * src1 = t->src[1];
char src1_str[128] = {0};
if (src1) {
snprintf(src1_str, sizeof(src1_str), "%s{%s}", src1->name, to_ne_string(src1).c_str());
}
printf("%s: %s = %s(%s{%s}, %s)\n",
t->name, ggml_type_name(t->type), ggml_op_desc(t),
src0->name, to_ne_string(src0).c_str(),
src1 ? src1_str : "");
}
static void print_tensor_shape(ggml_tensor * t) { static void print_tensor_shape(ggml_tensor * t) {
printf("%s.shape = [", t->name); printf("%s.shape = [", t->name);
for (int i = 0; i < ggml_n_dims(t); ++i) { for (int i = 0; i < ggml_n_dims(t); ++i) {
@ -453,12 +499,50 @@ static void print_tensor_shape(ggml_tensor * t) {
printf("]\n"); printf("]\n");
} }
static void print_tensor_sum(ggml_tensor * t, uint8_t * data, int64_t n) {
(void) n; // unused parameter
ggml_type type = t->type;
int64_t * ne = t->ne;
size_t * nb = t->nb;
double sum = 0.0;
for (int64_t i3 = 0; i3 < ne[3]; i3++) {
for (int64_t i2 = 0; i2 < ne[2]; i2++) {
for (int64_t i1 = 0; i1 < ne[1]; i1++) {
for (int64_t i0 = 0; i0 < ne[0]; i0++) {
size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0];
float v;
if (type == GGML_TYPE_F16) {
v = ggml_fp16_to_fp32(*(ggml_fp16_t *) &data[i]);
} else if (type == GGML_TYPE_F32) {
v = *(float *) &data[i];
} else if (type == GGML_TYPE_I32) {
v = (float) *(int32_t *) &data[i];
} else if (type == GGML_TYPE_I16) {
v = (float) *(int16_t *) &data[i];
} else if (type == GGML_TYPE_I8) {
v = (float) *(int8_t *) &data[i];
} else {
GGML_ABORT("fatal error");
}
sum += v;
}
}
}
}
printf("%s.sum = %.6f\n", t->name, sum);
}
static void print_tensor_data(ggml_tensor * t, uint8_t * data, int64_t n) { static void print_tensor_data(ggml_tensor * t, uint8_t * data, int64_t n) {
ggml_type type = t->type; ggml_type type = t->type;
int64_t * ne = t->ne; int64_t * ne = t->ne;
size_t * nb = t->nb; size_t * nb = t->nb;
printf("%s.data: [\n", t->name);
for (int64_t i3 = 0; i3 < ne[3]; i3++) { for (int64_t i3 = 0; i3 < ne[3]; i3++) {
printf("%s.data: [\n", t->name); if (i3 == n && ne[3] > 2*n) {
printf(" ..., \n");
i3 = ne[3] - n;
}
printf(" [\n");
for (int64_t i2 = 0; i2 < ne[2]; i2++) { for (int64_t i2 = 0; i2 < ne[2]; i2++) {
if (i2 == n && ne[2] > 2*n) { if (i2 == n && ne[2] > 2*n) {
printf(" ..., \n"); printf(" ..., \n");
@ -500,6 +584,122 @@ static void print_tensor_data(ggml_tensor * t, uint8_t * data, int64_t n) {
} }
printf(" ]\n"); printf(" ]\n");
} }
printf(" ]\n");
}
static void save_tensor_to_file(const struct ggml_tensor * tensor, const uint8_t * data_ptr) {
char filename[512];
snprintf(filename, sizeof(filename), "%s_cpp.txt", tensor->name);
FILE * f = fopen(filename, "w");
if (!f) {
fprintf(stderr, "Failed to open %s\n", filename);
return;
}
// Check tensor size and warn if too large
int64_t total_elements = ggml_nelements(tensor);
fprintf(stderr, "Saving tensor %s (%lld elements) to %s\n",
tensor->name, (long long)total_elements, filename);
if (total_elements > 10000000) { // 10M elements
fprintf(stderr, "Warning: tensor is very large (%lld elements), this may take time\n",
(long long)total_elements);
}
const uint8_t * data = (data_ptr) ? data_ptr : (uint8_t *) tensor->data;
ggml_type type = tensor->type;
const int64_t * ne = tensor->ne;
const size_t * nb = tensor->nb;
// Use a buffer to reduce I/O calls
const size_t BUF_SIZE = 8192;
char * buf = (char *) malloc(BUF_SIZE);
if (!buf) {
fprintf(stderr, "Failed to allocate buffer\n");
fclose(f);
return;
}
size_t buf_pos = 0;
// Helper lambda to flush buffer
auto flush_buf = [&]() {
if (buf_pos > 0) {
fwrite(buf, 1, buf_pos, f);
buf_pos = 0;
}
};
// Helper to append to buffer
auto append = [&](const char * str, size_t len) {
if (buf_pos + len >= BUF_SIZE) {
flush_buf();
}
if (len >= BUF_SIZE) {
// String too large for buffer, write directly
fwrite(str, 1, len, f);
} else {
memcpy(buf + buf_pos, str, len);
buf_pos += len;
}
};
auto append_str = [&](const char * str) {
append(str, strlen(str));
};
char num_buf[32];
// Write header once for all batches
append_str(tensor->name);
append_str(".data: [\n");
for (int64_t i3 = 0; i3 < ne[3]; i3++) {
append_str(" [\n"); // Start of batch
for (int64_t i2 = 0; i2 < ne[2]; i2++) {
append_str(" [\n");
for (int64_t i1 = 0; i1 < ne[1]; i1++) {
append_str(" [");
for (int64_t i0 = 0; i0 < ne[0]; i0++) {
size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0];
float v;
if (type == GGML_TYPE_F16) {
v = ggml_fp16_to_fp32(*(ggml_fp16_t *) &data[i]);
} else if (type == GGML_TYPE_F32) {
v = *(float *) &data[i];
} else if (type == GGML_TYPE_I32) {
v = (float) *(int32_t *) &data[i];
} else if (type == GGML_TYPE_I16) {
v = (float) *(int16_t *) &data[i];
} else if (type == GGML_TYPE_I8) {
v = (float) *(int8_t *) &data[i];
} else {
GGML_ABORT("fatal error");
}
int len = snprintf(num_buf, sizeof(num_buf), "%8.4f", v);
append(num_buf, len);
if (i0 < ne[0] - 1) {
append_str(", ");
}
}
append_str("],\n");
}
append_str(" ],\n");
}
append_str(" ]"); // End of batch
if (i3 < ne[3] - 1) {
append_str(",\n"); // Comma between batches
} else {
append_str("\n");
}
}
append_str("]\n"); // Close the top-level array
flush_buf();
free(buf);
fclose(f);
fprintf(stderr, "Tensor saved successfully\n");
} }
void clip_debug_encode(clip_ctx * ctx, int h, int w, float fill_value); void clip_debug_encode(clip_ctx * ctx, int h, int w, float fill_value);

View File

@ -61,6 +61,11 @@ struct clip_hparams {
int32_t attn_window_size = 0; int32_t attn_window_size = 0;
int32_t n_wa_pattern = 0; int32_t n_wa_pattern = 0;
// deepseek-ocr (sam)
int32_t sam_n_layer = 0;
int32_t sam_n_head = 0;
int32_t sam_n_embd = 0;
// audio // audio
int32_t n_mel_bins = 0; // whisper preprocessor int32_t n_mel_bins = 0; // whisper preprocessor
int32_t proj_stack_factor = 0; // ultravox int32_t proj_stack_factor = 0; // ultravox
@ -96,6 +101,21 @@ struct clip_hparams {
warmup_image_size = n_tok_per_side * patch_size * cur_merge; warmup_image_size = n_tok_per_side * patch_size * cur_merge;
// TODO: support warmup size for custom token numbers // TODO: support warmup size for custom token numbers
} }
// sam vit deepseek-ocr
std::vector<int32_t> global_attn_indices() const {
return { 2, 5, 8, 11 };
}
bool is_global_attn(int32_t layer) const {
const auto indices = global_attn_indices();
for (const auto & idx : indices) {
if (layer == idx) {
return true;
}
}
return false;
}
}; };
struct clip_layer { struct clip_layer {
@ -142,6 +162,10 @@ struct clip_layer {
ggml_tensor * deepstack_fc2_w = nullptr; ggml_tensor * deepstack_fc2_w = nullptr;
ggml_tensor * deepstack_fc2_b = nullptr; ggml_tensor * deepstack_fc2_b = nullptr;
// sam rel_pos
ggml_tensor * rel_pos_w = nullptr;
ggml_tensor * rel_pos_h = nullptr;
bool has_deepstack() const { bool has_deepstack() const {
return deepstack_fc1_w != nullptr; return deepstack_fc1_w != nullptr;
} }
@ -171,7 +195,8 @@ struct clip_model {
ggml_tensor * post_ln_w; ggml_tensor * post_ln_w;
ggml_tensor * post_ln_b; ggml_tensor * post_ln_b;
ggml_tensor * projection; // TODO: rename it to fc (fully connected layer) ggml_tensor * fc_w;
ggml_tensor * fc_b;
ggml_tensor * mm_fc_w; ggml_tensor * mm_fc_w;
ggml_tensor * mm_fc_b; ggml_tensor * mm_fc_b;
ggml_tensor * mm_ffn_up_w = nullptr; ggml_tensor * mm_ffn_up_w = nullptr;
@ -192,6 +217,8 @@ struct clip_model {
ggml_tensor * mm_2_b = nullptr; ggml_tensor * mm_2_b = nullptr;
ggml_tensor * image_newline = nullptr; ggml_tensor * image_newline = nullptr;
ggml_tensor * view_seperator = nullptr;
// Yi type models with mlp+normalization projection // Yi type models with mlp+normalization projection
ggml_tensor * mm_1_w = nullptr; // Yi type models have 0, 1, 3, 4 ggml_tensor * mm_1_w = nullptr; // Yi type models have 0, 1, 3, 4
@ -286,6 +313,24 @@ struct clip_model {
ggml_tensor * mm_boi = nullptr; ggml_tensor * mm_boi = nullptr;
ggml_tensor * mm_eoi = nullptr; ggml_tensor * mm_eoi = nullptr;
// deepseek ocr sam
ggml_tensor * patch_embed_proj_w = nullptr;
ggml_tensor * patch_embed_proj_b = nullptr;
ggml_tensor * pos_embed = nullptr;
ggml_tensor * neck_0_w;
ggml_tensor * neck_1_w;
ggml_tensor * neck_1_b;
ggml_tensor * neck_2_w;
ggml_tensor * neck_3_w;
ggml_tensor * neck_3_b;
ggml_tensor * net_2;
ggml_tensor * net_3;
int32_t n_sam_layers = 12; // used by deepseek-ocr sam encoder
std::vector<clip_layer> sam_layers;
bool audio_has_avgpool() const { bool audio_has_avgpool() const {
return proj_type == PROJECTOR_TYPE_QWEN2A return proj_type == PROJECTOR_TYPE_QWEN2A
|| proj_type == PROJECTOR_TYPE_VOXTRAL; || proj_type == PROJECTOR_TYPE_VOXTRAL;

View File

@ -23,6 +23,7 @@
#include <limits> #include <limits>
#include <array> #include <array>
#include <functional> #include <functional>
#include <algorithm>
struct clip_logger_state g_logger_state = {clip_log_callback_default, NULL}; struct clip_logger_state g_logger_state = {clip_log_callback_default, NULL};
@ -618,18 +619,15 @@ ggml_tensor * clip_graph::build_attn(
ggml_tensor * v = ggml_permute(ctx0, v_cur, 1, 2, 0, 3); ggml_tensor * v = ggml_permute(ctx0, v_cur, 1, 2, 0, 3);
v = ggml_cont(ctx0, v); v = ggml_cont(ctx0, v);
const auto n_tokens = q->ne[1]; ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
const auto n_head = q->ne[2]; // F32 may not needed for vision encoders?
// ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
// F32 may not needed for vision encoders?
// ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, 0.0f); kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, 0.0f);
ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq); ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3); cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens); cur = ggml_cont_2d(ctx0, cur, cur->ne[0] * cur->ne[1], cur->ne[2] * cur->ne[3]);
} }
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
@ -837,6 +835,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
{ {
builder = std::make_unique<clip_graph_llava>(ctx, img); builder = std::make_unique<clip_graph_llava>(ctx, img);
} break; } break;
case PROJECTOR_TYPE_DEEPSEEKOCR:
{
builder = std::make_unique<clip_graph_deepseekocr>(ctx, img);
} break;
case PROJECTOR_TYPE_GLM4V: case PROJECTOR_TYPE_GLM4V:
{ {
builder = std::make_unique<clip_graph_glm4v>(ctx, img); builder = std::make_unique<clip_graph_glm4v>(ctx, img);
@ -1187,6 +1189,17 @@ struct clip_model_loader {
hparams.audio_window_len = 400; hparams.audio_window_len = 400;
hparams.audio_hop_len = 160; hparams.audio_hop_len = 160;
} break; } break;
case PROJECTOR_TYPE_DEEPSEEKOCR:
{
hparams.patch_size = 16;
hparams.image_size = 1024;
hparams.warmup_image_size = 1024;
get_u32(KEY_SAM_N_BLOCK, hparams.sam_n_layer, true);
get_u32(KEY_SAM_N_HEAD, hparams.sam_n_head, true);
get_u32(KEY_SAM_N_EMBD, hparams.sam_n_embd, true);
get_u32(KEY_ATTN_WINDOW_SIZE, hparams.attn_window_size, true);
} break;
default: default:
break; break;
} }
@ -1482,7 +1495,7 @@ struct clip_model_loader {
} break; } break;
case PROJECTOR_TYPE_GLM4V: case PROJECTOR_TYPE_GLM4V:
{ {
model.projection = get_tensor(TN_MM_PROJECTOR); model.fc_w = get_tensor(string_format(TN_MM_PROJECTOR, "weight"));
model.mm_ffn_up_w = get_tensor(string_format(TN_MM_UP, "weight")); model.mm_ffn_up_w = get_tensor(string_format(TN_MM_UP, "weight"));
model.mm_ffn_up_b = get_tensor(string_format(TN_MM_UP, "bias"), false); model.mm_ffn_up_b = get_tensor(string_format(TN_MM_UP, "bias"), false);
model.mm_ffn_gate_w = get_tensor(string_format(TN_MM_GATE, "weight")); model.mm_ffn_gate_w = get_tensor(string_format(TN_MM_GATE, "weight"));
@ -1501,7 +1514,7 @@ struct clip_model_loader {
} break; } break;
case PROJECTOR_TYPE_IDEFICS3: case PROJECTOR_TYPE_IDEFICS3:
{ {
model.projection = get_tensor(TN_MM_PROJECTOR); model.fc_w = get_tensor(string_format(TN_MM_PROJECTOR, "weight"));
} break; } break;
case PROJECTOR_TYPE_LFM2: case PROJECTOR_TYPE_LFM2:
case PROJECTOR_TYPE_KIMIVL: case PROJECTOR_TYPE_KIMIVL:
@ -1589,13 +1602,13 @@ struct clip_model_loader {
} break; } break;
case PROJECTOR_TYPE_LLAMA4: case PROJECTOR_TYPE_LLAMA4:
{ {
model.mm_model_proj = get_tensor(TN_MM_PROJECTOR); model.mm_model_proj = get_tensor(string_format(TN_MM_PROJECTOR, "weight"));
model.mm_model_mlp_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight")); model.mm_model_mlp_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight"));
model.mm_model_mlp_2_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 2, "weight")); model.mm_model_mlp_2_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 2, "weight"));
} break; } break;
case PROJECTOR_TYPE_COGVLM: case PROJECTOR_TYPE_COGVLM:
{ {
model.mm_model_proj = get_tensor(TN_MM_PROJECTOR); model.mm_model_proj = get_tensor(string_format(TN_MM_PROJECTOR, "weight"));
model.mm_post_fc_norm_w = get_tensor(string_format(TN_MM_POST_FC_NORM, "weight")); model.mm_post_fc_norm_w = get_tensor(string_format(TN_MM_POST_FC_NORM, "weight"));
model.mm_post_fc_norm_b = get_tensor(string_format(TN_MM_POST_FC_NORM, "bias")); model.mm_post_fc_norm_b = get_tensor(string_format(TN_MM_POST_FC_NORM, "bias"));
model.mm_h_to_4h_w = get_tensor(string_format(TN_MM_H_TO_4H, "weight")); model.mm_h_to_4h_w = get_tensor(string_format(TN_MM_H_TO_4H, "weight"));
@ -1611,6 +1624,42 @@ struct clip_model_loader {
model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight")); model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"));
model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias")); model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"));
} break; } break;
case PROJECTOR_TYPE_DEEPSEEKOCR:
{
model.pos_embed = get_tensor(TN_SAM_POS_EMBD);
model.patch_embed_proj_w = get_tensor(string_format(TN_SAM_PATCH_EMBD, "weight"));
model.patch_embed_proj_b = get_tensor(string_format(TN_SAM_PATCH_EMBD, "bias"));
model.sam_layers.resize(model.n_sam_layers);
for (int il = 0; il < model.n_sam_layers; ++il) {
auto & layer = model.sam_layers[il];
layer.qkv_w = get_tensor(string_format(TN_SAM_ATTN_QKV, il, "weight"));
layer.qkv_b = get_tensor(string_format(TN_SAM_ATTN_QKV, il, "bias"));
layer.o_w = get_tensor(string_format(TN_SAM_ATTN_OUT, il, "weight"));
layer.o_b = get_tensor(string_format(TN_SAM_ATTN_OUT, il, "bias"));
layer.ln_1_w = get_tensor(string_format(TN_SAM_PRE_NORM, il, "weight"));
layer.ln_1_b = get_tensor(string_format(TN_SAM_PRE_NORM, il, "bias"));
layer.ln_2_w = get_tensor(string_format(TN_SAM_POST_NORM, il, "weight"));
layer.ln_2_b = get_tensor(string_format(TN_SAM_POST_NORM, il, "bias"));
layer.rel_pos_h = get_tensor(string_format(TN_SAM_ATTN_POS_H, il));
layer.rel_pos_w = get_tensor(string_format(TN_SAM_ATTN_POS_W, il));
layer.ff_up_w = get_tensor(string_format(TN_SAM_FFN_UP, il, "weight"));
layer.ff_up_b = get_tensor(string_format(TN_SAM_FFN_UP, il, "bias"));
layer.ff_down_w = get_tensor(string_format(TN_SAM_FFN_DOWN, il, "weight"));
layer.ff_down_b = get_tensor(string_format(TN_SAM_FFN_DOWN, il, "bias"));
}
model.neck_0_w = get_tensor(string_format(TN_SAM_NECK, 0, "weight"));
model.neck_1_b = get_tensor(string_format(TN_SAM_NECK, 1, "bias"));
model.neck_1_w = get_tensor(string_format(TN_SAM_NECK, 1, "weight"));
model.neck_2_w = get_tensor(string_format(TN_SAM_NECK, 2, "weight"));
model.neck_3_b = get_tensor(string_format(TN_SAM_NECK, 3, "bias"));
model.neck_3_w = get_tensor(string_format(TN_SAM_NECK, 3, "weight"));
model.net_2 = get_tensor(string_format(TN_SAM_NET, 2, "weight"));
model.net_3 = get_tensor(string_format(TN_SAM_NET, 3, "weight"));
model.image_newline = get_tensor(TN_IMAGE_NEWLINE);
model.view_seperator = get_tensor(TN_IMAGE_SEPERATOR);
model.fc_w = get_tensor(string_format(TN_MM_PROJECTOR, "weight"));
model.fc_b = get_tensor(string_format(TN_MM_PROJECTOR, "bias"));
} break;
default: default:
GGML_ASSERT(false && "unknown projector type"); GGML_ASSERT(false && "unknown projector type");
} }
@ -2012,6 +2061,7 @@ struct img_tool {
enum resize_algo { enum resize_algo {
RESIZE_ALGO_BILINEAR, RESIZE_ALGO_BILINEAR,
RESIZE_ALGO_BICUBIC, RESIZE_ALGO_BICUBIC,
RESIZE_ALGO_BICUBIC_PILLOW,
// RESIZE_ALGO_LANCZOS, // TODO // RESIZE_ALGO_LANCZOS, // TODO
}; };
@ -2041,6 +2091,9 @@ struct img_tool {
case RESIZE_ALGO_BICUBIC: case RESIZE_ALGO_BICUBIC:
resize_bicubic(src, dst, target_resolution.width, target_resolution.height); resize_bicubic(src, dst, target_resolution.width, target_resolution.height);
break; break;
case RESIZE_ALGO_BICUBIC_PILLOW:
resize_bicubic_pillow(src, dst, target_resolution.width, target_resolution.height);
break;
default: default:
throw std::runtime_error("Unsupported resize algorithm"); throw std::runtime_error("Unsupported resize algorithm");
} }
@ -2060,6 +2113,9 @@ struct img_tool {
case RESIZE_ALGO_BICUBIC: case RESIZE_ALGO_BICUBIC:
resize_bicubic(src, resized_image, new_width, new_height); resize_bicubic(src, resized_image, new_width, new_height);
break; break;
case RESIZE_ALGO_BICUBIC_PILLOW:
resize_bicubic_pillow(src, resized_image, new_width, new_height);
break;
default: default:
throw std::runtime_error("Unsupported resize algorithm"); throw std::runtime_error("Unsupported resize algorithm");
} }
@ -2270,6 +2326,255 @@ private:
return true; return true;
} }
// Bicubic resize function using Pillow's ImagingResample algorithm
// Adapted from https://github.com/python-pillow/Pillow/blob/main/src/libImaging/Resample.c
//
// Key Difference with resize_bicubic:
// 1. Uses separable filtering: horizontal pass followed by vertical pass
// 2. Pre-computes normalized filter coefficients for each output pixel
// 3. Applies convolution using fixed-point integer arithmetic for performance
static bool resize_bicubic_pillow(const clip_image_u8 & img, clip_image_u8 & dst, int target_width, int target_height) {
// Fixed-point precision: 22 bits = 32 (int32_t) - 8 (uint8_t pixels) - 2 (headroom for accumulation)
// This allows encoding fractional weights as integers: weight * 2^22
const int PRECISION_BITS = 32 - 8 - 2;
// Bicubic filter function with a = -0.5 (Note that GGML/PyTorch takes a = -0.75)
// Returns filter weight for distance x from pixel center
// Support: [-2, 2], meaning the filter influences pixels within 2 units of distance
auto bicubic_filter = [](double x) -> double {
constexpr double a = -0.5;
if (x < 0.0) {
x = -x;
}
if (x < 1.0) {
return ((a + 2.0) * x - (a + 3.0)) * x * x + 1;
}
if (x < 2.0) {
return (((x - 5) * x + 8) * x - 4) * a;
}
return 0.0; // Zero outside [-2, 2]
};
// Filter support radius: bicubic extends 2 pixels in each direction
constexpr double filter_support = 2.0;
// Clipping function for 8-bit values
auto clip8 = [](int val) -> uint8_t {
if (val < 0) return 0;
if (val > 255) return 255;
return static_cast<uint8_t>(val);
};
// Precompute filter coefficients for ONE dimension (horizontal or vertical)
//
// Parameters:
// inSize - Number of pixels in input dimension (e.g., src_width or src_height)
// outSize - Number of pixels in output dimension (e.g., target_width or target_height)
// bounds - [OUTPUT] Array of size outSize*2 storing input pixel ranges:
// bounds[xx*2+0] = first input pixel index for output pixel xx (xmin)
// bounds[xx*2+1] = number of input pixels for output pixel xx (xcnt)
// weights - [OUTPUT] Array of size outSize*ksize storing fixed-point filter weights:
// kk[xx*ksize + x] = weight for input pixel x contributing to output pixel xx
//
// Returns: kernel size (ksize) - number of input pixels that contribute to each output pixel
auto precompute_weights = [&](int inSize, int outSize,
std::vector<int> & bounds, std::vector<int32_t> & weights) -> int {
double support, scale, filterscale;
double center, ww, ss;
int xx, x, ksize, xmin, xmax, xcnt;
// Calculate scaling factor: ratio of input range to output size
filterscale = scale = (double)inSize / outSize;
// For upsampling (scale < 1), keep filterscale = 1 to maintain filter sharpness
// For downsampling (scale > 1), widen filter to prevent aliasing
if (filterscale < 1.0) {
filterscale = 1.0;
}
// Determine filter support radius and kernel size
support = filter_support * filterscale; // Widen filter when downsampling
ksize = static_cast<int>(std::ceil(support)) * 2 + 1; // Total pixels in kernel
std::vector<double> pre_weights(outSize * ksize); // Temporary weights
bounds.resize(outSize * 2);
// For each output pixel, compute its filter coefficients
for (xx = 0; xx < outSize; xx++) {
// Calculate the center position in input space (pixel-center convention: +0.5)
center = (xx + 0.5) * scale;
ww = 0.0; // Sum of weights for normalization
ss = 1.0 / filterscale; // Scale factor for filter function
// Determine the range of input pixels that contribute to this output pixel
xmin = static_cast<int>(center - support + 0.5);
if (xmin < 0) {
xmin = 0;
}
xmax = static_cast<int>(center + support + 0.5);
if (xmax > inSize) {
xmax = inSize;
}
xcnt = xmax - xmin;
// Compute filter weights for each contributing input pixel
for (x = 0; x < xcnt; x++) {
// Distance from input pixel center to output pixel center in input space
double w = bicubic_filter((x + xmin - center + 0.5) * ss);
pre_weights[xx * ksize + x] = w;
ww += w; // Accumulate for normalization
}
// Normalize weights to sum to 1.0 (preserves brightness)
for (x = 0; x < xcnt; x++) {
if (ww != 0.0) {
pre_weights[xx * ksize + x] /= ww;
}
}
// Zero-pad remaining kernel positions
for (; x < ksize; x++) {
pre_weights[xx * ksize + x] = 0;
}
// Store input pixel range for this output pixel
bounds[xx * 2 + 0] = xmin;
bounds[xx * 2 + 1] = xcnt;
}
// Convert floating-point coefficients to fixed-point integers
// Formula: int32 = round(float * 2^PRECISION_BITS)
weights.resize(outSize * ksize);
for (int i = 0; i < outSize * ksize; i++) {
if (pre_weights[i] < 0) {
weights[i] = static_cast<int32_t>(-0.5 + pre_weights[i] * (1 << PRECISION_BITS));
} else {
weights[i] = static_cast<int32_t>(0.5 + pre_weights[i] * (1 << PRECISION_BITS));
}
}
return ksize;
};
// Horizontal resampling pass
// Resizes width from imIn.nx to imOut.nx, preserving height
auto resample_horizontal = [&](const clip_image_u8 & imIn, clip_image_u8 & imOut,
int ksize, const std::vector<int> & bounds, const std::vector<int32_t> & weights) {
imOut.ny = imIn.ny;
imOut.buf.resize(3 * imOut.nx * imOut.ny);
// Process each row independently
for (int yy = 0; yy < imOut.ny; yy++) {
// For each output pixel in this row
for (int xx = 0; xx < imOut.nx; xx++) {
// Get the range of input pixels and filter coefficients
int xmin = bounds[xx * 2 + 0]; // First input pixel index
int xcnt = bounds[xx * 2 + 1]; // Number of input pixels
// Initialize accumulators for RGB channels with rounding bias (0.5 in fixed-point)
int32_t ss0 = 1 << (PRECISION_BITS - 1);
int32_t ss1 = 1 << (PRECISION_BITS - 1);
int32_t ss2 = 1 << (PRECISION_BITS - 1);
// Convolve: sum weighted input pixels
for (int x = 0; x < xcnt; x++) {
int src_idx = ((yy * imIn.nx) + (x + xmin)) * 3;
ss0 += static_cast<uint8_t>(imIn.buf[src_idx + 0]) * weights[xx * ksize + x]; // R channel
ss1 += static_cast<uint8_t>(imIn.buf[src_idx + 1]) * weights[xx * ksize + x]; // G channel
ss2 += static_cast<uint8_t>(imIn.buf[src_idx + 2]) * weights[xx * ksize + x]; // B channel
}
// Convert back from fixed-point (divide by 2^PRECISION_BITS) and clamp to [0,255]
int dst_idx = (yy * imOut.nx + xx) * 3;
imOut.buf[dst_idx + 0] = clip8(ss0 >> PRECISION_BITS);
imOut.buf[dst_idx + 1] = clip8(ss1 >> PRECISION_BITS);
imOut.buf[dst_idx + 2] = clip8(ss2 >> PRECISION_BITS);
}
}
};
// Vertical resampling pass
// Resizes height from imIn.ny to imOut.ny, preserving width
auto resample_vertical = [&](const clip_image_u8 & imIn, clip_image_u8 & imOut,
int ksize, const std::vector<int> & bounds, const std::vector<int32_t> & weight) {
imOut.nx = imIn.nx;
imOut.buf.resize(3 * imOut.nx * imOut.ny);
// For each output row
for (int yy = 0; yy < imOut.ny; yy++) {
// Get the range of input rows and filter coefficients
int ymin = bounds[yy * 2 + 0]; // First input row index
int ycnt = bounds[yy * 2 + 1]; // Number of input rows
// Process each column in this output row
for (int xx = 0; xx < imOut.nx; xx++) {
// Initialize accumulators for RGB channels with rounding bias
int32_t ss0 = 1 << (PRECISION_BITS - 1);
int32_t ss1 = 1 << (PRECISION_BITS - 1);
int32_t ss2 = 1 << (PRECISION_BITS - 1);
// Convolve: sum weighted input pixels vertically
for (int y = 0; y < ycnt; y++) {
int src_idx = ((y + ymin) * imIn.nx + xx) * 3;
ss0 += static_cast<uint8_t>(imIn.buf[src_idx + 0]) * weight[yy * ksize + y]; // R channel
ss1 += static_cast<uint8_t>(imIn.buf[src_idx + 1]) * weight[yy * ksize + y]; // G channel
ss2 += static_cast<uint8_t>(imIn.buf[src_idx + 2]) * weight[yy * ksize + y]; // B channel
}
// Convert back from fixed-point and clamp to [0,255]
int dst_idx = (yy * imOut.nx + xx) * 3;
imOut.buf[dst_idx + 0] = clip8(ss0 >> PRECISION_BITS);
imOut.buf[dst_idx + 1] = clip8(ss1 >> PRECISION_BITS);
imOut.buf[dst_idx + 2] = clip8(ss2 >> PRECISION_BITS);
}
}
};
// Main resampling logic using separable two-pass approach
const int src_width = img.nx;
const int src_height = img.ny;
dst.nx = target_width;
dst.ny = target_height;
bool need_horizontal = (target_width != src_width);
bool need_vertical = (target_height != src_height);
// Precompute filter coefficients for both dimensions
std::vector<int> bounds_horiz, bounds_vert;
std::vector<int32_t> weights_horiz, weights_vert;
int ksize_horiz = 0, ksize_vert = 0;
if (need_horizontal) {
ksize_horiz = precompute_weights(src_width, target_width, bounds_horiz, weights_horiz);
}
if (need_vertical) {
ksize_vert = precompute_weights(src_height, target_height, bounds_vert, weights_vert);
}
// Perform two-pass resampling
if (need_horizontal && need_vertical) {
// Both horizontal and vertical
clip_image_u8 temp;
temp.nx = target_width;
resample_horizontal(img, temp, ksize_horiz, bounds_horiz, weights_horiz);
resample_vertical(temp, dst, ksize_vert, bounds_vert, weights_vert);
} else if (need_horizontal) {
// Only horizontal
resample_horizontal(img, dst, ksize_horiz, bounds_horiz, weights_horiz);
} else if (need_vertical) {
// Only vertical
resample_vertical(img, dst, ksize_vert, bounds_vert, weights_vert);
} else {
// No resizing needed - direct copy
dst.buf = img.buf;
}
return true;
}
static inline int clip(int x, int lower, int upper) { static inline int clip(int x, int lower, int upper) {
return std::max(lower, std::min(x, upper)); return std::max(lower, std::min(x, upper));
} }
@ -2582,6 +2887,59 @@ private:
} }
}; };
static std::vector<std::pair<int, int>> ds_build_target_ratios(const int min_num, const int max_num) {
std::vector<std::pair<int, int>> ratios;
for (int n = min_num; n <= max_num; ++n) {
for (int i = 1; i <= n; ++i) {
for (int j = 1; j <= n; ++j) {
if (const int blocks = i * j; blocks >= min_num && blocks <= max_num) {
ratios.emplace_back(i, j); // (cols, rows)
}
}
}
}
// sort by total blocks like in Python (key=lambda x: x[0] * x[1])
std::sort(ratios.begin(), ratios.end(),
[](const auto &a, const auto &b) {
return (a.first * a.second) < (b.first * b.second);
});
// optional: dedup
ratios.erase(std::unique(ratios.begin(), ratios.end()), ratios.end());
return ratios;
}
static std::pair<int, int> ds_find_closest_ratio(
const float aspect_ratio,
const std::vector<std::pair<int, int>> &target_ratios,
const int width,
const int height,
const int image_size
) {
float best_diff = std::numeric_limits<float>::infinity();
std::pair<int, int> best_ratio = {1, 1};
const float area = static_cast<float>(width) * static_cast<float>(height);
for (const auto &r : target_ratios) {
const float target_ar = static_cast<float>(r.first) / static_cast<float>(r.second);
if (const float diff = std::fabs(aspect_ratio - target_ar); diff < best_diff) {
best_diff = diff;
best_ratio = r;
} else if (diff == best_diff) {
// same as python: prefer this ratio if the image area is “large enough”
if (const float needed_area = 0.5f * image_size * image_size * r.first * r.second; area > needed_area) {
best_ratio = r;
}
}
}
return best_ratio; // (cols, rows)
}
// returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector // returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector
// res_imgs memory is being allocated here, previous allocations will be freed if found // res_imgs memory is being allocated here, previous allocations will be freed if found
bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, struct clip_image_f32_batch * res_imgs) { bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, struct clip_image_f32_batch * res_imgs) {
@ -2798,6 +3156,89 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
} }
} }
} break; } break;
case PROJECTOR_TYPE_DEEPSEEKOCR:
{
const std::vector native_resolutions = {
/*512 tiny , 640 small, */ 1024 /* base */, 1280 /* large */
};
// original image size
const int orig_w = original_size.width;
const int orig_h = original_size.height;
const int orig_area = orig_h * orig_w;
std::array<uint8_t, 3u> color;
for (int i = 0; i < 3; i++) {
color[i] = (int)(255 * params.image_mean[i]);
}
size_t mode_i = 0;
int min_diff = orig_area;
for (size_t i = 0; i < native_resolutions.size(); i++) {
int r = native_resolutions[i];
if (std::abs(orig_area - r * r) < min_diff) {
mode_i = i;
min_diff = std::abs(orig_area - r * r);
}
}
/* Native Resolution (Base/Large) */
const int image_size = native_resolutions[mode_i];
// Resize maintaining aspect ratio, then pad to square
float scale = std::min(
static_cast<float>(image_size) / orig_w,
static_cast<float>(image_size) / orig_h
);
int new_w = static_cast<int>(orig_w * scale);
int new_h = static_cast<int>(orig_h * scale);
clip_image_u8_ptr scaled_img(clip_image_u8_init());
img_tool::resize(*img, *scaled_img, clip_image_size{new_w, new_h},
img_tool::RESIZE_ALGO_BICUBIC_PILLOW, true, color);
// Use mean color for padding
unsigned char pad_r = static_cast<unsigned char>(params.image_mean[0] * 255.0f);
unsigned char pad_g = static_cast<unsigned char>(params.image_mean[1] * 255.0f);
unsigned char pad_b = static_cast<unsigned char>(params.image_mean[2] * 255.0f);
// Pad to image_size × image_size (center padding)
clip_image_u8_ptr padded_img(clip_image_u8_init());
padded_img->nx = image_size;
padded_img->ny = image_size;
padded_img->buf.resize(image_size * image_size * 3); // black padding
// Fill with mean color
for (int i = 0; i < image_size * image_size; ++i)
{
padded_img->buf[i * 3 + 0] = pad_r;
padded_img->buf[i * 3 + 1] = pad_g;
padded_img->buf[i * 3 + 2] = pad_b;
}
// Calculate padding offsets (center the image)
int pad_x = (image_size - new_w) / 2;
int pad_y = (image_size - new_h) / 2;
// Copy scaled image into padded canvas
for (int y = 0; y < new_h; ++y){
for (int x = 0; x < new_w; ++x){
int src_idx = (y * new_w + x) * 3;
int dst_idx = ((y + pad_y) * image_size + (x + pad_x)) * 3;
padded_img->buf[dst_idx + 0] = scaled_img->buf[src_idx + 0];
padded_img->buf[dst_idx + 1] = scaled_img->buf[src_idx + 1];
padded_img->buf[dst_idx + 2] = scaled_img->buf[src_idx + 2];
}
}
// Normalize and output
clip_image_f32_ptr res(clip_image_f32_init());
normalize_image_u8_to_f32(*padded_img, *res, params.image_mean, params.image_std);
res_imgs->entries.push_back(std::move(res));
res_imgs->grid_x = 1;
res_imgs->grid_y = 1;
} break;
default: default:
LOG_ERR("%s: unsupported projector type %d\n", __func__, ctx->proj_type()); LOG_ERR("%s: unsupported projector type %d\n", __func__, ctx->proj_type());
@ -3004,6 +3445,18 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
{ {
n_patches += 2; // for BOI and EOI token embeddings n_patches += 2; // for BOI and EOI token embeddings
} break; } break;
case PROJECTOR_TYPE_DEEPSEEKOCR:
{
// SAM encoder applies two stride-2 convolutions (net_2 and net_3)
// which reduces spatial dimensions by 4x in each direction (16x total)
// E.g., 64x64 -> 16x16 patches
n_patches /= 16;
// build_global_local_features adds image newlines and view separator
// Formula: h*(w+1) + 1 where h = w = sqrt(n_patches)
int h = static_cast<int>(std::sqrt(static_cast<float>(n_patches)));
n_patches = h * (h + 1) + 1;
} break;
default: default:
GGML_ABORT("unsupported projector type"); GGML_ABORT("unsupported projector type");
} }
@ -3332,6 +3785,30 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
} }
set_input_i32("patches", patches); set_input_i32("patches", patches);
} break; } break;
case PROJECTOR_TYPE_DEEPSEEKOCR:
{
GGML_ASSERT(pos_w == pos_h);
const int window = hparams.attn_window_size;
const int pos = pos_w;
std::vector<int32_t> rel_pos_indices_local(window * window);
std::vector<int32_t> rel_pos_indices_global(pos * pos);
for (int q = 0; q < window; q++) {
for (int k = 0; k < window; k++) {
rel_pos_indices_local[q * window + k] = q - k + window - 1;
}
}
for (int q = 0; q < pos; q++) {
for (int k = 0; k < pos; k++) {
rel_pos_indices_global[q * pos + k] = q - k + pos - 1;
}
}
set_input_i32("rel_pos_indices_local", rel_pos_indices_local);
set_input_i32("rel_pos_indices_global", rel_pos_indices_global);
} break;
case PROJECTOR_TYPE_GEMMA3: case PROJECTOR_TYPE_GEMMA3:
case PROJECTOR_TYPE_IDEFICS3: case PROJECTOR_TYPE_IDEFICS3:
case PROJECTOR_TYPE_INTERNVL: case PROJECTOR_TYPE_INTERNVL:
@ -3389,8 +3866,27 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
for (ggml_tensor * t : ctx->debug_print_tensors) { for (ggml_tensor * t : ctx->debug_print_tensors) {
std::vector<uint8_t> data(ggml_nbytes(t)); std::vector<uint8_t> data(ggml_nbytes(t));
ggml_backend_tensor_get(t, data.data(), 0, ggml_nbytes(t)); ggml_backend_tensor_get(t, data.data(), 0, ggml_nbytes(t));
print_tensor_info(t);
print_tensor_shape(t); print_tensor_shape(t);
print_tensor_data(t, data.data(), 3); print_tensor_sum(t, data.data(), 3);
std::string tname_s = std::string(t->name);
bool is_stored = false;
std::vector<std::string> patterns = {
/* Add tensor names here to dump (e.g. "sam_output") */
};
for (auto & p : patterns) {
if (tname_s == p) {
save_tensor_to_file(t, data.data());
is_stored = true;
break;
}
}
if (!is_stored) {
print_tensor_data(t, data.data(), 3);
}
} }
} }
@ -3439,7 +3935,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
case PROJECTOR_TYPE_GEMMA3: case PROJECTOR_TYPE_GEMMA3:
return ctx->model.mm_input_proj_w->ne[0]; return ctx->model.mm_input_proj_w->ne[0];
case PROJECTOR_TYPE_IDEFICS3: case PROJECTOR_TYPE_IDEFICS3:
return ctx->model.projection->ne[1]; return ctx->model.fc_w->ne[1];
case PROJECTOR_TYPE_ULTRAVOX: case PROJECTOR_TYPE_ULTRAVOX:
case PROJECTOR_TYPE_VOXTRAL: case PROJECTOR_TYPE_VOXTRAL:
return ctx->model.mm_2_w->ne[1]; return ctx->model.mm_2_w->ne[1];
@ -3456,6 +3952,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
return ctx->model.mm_2_w->ne[1]; return ctx->model.mm_2_w->ne[1];
case PROJECTOR_TYPE_COGVLM: case PROJECTOR_TYPE_COGVLM:
return ctx->model.mm_4h_to_h_w->ne[1]; return ctx->model.mm_4h_to_h_w->ne[1];
case PROJECTOR_TYPE_DEEPSEEKOCR:
return ctx->model.fc_w->ne[1];
case PROJECTOR_TYPE_GLM4V: case PROJECTOR_TYPE_GLM4V:
return ctx->model.mm_ffn_down_w->ne[1]; return ctx->model.mm_ffn_down_w->ne[1];
default: default:
@ -3489,6 +3987,10 @@ bool clip_is_gemma3(const struct clip_ctx * ctx) {
return ctx->proj_type() == PROJECTOR_TYPE_GEMMA3; return ctx->proj_type() == PROJECTOR_TYPE_GEMMA3;
} }
bool clip_is_deepseekocr(const struct clip_ctx * ctx) {
return ctx->proj_type() == PROJECTOR_TYPE_DEEPSEEKOCR;
}
bool clip_has_vision_encoder(const struct clip_ctx * ctx) { bool clip_has_vision_encoder(const struct clip_ctx * ctx) {
return ctx->model.modality == CLIP_MODALITY_VISION; return ctx->model.modality == CLIP_MODALITY_VISION;
} }

View File

@ -107,6 +107,8 @@ bool clip_is_glm(const struct clip_ctx * ctx);
bool clip_is_mrope(const struct clip_ctx * ctx); bool clip_is_mrope(const struct clip_ctx * ctx);
bool clip_is_llava(const struct clip_ctx * ctx); bool clip_is_llava(const struct clip_ctx * ctx);
bool clip_is_gemma3(const struct clip_ctx * ctx); bool clip_is_gemma3(const struct clip_ctx * ctx);
bool clip_is_deepseekocr(const struct clip_ctx * ctx);
bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec); bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec);

View File

@ -0,0 +1,324 @@
#include "models.h"
// Implementation based on approach suggested by Acly
// See: https://github.com/ggml-org/llama.cpp/pull/17383#issuecomment-3554227091
static ggml_tensor * window_partition(ggml_context * ctx0, ggml_tensor * x, const int window) {
auto [c, w, h, b] = x->ne;
// same as
// x = ggml_win_part(m, x, window);
// x = ggml_reshape_3d(m, x, c, window * window, x->ne[3]);
const int64_t px = (window - w % window) % window;
const int64_t py = (window - h % window) % window;
const int64_t npw = (w + px) / window;
const int64_t nph = (h + py) / window;
ggml_tensor * cur = x;
if (px > 0 || py > 0) {
cur = ggml_pad(ctx0, cur, 0, static_cast<int>(px), static_cast<int>(py), 0);
}
cur = ggml_reshape_4d(ctx0, cur, c * window, npw, window, nph * b);
cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 0, 2, 1, 3));
cur = ggml_reshape_4d(ctx0, cur, c, window, window, npw * nph * b);
return cur;
}
// Implementation based on approach suggested by Acly
// See: https://github.com/ggml-org/llama.cpp/pull/17383#issuecomment-3554227091
static ggml_tensor * window_unpartition(ggml_context * ctx0,
ggml_tensor * x,
const int w,
const int h,
const int window) {
const int64_t c = x->ne[0];
// same as
// x = ggml_reshape_4d(m, x, c, window, window, x->ne[2]);
// x = ggml_win_unpart(m, x, w, h, window);
const int64_t px = (window - w % window) % window;
const int64_t py = (window - h % window) % window;
const int64_t npw = (w + px) / window;
const int64_t nph = (h + py) / window;
const int64_t b = x->ne[3] / (npw * nph);
ggml_tensor * cur = x;
cur = ggml_reshape_4d(ctx0, cur, c * window, window, npw, nph * b);
cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 0, 2, 1, 3));
cur = ggml_reshape_4d(ctx0, cur, c, w + px, h + py, b);
cur = ggml_view_4d(ctx0, cur, cur->ne[0], w, h, cur->ne[3], cur->nb[1], cur->nb[2], cur->nb[3], 0);
cur = ggml_cont(ctx0, cur);
return cur;
}
static ggml_tensor * get_rel_pos(ggml_context * ctx0,
ggml_tensor * rel_pos, // [L, C]
ggml_tensor * indices, // [q_size, k_size]
const int q_size,
const int k_size) {
const int64_t C = rel_pos->ne[0]; // channels
const int64_t L = rel_pos->ne[1]; // length
GGML_ASSERT(indices != nullptr);
GGML_ASSERT(indices->type == GGML_TYPE_I32);
GGML_ASSERT(indices->ne[0] == k_size);
GGML_ASSERT(indices->ne[1] == q_size);
const auto max_rel_dist = 2 * std::max(q_size, k_size) - 1;
ggml_tensor * cur = rel_pos;
if (max_rel_dist != L) {
// Linear interpolation
const int64_t ne0 = cur->ne[0];
const int64_t ne1 = cur->ne[1];
const int64_t ne2 = cur->ne[2];
const int64_t ne3 = cur->ne[3];
cur = ggml_reshape_3d(ctx0, ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 0, 2, 3)), ne1, 1, ne0 * ne2 * ne3);
cur = ggml_reshape_4d(
ctx0, ggml_interpolate(ctx0, cur, max_rel_dist, 1, ne0 * ne2 * ne3, 1, GGML_SCALE_MODE_BILINEAR),
max_rel_dist, ne0, ne2, ne3);
cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 0, 2, 3));
}
// Flatten indices to 1D for ggml_get_rows
const int qk = q_size * k_size;
cur = ggml_reshape_3d(ctx0, ggml_get_rows(ctx0, cur, ggml_reshape_1d(ctx0, indices, qk)), C, k_size, q_size);
return cur; // [C, k_size, q_size]
}
ggml_cgraph * clip_graph_deepseekocr::build() {
// patch embedding
ggml_tensor * inp_raw = build_inp_raw();
ggml_tensor * sam_out;
// Building SAM
{
const int n_embd = hparams.sam_n_embd;
const int n_layer = hparams.sam_n_layer;
const int n_heads = hparams.sam_n_head;
const int d_heads = n_embd / n_heads;
const int window = hparams.attn_window_size;
ggml_tensor * inpL;
inpL = ggml_conv_2d_sk_p0(ctx0, model.patch_embed_proj_w, inp_raw);
inpL = ggml_add(ctx0, inpL, ggml_reshape_3d(ctx0, model.patch_embed_proj_b, 1, 1, n_embd));
inpL = ggml_cont(ctx0, ggml_permute(ctx0, inpL, 1, 2, 0, 3));
ggml_tensor * rel_pos_indices_local;
ggml_tensor * rel_pos_indices_global;
rel_pos_indices_local = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, window, window);
rel_pos_indices_global = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, inpL->ne[1], inpL->ne[2]);
ggml_set_name(rel_pos_indices_local, "rel_pos_indices_local");
ggml_set_name(rel_pos_indices_global, "rel_pos_indices_global");
ggml_set_input(rel_pos_indices_local);
ggml_set_input(rel_pos_indices_global);
ggml_tensor * cur;
const auto tgt_size = inpL->ne[1];
const auto str_size = model.pos_embed->ne[1];
if (str_size != tgt_size) {
ggml_tensor * old_pos_embed = nullptr;
old_pos_embed = ggml_cont(ctx0, ggml_permute(ctx0, model.pos_embed, 2, 0, 1, 3));
ggml_tensor * new_pos_embed =
ggml_interpolate(ctx0, old_pos_embed, tgt_size, tgt_size, n_embd, 1, GGML_SCALE_MODE_BICUBIC);
new_pos_embed = ggml_cont(ctx0, ggml_permute(ctx0, new_pos_embed, 1, 2, 0, 3));
cur = ggml_add(ctx0, inpL, new_pos_embed);
} else {
cur = ggml_add(ctx0, inpL, model.pos_embed);
}
// loop over layers
for (int il = 0; il < n_layer; il++) {
auto & layer = model.sam_layers[il];
ggml_tensor * shortcut = cur;
// layernorm1
cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, NORM_TYPE_NORMAL, eps, il);
const int64_t w0 = cur->ne[1];
const int64_t h0 = cur->ne[2];
ggml_tensor * indices;
if (hparams.is_global_attn(il)) {
indices = rel_pos_indices_global;
} else {
// local attention layer - apply window partition
cur = window_partition(ctx0, cur, window);
indices = rel_pos_indices_local;
}
const int64_t W = cur->ne[1];
const int64_t H = cur->ne[2];
// self-attention
{
const int B = cur->ne[3];
cur = ggml_mul_mat(ctx0, layer.qkv_w, cur);
cur = ggml_add(ctx0, cur, layer.qkv_b);
cur = ggml_cont(ctx0, cur); // Ensure tensor is contiguous before reshape
cur = ggml_reshape_4d(ctx0, cur, n_embd, 3, W * H, B);
ggml_tensor * Q;
ggml_tensor * K;
ggml_tensor * V;
Q = ggml_view_3d(ctx0, cur, n_embd, W * H, B, cur->nb[2], cur->nb[3], 0 * cur->nb[1]);
Q = ggml_reshape_4d(ctx0, ggml_cont(ctx0, Q), d_heads, n_heads, W * H, B);
K = ggml_view_3d(ctx0, cur, n_embd, W * H, B, cur->nb[2], cur->nb[3], 1 * cur->nb[1]);
K = ggml_reshape_4d(ctx0, ggml_cont(ctx0, K), d_heads, n_heads, W * H, B);
V = ggml_view_3d(ctx0, cur, n_embd, W * H, B, cur->nb[2], cur->nb[3], 2 * cur->nb[1]);
V = ggml_reshape_4d(ctx0, ggml_cont(ctx0, V), d_heads, n_heads, W * H, B);
ggml_tensor * mask;
ggml_tensor * rw;
ggml_tensor * rh;
ggml_tensor * qr;
rw = get_rel_pos(ctx0, layer.rel_pos_w, indices, W, W); // [W, W, C]
rh = get_rel_pos(ctx0, layer.rel_pos_h, indices, H, H); // [H, H, C]
qr = ggml_permute(ctx0, Q, 0, 2, 1, 3);
qr = ggml_reshape_4d(ctx0, ggml_cont(ctx0, qr), d_heads, W, H, B * n_heads);
rw = ggml_mul_mat(ctx0, rw,
ggml_cont(ctx0, ggml_permute(ctx0, qr, 0, 2, 1, 3))); // [B*n_heads, W, H, W]
rw = ggml_cont(ctx0, ggml_permute(ctx0, rw, 0, 2, 1, 3)); // [B*n_heads, H, W, W]
rw = ggml_reshape_4d(ctx0, rw, W, 1, W * H, n_heads * B);
rw = ggml_repeat_4d(ctx0, rw, W, H, W * H, n_heads * B);
rh = ggml_mul_mat(ctx0, rh, qr); // [B*n_heads, H, W, H]
rh = ggml_reshape_4d(ctx0, rh, 1, H, W * H, n_heads * B);
mask = ggml_add(ctx0, rw, rh); // [B*n_heads, H*W, H, W]
mask = ggml_reshape_4d(ctx0, mask, W * H, W * H, n_heads, B);
mask = ggml_cast(ctx0, mask, GGML_TYPE_F16);
const float scale = 1.0f / sqrtf(static_cast<float>(d_heads));
cur = build_attn(layer.o_w, layer.o_b, Q, K, V, mask, scale,
il); // [B, H*W, n_embd]
cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur), n_embd, W, H, B);
}
if (hparams.is_global_attn(il) == false) {
// local attention layer - reverse window partition
cur = window_unpartition(ctx0, cur, w0, h0, window);
}
// re-add the layer input, e.g., residual
cur = ggml_add(ctx0, cur, shortcut);
ggml_tensor * inpFF = cur;
// layernorm2
cur = build_norm(inpFF, layer.ln_2_w, layer.ln_2_b, NORM_TYPE_NORMAL, eps, il);
// ffn
cur = build_ffn(cur, layer.ff_up_w, layer.ff_up_b, nullptr, nullptr, layer.ff_down_w, layer.ff_down_b,
hparams.ffn_op, il);
// residual 2
cur = ggml_add(ctx0, cur, inpFF);
cb(cur, "sam_layer_out", il);
}
cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 2, 0, 1, 3));
cur = ggml_conv_2d(ctx0, model.neck_0_w, cur, 1, 1, 0, 0, 1, 1);
cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 2, 0, 3));
cur = build_norm(cur, model.neck_1_w, model.neck_1_b, NORM_TYPE_NORMAL, hparams.eps, -1);
cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 2, 0, 1, 3));
cur = ggml_conv_2d(ctx0, model.neck_2_w, cur, 1, 1, 1, 1, 1, 1);
cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 2, 0, 3));
cur = build_norm(cur, model.neck_3_w, model.neck_3_b, NORM_TYPE_NORMAL, hparams.eps, -1);
cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 2, 0, 1, 3));
cur = ggml_conv_2d(ctx0, model.net_2, cur, 2, 2, 1, 1, 1, 1);
cur = ggml_conv_2d(ctx0, model.net_3, cur, 2, 2, 1, 1, 1, 1);
cb(cur, "sam_output", -1);
ggml_build_forward_expand(gf, cur);
sam_out = cur;
}
ggml_tensor * clip_out;
// Building DS-OCR CLIP
{
ggml_tensor * inp;
inp = ggml_cpy(ctx0, sam_out, ggml_dup_tensor(ctx0, sam_out));
inp = ggml_reshape_2d(ctx0, inp, inp->ne[0] * inp->ne[1], inp->ne[2]);
inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 0, 2, 3));
ggml_tensor * new_pos_embd =
ggml_cpy(ctx0, model.position_embeddings, ggml_dup_tensor(ctx0, model.position_embeddings));
int n_pos = new_pos_embd->ne[1]; // +1 for [CLS]
const auto tgt_size = static_cast<int>(std::sqrt(inp->ne[1]));
const auto src_size = static_cast<int>(std::sqrt(n_pos - 1));
if (tgt_size != src_size) {
ggml_tensor * old_pos_embd;
ggml_tensor * cls_tok;
old_pos_embd = ggml_view_2d(ctx0, new_pos_embd, new_pos_embd->ne[0], src_size * src_size,
ggml_row_size(new_pos_embd->type, new_pos_embd->ne[0]), 0);
cls_tok = ggml_view_2d(ctx0, new_pos_embd, new_pos_embd->ne[0], 1,
ggml_row_size(new_pos_embd->type, new_pos_embd->ne[0]), src_size * src_size);
new_pos_embd = ggml_interpolate(ctx0, old_pos_embd, tgt_size, tgt_size, new_pos_embd->ne[0], 1,
GGML_SCALE_MODE_BICUBIC);
new_pos_embd = ggml_reshape_3d(ctx0, new_pos_embd, n_embd, tgt_size * tgt_size, 1);
new_pos_embd = ggml_concat(ctx0, new_pos_embd, cls_tok, 1);
n_pos = tgt_size * tgt_size + 1;
}
// add CLS token
inp = ggml_concat(ctx0, model.class_embedding, inp, 1);
// for selecting learned pos embd, used by ViT
ggml_tensor * positions = ggml_cast(ctx0, ggml_arange(ctx0, 0, n_pos, 1), GGML_TYPE_I32);
ggml_tensor * learned_pos_embd = ggml_get_rows(ctx0, new_pos_embd, positions);
ggml_tensor * cur = build_vit(inp, n_pos, NORM_TYPE_NORMAL, FFN_GELU_QUICK, learned_pos_embd, nullptr);
ggml_build_forward_expand(gf, cur);
clip_out = cur;
}
const int clip_n_patches = sam_out->ne[0] * sam_out->ne[1];
sam_out = ggml_cont(ctx0, ggml_permute(ctx0, sam_out, 1, 2, 0, 3));
sam_out = ggml_reshape_2d(ctx0, sam_out, sam_out->ne[0], clip_n_patches);
clip_out = ggml_view_2d(ctx0, clip_out, n_embd, clip_n_patches, clip_out->nb[1], clip_out->nb[1]);
ggml_tensor * cur;
cur = ggml_concat(ctx0, clip_out, sam_out, 0);
cur = ggml_reshape_2d(ctx0, cur, 2 * n_embd, clip_n_patches);
cur = ggml_cont(ctx0, cur);
cur = ggml_mul_mat(ctx0, model.fc_w, cur);
cur = ggml_add(ctx0, cur, model.fc_b);
const auto h = static_cast<int>(std::sqrt(static_cast<float>(cur->ne[1])));
const auto w = h;
const auto n_dim = cur->ne[0];
ggml_tensor * imgnl;
ggml_tensor * vs;
imgnl = ggml_repeat_4d(ctx0, model.image_newline, n_dim, 1, h, 1);
vs = ggml_reshape_2d(ctx0, model.view_seperator, n_dim, 1); // (n_dim, 1)
cur = ggml_reshape_3d(ctx0, cur, n_dim, w, h);
cur = ggml_reshape_2d(ctx0, ggml_concat(ctx0, cur, imgnl, 1), n_dim, (w + 1) * h);
cur = ggml_concat(ctx0, cur, vs, 1); // (n_dim, h*(w+1) + 1)
cb(cur, "dsocr_output", -1);
ggml_build_forward_expand(gf, cur);
return gf;
}

View File

@ -95,7 +95,7 @@ ggml_cgraph * clip_graph_glm4v::build() {
// FC projector // FC projector
{ {
cur = ggml_mul_mat(ctx0, model.projection, cur); cur = ggml_mul_mat(ctx0, model.fc_w, cur);
// default LayerNorm (post_projection_norm) // default LayerNorm (post_projection_norm)
cur = build_norm(cur, model.mm_post_norm_w, model.mm_post_norm_b, NORM_TYPE_NORMAL, 1e-5, -1); cur = build_norm(cur, model.mm_post_norm_w, model.mm_post_norm_b, NORM_TYPE_NORMAL, 1e-5, -1);
cur = ggml_gelu_erf(ctx0, cur); cur = ggml_gelu_erf(ctx0, cur);

View File

@ -57,6 +57,11 @@ struct clip_graph_whisper_enc : clip_graph {
ggml_cgraph * build() override; ggml_cgraph * build() override;
}; };
struct clip_graph_deepseekocr : clip_graph {
clip_graph_deepseekocr(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
ggml_cgraph * build() override;
};
struct clip_graph_glm4v : clip_graph { struct clip_graph_glm4v : clip_graph {
clip_graph_glm4v(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} clip_graph_glm4v(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
ggml_cgraph * build() override; ggml_cgraph * build() override;

View File

@ -43,7 +43,7 @@ ggml_cgraph * clip_graph_siglip::build() {
// https://github.com/huggingface/transformers/blob/0a950e0bbe1ed58d5401a6b547af19f15f0c195e/src/transformers/models/idefics3/modeling_idefics3.py#L578 // https://github.com/huggingface/transformers/blob/0a950e0bbe1ed58d5401a6b547af19f15f0c195e/src/transformers/models/idefics3/modeling_idefics3.py#L578
const int scale_factor = model.hparams.n_merge; const int scale_factor = model.hparams.n_merge;
cur = build_patch_merge_permute(cur, scale_factor); cur = build_patch_merge_permute(cur, scale_factor);
cur = ggml_mul_mat(ctx0, model.projection, cur); cur = ggml_mul_mat(ctx0, model.fc_w, cur);
} else if (proj_type == PROJECTOR_TYPE_LFM2) { } else if (proj_type == PROJECTOR_TYPE_LFM2) {
// pixel unshuffle block // pixel unshuffle block

View File

@ -174,7 +174,7 @@ struct mtmd_context {
clip_context_params ctx_clip_params { clip_context_params ctx_clip_params {
/* use_gpu */ ctx_params.use_gpu, /* use_gpu */ ctx_params.use_gpu,
/* flash_attn_type */ CLIP_FLASH_ATTN_TYPE_AUTO, /* flash_attn_type */ mtmd_get_clip_flash_attn_type(ctx_params.flash_attn_type),
/* image_min_tokens */ ctx_params.image_min_tokens, /* image_min_tokens */ ctx_params.image_min_tokens,
/* image_max_tokens */ ctx_params.image_max_tokens, /* image_max_tokens */ ctx_params.image_max_tokens,
/* warmup */ ctx_params.warmup, /* warmup */ ctx_params.warmup,
@ -827,7 +827,8 @@ int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens)
if (clip_is_llava(ctx_clip) if (clip_is_llava(ctx_clip)
|| clip_is_minicpmv(ctx_clip) || clip_is_minicpmv(ctx_clip)
|| clip_is_glm(ctx_clip)) { || clip_is_glm(ctx_clip)
|| clip_is_deepseekocr(ctx_clip)) {
// TODO @ngxson : llava does not support batched encoding ; this should be fixed inside clip_image_batch_encode() // TODO @ngxson : llava does not support batched encoding ; this should be fixed inside clip_image_batch_encode()
const auto & entries = image_tokens->batch_f32.entries; const auto & entries = image_tokens->batch_f32.entries;
for (size_t i = 0; i < entries.size(); i++) { for (size_t i = 0; i < entries.size(); i++) {

View File

@ -28,6 +28,14 @@ if [ "${1:-}" = "huge" ]; then
echo "Include BIG and HUGE models..." echo "Include BIG and HUGE models..."
fi fi
# Check if the second argument is "flash", then enable flash attention
# This is useful to test if flash attention off works correctly
FLASH_ATTN="on"
if [ "${2:-}" = "flash_off" ] || [ "${1:-}" = "flash_off" ]; then
FLASH_ATTN="off"
echo "Flash attention disabled..."
fi
############### ###############
arr_prefix=() arr_prefix=()
@ -99,6 +107,8 @@ if [ "$RUN_BIG_TESTS" = true ]; then
add_test_vision "ggml-org/Qwen2.5-Omni-7B-GGUF:Q4_K_M" add_test_vision "ggml-org/Qwen2.5-Omni-7B-GGUF:Q4_K_M"
# add_test_vision "ggml-org/Qwen2.5-VL-32B-Instruct-GGUF:Q4_K_M" # does not work on my mac M3 Ultra # add_test_vision "ggml-org/Qwen2.5-VL-32B-Instruct-GGUF:Q4_K_M" # does not work on my mac M3 Ultra
# add_test_vision "ggml-org/Kimi-VL-A3B-Thinking-2506-GGUF:Q4_K_M" # not always working # add_test_vision "ggml-org/Kimi-VL-A3B-Thinking-2506-GGUF:Q4_K_M" # not always working
add_test_vision "sabafallah/DeepSeek-OCR-GGUF:q8_0" -p "Free OCR." --chat-template deepseek-ocr
add_test_vision "ggml-org/GLM-4.6V-Flash-GGUF:Q4_K_M" -p "extract all texts from this image"
add_test_audio "ggml-org/ultravox-v0_5-llama-3_1-8b-GGUF:Q4_K_M" add_test_audio "ggml-org/ultravox-v0_5-llama-3_1-8b-GGUF:Q4_K_M"
add_test_audio "ggml-org/Qwen2.5-Omni-7B-GGUF:Q4_K_M" add_test_audio "ggml-org/Qwen2.5-Omni-7B-GGUF:Q4_K_M"
@ -142,6 +152,7 @@ for i in "${!arr_hf[@]}"; do
-hf $(printf %q "$hf") \ -hf $(printf %q "$hf") \
--image $(printf %q "$SCRIPT_DIR/$inp_file") \ --image $(printf %q "$SCRIPT_DIR/$inp_file") \
--temp 0 -n 128 \ --temp 0 -n 128 \
--flash-attn $(printf %q "$FLASH_ATTN") \
${extra_args}" ${extra_args}"
# if extra_args does not contain -p, we add a default prompt # if extra_args does not contain -p, we add a default prompt

View File

@ -0,0 +1,85 @@
<|ref|>title<|/ref|><|det|>[[61, 255, 907, 533]]<|/det|>
# MEN WALK ON MOON
ASTRONAUTS LAND ON PLAIN;
COLLECT ROCKS, PLANT FLAG
<|ref|>text<|/ref|><|det|>[[56, 559, 268, 629]]<|/det|>
Voice From Moon:
Eagle Has Landed'
<|ref|>text<|/ref|><|det|>[[74, 645, 262, 675]]<|/det|>
EAGLE (the lunar surface, Houston, Truesquily)
Base here, The Eagle has landed.
<|ref|>text<|/ref|><|det|>[[74, 675, 262, 720]]<|/det|>
BOOTHROOM: Lounge, Truesquily, we enjoy you on the ground. You've got a bunch of guys about to toss bikes. We're breaking again. Thanks a lot.
<|ref|>text<|/ref|><|det|>[[74, 720, 262, 750]]<|/det|>
TRAVELLING MADE: Time you. BOOTHROOM: You're looking good here.
<|ref|>text<|/ref|><|det|>[[74, 750, 262, 780]]<|/det|>
TRAVELLING MADE: A very smooth touchdown. BEDROOM: Eagle, you are very far. I'll. (The first sign in the lunar appearance) (Over.)
<|ref|>text<|/ref|><|det|>[[74, 780, 262, 810]]<|/det|>
TRAVELLING MADE: Eagle, stay for I'll. BOOTHROOM: Bumper and we are you waiting the cue.
<|ref|>text<|/ref|><|det|>[[74, 810, 262, 830]]<|/det|>
TRAVELLING MADE: Eagle, and service mobility.
<|ref|>text<|/ref|><|det|>[[74, 830, 262, 850]]<|/det|>
How do you read me?
<|ref|>text<|/ref|><|det|>[[74, 850, 262, 880]]<|/det|>
TRAVELLING COLUMBIA, he has landed Truesquily. Base, Eagle is at Truesquily. I read you first by. Over.
<|ref|>text<|/ref|><|det|>[[74, 880, 262, 900]]<|/det|>
COLUMBIA: Yes, I heard the whole thing.
<|ref|>text<|/ref|><|det|>[[74, 900, 262, 920]]<|/det|>
BOOTHROOM: Well, it's a good show.
<|ref|>text<|/ref|><|det|>[[74, 920, 262, 940]]<|/det|>
COLUMBIA: Fantastic.
<|ref|>text<|/ref|><|det|>[[74, 940, 262, 960]]<|/det|>
TRAVELLING MADE: I'll read that.
<|ref|>text<|/ref|><|det|>[[74, 960, 262, 980]]<|/det|>
APOLLO CONTROL: The most major sky to sky will be for the 23 event, that is at 21 minutes 26 sec-
<|ref|>text<|/ref|><|det|>[[74, 980, 262, 990]]<|/det|>
tion of lunar descent.
<|ref|>image<|/ref|><|det|>[[270, 545, 697, 990]]<|/det|>
<|ref|>text<|/ref|><|det|>[[715, 559, 911, 629]]<|/det|>
A Powdery Surface
Is Closely Explored
<|ref|>text<|/ref|><|det|>[[733, 645, 851, 665]]<|/det|>
BY JOHN NOBLE WILFORD
<|ref|>text<|/ref|><|det|>[[715, 669, 911, 700]]<|/det|>
HOUSTON, Monday, July 21—New hires landed and walked on the moon.
<|ref|>text<|/ref|><|det|>[[715, 700, 911, 750]]<|/det|>
Two Americans, astronauts of Apollo 11, steered their Eagle-shaped lunar module safely and smoothly to the lunar landing yesterday at 4:17:40 P.M., Eastern day-light time.
<|ref|>text<|/ref|><|det|>[[715, 750, 911, 780]]<|/det|>
Neil A. Armstrong, the 38-year-old civilian commander, radioed to earth and the landing team here.
<|ref|>text<|/ref|><|det|>[[715, 780, 911, 830]]<|/det|>
"Boom, Truesquily! Base here. The Eagle has landed," the first man to reach the moon—Neil Armstrong and his engineer, Capt. Charles E. Alder, of the Jet Propulsion Laboratory, the space agency's rocket and space program manager.
<|ref|>text<|/ref|><|det|>[[715, 830, 911, 880]]<|/det|>
About six and a half hours later, Mr. Armstrong opened the landing craft's hatch, stepped slowly down the ladder and descended as he pointed his first landing footguard on the lunar crater.
<|ref|>text<|/ref|><|det|>[[715, 880, 911, 920]]<|/det|>
"That's one small step for man, one giant leap for mankind."
<|ref|>text<|/ref|><|det|>[[715, 920, 911, 960]]<|/det|>
His first step on the moon came on 10:56:29 P.M., as a television camera recorded the craft's transmitted his every word to an aerial and excited audiences of hundreds of millions of people on earth.
<|ref|>text<|/ref|><|det|>[[749, 960, 861, 974]]<|/det|>
Testable Slope Test Soil

View File

@ -0,0 +1,42 @@
MEN WALK ON MOON
ASTRONAUTS LAND ON PLAIN;
COLLECT ROCKS, PLANT FLAG
Voice From Moon:
'Eagle Has Landed'
A Powder Surface
Is Closely Explored
By JOHN NOBLE WILFORD
NOVEMBER, Monday, July 21—New York Herald and
wished on the moon.
Two American astronauts of Apollo 11, steered their
frigate Eagle toward the moon's surface and smoothly to
the lunar landing yesterday at 4:17:40 P.M., Eastern day-
light time.
Neil A. Armstrong, the 38-year-old civilian commander,
landed on the soft sand of the moon's surface here.
"Beautiful, Triumph!" he said. "The Eagle has landed."
The first man to reach the moon—Neil Armstrong and
his co-pilot, Charles E. "Pete" Conrad, 26, of the Pentagon,
brought their ship to rest on a level, rock-strewn plain near
the moon's surface. The two men and two of the three
astronauts on board, Armstrong, Conrad and Edwin E.
Aldrin, 38, of Houston, stepped slowly down the ladder
and descended as he pointed his first full-flaming footpad
at the lunar crater.
"That's one small step for man, one giant leap for
mankind."
His first step on the moon came at 10:56:20 P.M., as
a television camera rolled the earth's thousandth line every
second to an aerial and studied audiences of hundreds of
millions of people on earth.
Textile Slope Test Soil

View File

@ -0,0 +1,186 @@
#!/usr/bin/env python3
"""
Test script to compare llama.cpp mtmd-cli output with HuggingFace reference implementation
for DeepSeek-OCR model using embedding similarity.
"""
import argparse
import subprocess
import sys
from pathlib import Path
from sentence_transformers import SentenceTransformer
from sentence_transformers import util
def run_mtmd_deepseek_ocr(
model_path: str,
mmproj_path: str,
image_path: str,
bin_path: str,
prompt: str = "Free OCR."
) -> str:
"""
Run inference using llama.cpp mtmd-cli.
"""
cmd = [
bin_path,
"-m", model_path,
"--mmproj", mmproj_path,
"--image", image_path,
# "-p", "<|grounding|>Convert the document to markdown.",
"-p", prompt,
"--chat-template", "deepseek-ocr",
"--temp", "0",
"-n", "1024",
# "--verbose"
]
print(f"Running llama.cpp command: {' '.join(cmd)}")
result = subprocess.run(
cmd,
capture_output=True,
text=False,
timeout=300
)
if result.returncode != 0:
stderr = result.stderr.decode('utf-8', errors='replace')
print(f"llama.cpp stderr: {stderr}")
raise RuntimeError(f"llama-mtmd-cli failed with code {result.returncode}")
output = result.stdout.decode('utf-8', errors='replace').strip()
print(f"llama.cpp output length: {len(output)} chars")
return output
def compute_embedding_similarity(text1: str, text2: str, model_name: str) -> float:
"""
Compute cosine similarity between two texts using embedding model.
"""
print(f"Loading embedding model: {model_name}")
# Use sentence-transformers for easier embedding extraction
embed_model = SentenceTransformer(model_name)
print("Computing embeddings...")
embeddings = embed_model.encode([text1, text2], convert_to_numpy=True)
similarity = util.similarity.cos_sim([embeddings[0]], [embeddings[1]])[0][0]
return float(similarity)
def read_expected_output(file_path: str) -> str:
"""
Read expected OCR output from file.
"""
cur_path = Path(__file__).parent
expected_path = str(cur_path / file_path)
with open(expected_path, "r", encoding="utf-8") as f:
return f.read().strip()
def main():
ap = argparse.ArgumentParser(description="Compare llama.cpp and HuggingFace DeepSeek-OCR outputs")
ap.add_argument("--llama-model", default="gguf_models/deepseek-ai/deepseek-ocr-f16.gguf",
help="Path to llama.cpp GGUF model")
ap.add_argument("--mmproj", default="gguf_models/deepseek-ai/mmproj-deepseek-ocr-f16.gguf",
help="Path to mmproj GGUF file")
ap.add_argument("--image", default="test-1.jpeg",
help="Path to test image")
ap.add_argument("--llama-bin", default="build/bin/llama-mtmd-cli",
help="Path to llama-mtmd-cli binary")
ap.add_argument("--embedding-model", default="Qwen/Qwen3-Embedding-0.6B",
help="Embedding model for similarity computation")
ap.add_argument("--threshold", type=float, default=0.7,
help="Minimum similarity threshold for pass")
args = ap.parse_args()
# Validate paths
# script directory + image
mtmd_dir = Path(__file__).parent.parent
args.image = str(mtmd_dir / args.image)
# project directory + llama model
args.llama_model = str(mtmd_dir.parent.parent / args.llama_model)
# project directory + mmproj
args.mmproj = str(mtmd_dir.parent.parent / args.mmproj)
args.llama_bin = str(mtmd_dir.parent.parent / args.llama_bin)
if not Path(args.image).exists():
print(f"Error: Image not found: {args.image}")
sys.exit(1)
if not Path(args.llama_model).exists():
print(f"Error: Model not found: {args.llama_model}")
sys.exit(1)
if not Path(args.mmproj).exists():
print(f"Error: mmproj not found: {args.mmproj}")
sys.exit(1)
print("=" * 60)
print("DeepSeek-OCR: llama.cpp vs HuggingFace Comparison")
print("=" * 60)
# Default paths based on your command
# Run llama.cpp inference
print("\n[2/3] Running llama.cpp implementation...")
llama_free_ocr = run_mtmd_deepseek_ocr(
args.llama_model,
args.mmproj,
args.image,
args.llama_bin
)
llama_md_ocr = run_mtmd_deepseek_ocr(
args.llama_model,
args.mmproj,
args.image,
args.llama_bin,
prompt="<|grounding|>Convert the document to markdown."
)
expected_free_ocr = read_expected_output("test-1-extracted.txt")
expected_md_ocr = read_expected_output("test-1-extracted.md")
# Compute similarity
print("\n[3/3] Computing embedding similarity...")
free_ocr_similarity = compute_embedding_similarity(
expected_free_ocr,
llama_free_ocr,
args.embedding_model
)
md_ocr_similarity = compute_embedding_similarity(
expected_md_ocr,
llama_md_ocr,
args.embedding_model
)
# Results
print("\n" + "=" * 60)
print("RESULTS")
print("=" * 60)
print(f"\nReference Model output:\n{'-' * 40}")
print(expected_free_ocr)
print(f"\nDeepSeek-OCR output:\n{'-' * 40}")
print(llama_free_ocr)
print(f"\n{'=' * 60}")
print(f"Cosine Similarity: {free_ocr_similarity:.4f}")
print(f"Threshold: {args.threshold}")
print(f"Result: {'PASS' if free_ocr_similarity >= args.threshold else 'FAIL'}")
print("=" * 60)
# Markdown OCR results
print(f"\nReference Model Markdown output:\n{'-' * 40}")
print(expected_md_ocr)
print(f"\nDeepSeek-OCR Markdown output:\n{'-' * 40}")
print(llama_md_ocr)
print(f"\n{'=' * 60}")
print(f"Cosine Similarity (Markdown): {md_ocr_similarity:.4f}")
print(f"Threshold: {args.threshold}")
print(f"Result: {'PASS' if md_ocr_similarity >= args.threshold else 'FAIL'}")
print("=" * 60)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,5 @@
sentence-transformers
transformers
tokenizers
torch
torchvision