diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 222f6ed6dc..82a6c95bdd 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -620,6 +620,9 @@ class ModelBase: if "thinker_config" in config: # rename for Qwen2.5-Omni config["text_config"] = config["thinker_config"]["text_config"] + if "language_config" in config: + # rename for DeepSeekOCR + config["text_config"] = config["language_config"] return config @classmethod @@ -1442,7 +1445,7 @@ class MmprojModel(ModelBase): preprocessor_config: dict[str, Any] global_config: dict[str, Any] - n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth"] + n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth", "width.clip-l-14-224.layers", "sam_vit_b.layers"] has_vision_encoder: bool = True # by default has_audio_encoder: bool = False @@ -1488,13 +1491,31 @@ class MmprojModel(ModelBase): # TODO @ngxson : this is a hack to support both vision and audio encoders have_multiple_encoders = self.has_audio_encoder and self.has_vision_encoder self.block_count = 128 if have_multiple_encoders else self.find_hparam(self.n_block_keys, True) + # FIXME: DeepseekOCRVisionModel specific hack + if self.block_count is None: + if isinstance(self, DeepseekOCRVisionModel): + clip_block_count = self.hparams['width']['clip-l-14-224']['layers'] + sam_block_count = self.hparams['width']['sam_vit_b']['layers'] + if clip_block_count is not None: + self.block_count = clip_block_count + if sam_block_count is not None: + self.block_count = sam_block_count if self.block_count is None else self.block_count + sam_block_count + if self.block_count is None: + raise KeyError(f"could not find block count using any of: {self.n_block_keys}") self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.MMPROJ, self.block_count) # load preprocessor config self.preprocessor_config = {} if not self.is_mistral_format: - with open(self.dir_model / "preprocessor_config.json", "r", encoding="utf-8") as f: - self.preprocessor_config = json.load(f) + # check if preprocessor_config.json exists + if (self.dir_model / "preprocessor_config.json").is_file(): + with open(self.dir_model / "preprocessor_config.json", "r", encoding="utf-8") as f: + self.preprocessor_config = json.load(f) + else: + # try "processing_config" file if exists + if (self.dir_model / "processing_config.json").is_file(): + with open(self.dir_model / "processing_config.json", "r", encoding="utf-8") as f: + self.preprocessor_config = json.load(f) def get_vision_config(self) -> dict[str, Any] | None: config_name = "vision_config" if not self.is_mistral_format else "vision_encoder" @@ -5770,6 +5791,61 @@ class Gemma3VisionModel(MmprojModel): return [] # skip other tensors +@ModelBase.register("DeepseekOCRForCausalLM") +class DeepseekOCRVisionModel(MmprojModel): + def set_gguf_parameters(self): + super().set_gguf_parameters() + hparams = self.hparams + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.DEEPSEEKOCR) + # default values below are taken from HF tranformers code + self.gguf_writer.add_vision_attention_layernorm_eps(hparams.get("layer_norm_eps", 1e-6)) + self.gguf_writer.add_vision_use_gelu(True) + # calculate proj_scale_factor (used by tinygemma3 test model) + image_seq_length = self.preprocessor_config.get("image_seq_length", 256) + n_per_side = int(image_seq_length ** 0.5) + image_size = self.hparams["image_size"] + patch_size = self.hparams["patch_size"] + proj_scale_factor = (image_size // patch_size) // n_per_side + if proj_scale_factor > 0 and proj_scale_factor != 4: + # we only need to write this if it's not the default value + # in this case, we are converting a test model + self.gguf_writer.add_vision_projector_scale_factor(proj_scale_factor) + + def get_vision_config(self) -> dict[str, Any]: + orig_vision_config = self.global_config.get("vision_config") + + super().get_vision_config() + + def tensor_force_quant(self, name, new_name, bid, n_dims): + # related to https://github.com/ggml-org/llama.cpp/issues/13025 + if "input_projection" in name: + return gguf.GGMLQuantizationType.F16 + if ".embeddings." in name: + return gguf.GGMLQuantizationType.F32 + return super().tensor_force_quant(name, new_name, bid, n_dims) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + if "vision_model.head." in name: + return [] # skip redundant tensors for tinygemma3 + + if name.startswith("multi_modal_projector.") or name.startswith("vision_tower.") \ + or name.startswith("multimodal_projector.") or name.startswith("vision_model."): + # process vision tensors + name = name.replace("_weight", ".weight") + + # correct norm value ; only this "soft_emb_norm" need to be corrected as it's part of Gemma projector + # the other norm values are part of SigLIP model, and they are already correct + # ref code: Gemma3RMSNorm + if "soft_emb_norm.weight" in name: + logger.info(f"Correcting norm value for '{name}'") + data_torch = data_torch + 1 + + return [(self.map_tensor_name(name), data_torch)] + + return [] # skip other tensors + @ModelBase.register("Gemma3nForConditionalGeneration") class Gemma3NModel(Gemma3Model): @@ -6943,6 +7019,7 @@ class DeepseekModel(TextModel): @ModelBase.register( "DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM", + "DeepseekOCRForCausalLM", "KimiVLForConditionalGeneration", ) class DeepseekV2Model(TextModel): @@ -7009,31 +7086,35 @@ class DeepseekV2Model(TextModel): super().set_gguf_parameters() hparams = self.hparams + kv_lora_rank = hparams["q_lora_rank"] if hparams["q_lora_rank"] is not None else 512 + routed_scaling_factor = hparams.get("routed_scaling_factor", 1.0) + norm_topk_prob = hparams.get("norm_topk_prob", False) + scoring_func = hparams.get("scoring_func", "softmax") self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"]) self.gguf_writer.add_vocab_size(hparams["vocab_size"]) if "q_lora_rank" in hparams and hparams["q_lora_rank"] is not None: self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"]) - self.gguf_writer.add_kv_lora_rank(hparams["kv_lora_rank"]) + self.gguf_writer.add_kv_lora_rank(kv_lora_rank) # note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA - self.gguf_writer.add_key_length(hparams["kv_lora_rank"] + hparams["qk_rope_head_dim"]) - self.gguf_writer.add_value_length(hparams["kv_lora_rank"]) + self.gguf_writer.add_key_length(kv_lora_rank + hparams["qk_rope_head_dim"]) + self.gguf_writer.add_value_length(kv_lora_rank) self.gguf_writer.add_key_length_mla(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"]) self.gguf_writer.add_value_length_mla(hparams["v_head_dim"]) self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"]) self.gguf_writer.add_expert_count(hparams["n_routed_experts"]) self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"]) - self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"]) - self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"]) + self.gguf_writer.add_expert_weights_scale(routed_scaling_factor) + self.gguf_writer.add_expert_weights_norm(norm_topk_prob) - if hparams["scoring_func"] == "sigmoid": + if scoring_func == "sigmoid": self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID) - elif hparams["scoring_func"] == "softmax": + elif scoring_func == "softmax": self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX) else: - raise ValueError(f"Unsupported scoring_func value: {hparams['scoring_func']}") + raise ValueError(f"Unsupported scoring_func value: {scoring_func}") self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) @@ -7043,12 +7124,14 @@ class DeepseekV2Model(TextModel): self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"]) self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1 * rope_scaling["mscale_all_dim"]) + self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("rms_norm_eps", 1e-6)) _experts: list[dict[str, Tensor]] | None = None def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: # skip vision tensors and remove "language_model." for Kimi-VL - if "vision_tower" in name or "multi_modal_projector" in name: + if "vision_" in name or "multi_modal_projector" in name \ + or "image_newline" in name or "model.projector" in name or "sam_model" in name or "view_seperator" in name: return [] if name.startswith("language_model."): diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 6b4b6c5ab0..dfd947083a 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -664,6 +664,21 @@ class MODEL_TENSOR(IntEnum): V_MM_GATE = auto() # cogvlm V_TOK_BOI = auto() # cogvlm V_TOK_EOI = auto() # cogvlm + # DeepSeek-OCR sam_model + V_SAM_POS_EMBD = auto() + V_SAM_PATCH_EMBD = auto() + V_SAM_PRE_NORM = auto() + V_SAM_POST_NORM = auto() + V_SAM_ATTN_POS_H = auto() + V_SAM_ATTN_POS_W = auto() + V_SAM_ATTN_QKV = auto() + V_SAM_ATTN_OUT = auto() + V_SAM_MLP_LIN_1 = auto() + V_SAM_MLP_LIN_2 = auto() + V_SAM_NECK = auto() + V_SAM_NET_2 = auto() + V_SAM_NET_3 = auto() + # audio (mtmd) A_ENC_EMBD_POS = auto() A_ENC_CONV1D = auto() @@ -1030,6 +1045,20 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.V_MM_GATE: "mm.gate", MODEL_TENSOR.V_TOK_BOI: "v.boi", MODEL_TENSOR.V_TOK_EOI: "v.eoi", + # DeepSeek-OCR sam_model + MODEL_TENSOR.V_SAM_POS_EMBD: "v.sam.pos_embd", + MODEL_TENSOR.V_SAM_PATCH_EMBD: "v.sam.patch_embd", + MODEL_TENSOR.V_SAM_PRE_NORM: "v.sam.blk.{bid}.pre_ln", + MODEL_TENSOR.V_SAM_POST_NORM: "v.sam.blk.{bid}.post_ln", + MODEL_TENSOR.V_SAM_ATTN_POS_H: "v.sam.blk.{bid}.attn.pos_h", + MODEL_TENSOR.V_SAM_ATTN_POS_W: "v.sam.blk.{bid}.attn.pos_w", + MODEL_TENSOR.V_SAM_ATTN_QKV: "v.sam.blk.{bid}.attn.qkv", + MODEL_TENSOR.V_SAM_ATTN_OUT: "v.sam.blk.{bid}.attn.out", + MODEL_TENSOR.V_SAM_MLP_LIN_1: "v.sam.blk.{bid}.mlp.lin1", + MODEL_TENSOR.V_SAM_MLP_LIN_2: "v.sam.blk.{bid}.mlp.lin2", + MODEL_TENSOR.V_SAM_NECK: "v.sam.neck.{bid}", + MODEL_TENSOR.V_SAM_NET_2: "v.sam.net_2", + MODEL_TENSOR.V_SAM_NET_3: "v.sam.net_3", # audio (mtmd) MODEL_TENSOR.A_ENC_EMBD_POS: "a.position_embd", MODEL_TENSOR.A_ENC_CONV1D: "a.conv1d.{bid}", @@ -2247,7 +2276,9 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.ATTN_Q_B, MODEL_TENSOR.ATTN_KV_A_MQA, MODEL_TENSOR.ATTN_KV_B, + MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_K_B, + MODEL_TENSOR.ATTN_V, MODEL_TENSOR.ATTN_V_B, MODEL_TENSOR.ATTN_Q_A_NORM, MODEL_TENSOR.ATTN_KV_A_NORM, @@ -3207,6 +3238,7 @@ class VisionProjectorType: LIGHTONOCR = "lightonocr" COGVLM = "cogvlm" JANUS_PRO = "janus_pro" + DEEPSEEKOCR = "deepseekocr" # Items here are (block size, type size) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 9294066876..f15ea0a02a 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -2,6 +2,8 @@ from __future__ import annotations from typing import Sequence +from numpy.f2py.auxfuncs import throw_error + from .constants import MODEL_ARCH, MODEL_TENSOR, MODEL_TENSORS, TENSOR_NAMES @@ -1457,6 +1459,58 @@ class TensorNameMap: "model.visual.deepstack_merger_list.{bid}.linear_fc2", # deepstack in qwen3vl ), + MODEL_TENSOR.V_SAM_POS_EMBD: ( + "model.sam_model.pos_embed" + ), + + MODEL_TENSOR.V_SAM_PATCH_EMBD: ( + "model.sam_model.patch_embed.proj" + ), + + MODEL_TENSOR.V_SAM_PRE_NORM: ( + "model.sam_model.blocks.{bid}.norm1", # deepstack in qwen3vl + ), + + MODEL_TENSOR.V_SAM_POST_NORM: ( + "model.sam_model.blocks.{bid}.norm2", # deepstack in qwen3vl + ), + + MODEL_TENSOR.V_SAM_ATTN_POS_H: ( + "model.sam_model.blocks.{bid}.attn.rel_pos_h" + ), + + MODEL_TENSOR.V_SAM_ATTN_POS_W: ( + "model.sam_model.blocks.{bid}.attn.rel_pos_w" + ), + + MODEL_TENSOR.V_SAM_ATTN_QKV: ( + "model.sam_model.blocks.{bid}.attn.qkv" + ), + + MODEL_TENSOR.V_SAM_ATTN_OUT: ( + "model.sam_model.blocks.{bid}.attn.proj" + ), + + MODEL_TENSOR.V_SAM_MLP_LIN_1: ( + "model.sam_model.blocks.{bid}.mlp.lin1", + ), + + MODEL_TENSOR.V_SAM_MLP_LIN_2: ( + "model.sam_model.blocks.{bid}.mlp.lin2", + ), + + MODEL_TENSOR.V_SAM_NECK: ( + "model.sam_model.neck.{bid}" + ), + + MODEL_TENSOR.V_SAM_NET_2: ( + "model.sam_model.net_2" + ), + + MODEL_TENSOR.V_SAM_NET_3: ( + "model.sam_model.net_3" + ), + MODEL_TENSOR.V_MM_POST_FC_NORM: ( "model.vision.linear_proj.norm1", # cogvlm ), diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index 722b1a4948..8d1c7d0dff 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -129,6 +129,24 @@ #define TN_TOK_BOI "v.boi" #define TN_TOK_EOI "v.eoi" +// deepseek-ocr +#define TN_SAM_POS_EMBD "sam.pos_embd" +#define TN_SAM_PATCH_EMBD "sam.patch_embd" +#define TN_SAM_PRE_NORM "sam.blk.%d.pre_ln" +#define TN_SAM_POST_NORM "sam.blk.%d.post_ln" +#define TN_SAM_ATTN_POS_H "sam.blk.%d.attn.pos_h" +#define TN_SAM_ATTN_POS_W "sam.blk.%d.attn.pos_w" +#define TN_SAM_ATTN_QKV "sam.blk.%d.attn.qkv" +#define TN_SAM_ATTN_OUT "sam.blk.%d.attn.out" +#define TN_SAM_MLP_LIN_1 "sam.blk.%d.mlp.lin1" +#define TN_SAM_MLP_LIN_2 "sam.blk.%d.mlp.lin2" +#define TN_SAM_NECK "sam.neck.%d" +#define TN_SAM_NET_2 "sam.net_2" +#define TN_SAM_NET_3 "sam.net_3" + + +#define TN_SAM_ATTN_OUT "sam.blk.%d.attn_out" + // align x to upper multiple of n #define CLIP_ALIGN(x, n) ((((x) + (n) - 1) / (n)) * (n)) @@ -156,6 +174,7 @@ enum projector_type { PROJECTOR_TYPE_LIGHTONOCR, PROJECTOR_TYPE_COGVLM, PROJECTOR_TYPE_JANUS_PRO, + PROJECTOR_TYPE_DEEPSEEK_OCR, PROJECTOR_TYPE_UNKNOWN, }; @@ -182,6 +201,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_LIGHTONOCR,"lightonocr"}, { PROJECTOR_TYPE_COGVLM, "cogvlm"}, { PROJECTOR_TYPE_JANUS_PRO, "janus_pro"}, + { PROJECTOR_TYPE_DEEPSEEK_OCR,"deepseek_orc"}, }; static projector_type clip_projector_type_from_string(const std::string & str) { diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index d1423b67f9..0961b96fd6 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -222,6 +222,33 @@ struct clip_hparams { warmup_image_size = n_tok_per_side * patch_size * cur_merge; // TODO: support warmup size for custom token numbers } + + // sam vit deepseek-ocr + std::vector global_attn_indices() const { + switch (n_embd) { + case 768: return { 2, 5, 8, 11 }; + case 1024: return { 5, 11, 17, 23 }; + case 1280: return { 7, 15, 23, 31 }; + default: + { + fprintf(stderr, "%s: unsupported n_enc_state = %d\n", __func__, n_embd); + } break; + }; + + return {}; + } + + bool is_global_attn(int32_t layer) const { + const auto indices = global_attn_indices(); + + for (const auto & idx : indices) { + if (layer == idx) { + return true; + } + } + + return false; + } }; struct clip_layer { @@ -271,6 +298,10 @@ struct clip_layer { bool has_deepstack() const { return deepstack_fc1_w != nullptr; } + + // sam rel_pos + ggml_tensor * rel_pos_w = nullptr; + ggml_tensor * rel_pos_h = nullptr; }; struct clip_model { @@ -308,6 +339,7 @@ struct clip_model { ggml_tensor * mm_2_b = nullptr; ggml_tensor * image_newline = nullptr; + ggml_tensor * view_seperator = nullptr; // Yi type models with mlp+normalization projection ggml_tensor * mm_1_w = nullptr; // Yi type models have 0, 1, 3, 4 @@ -400,6 +432,11 @@ struct clip_model { ggml_tensor * mm_boi = nullptr; ggml_tensor * mm_eoi = nullptr; + // deepseek ocr sam + ggml_tensor * patch_embed_proj_w = nullptr; + ggml_tensor * patch_embed_proj_b = nullptr; + ggml_tensor * pos_embed = nullptr; + bool audio_has_avgpool() const { return proj_type == PROJECTOR_TYPE_QWEN2A || proj_type == PROJECTOR_TYPE_VOXTRAL; @@ -409,6 +446,15 @@ struct clip_model { return proj_type == PROJECTOR_TYPE_ULTRAVOX || proj_type == PROJECTOR_TYPE_VOXTRAL; } + ggml_tensor * neck_conv_0; + ggml_tensor * neck_norm_0_w; + ggml_tensor * neck_norm_0_b; + ggml_tensor * neck_conv_1; + ggml_tensor * neck_norm_1_w; + ggml_tensor * neck_norm_1_b; + + std::vector enc_layers; + }; struct clip_ctx { @@ -521,9 +567,9 @@ struct clip_graph { hparams(model.hparams), img(img), patch_size(hparams.patch_size), - n_patches_x(img.nx / patch_size), - n_patches_y(img.ny / patch_size), - n_patches(n_patches_x * n_patches_y), + n_patches_x(img.nx / patch_size), // sam 1024 / 16 = 64 + n_patches_y(img.ny / patch_size), // sam 1024 / 16 = 64 + n_patches(n_patches_x * n_patches_y), // sam 64 * 64 = 4096 n_embd(hparams.n_embd), n_head(hparams.n_head), d_head(n_embd / n_head), @@ -619,6 +665,244 @@ struct clip_graph { return gf; } + ggml_tensor * build_sam_enc(ggml_tensor * inp_raw, + const int enc_image_size = 1024 + ) { + constexpr int enc_n_embd = 768; + constexpr int _depth = 12; + constexpr int enc_n_heads = 12; + constexpr int enc_d_heads = enc_n_embd / enc_n_heads; + constexpr int _prompt_n_embd = 256; + constexpr int enc_patch_size = 16; + constexpr int _window_size = 14; + + const int enc_n_patches = enc_image_size / enc_patch_size; // 64 + + ggml_tensor * inpL = build_enc_inp(inp_raw, enc_patch_size, enc_image_size, enc_n_embd); + ggml_tensor * cur = ggml_add(ctx0, inpL, model.position_embeddings); + + // loop over layers + for (int il = 0; il < _depth; il++) { + auto & layer = model.enc_layers[il]; + + // layernorm1 + cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, NORM_TYPE_NORMAL, eps, il); + cb(cur, "enc_layer_inp_normed", il); + + const int64_t w0 = cur->ne[1]; + const int64_t h0 = cur->ne[2]; + + if (hparams.is_global_attn(il) == false) { + // local attention layer - apply window partition + // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L169-L172 + cur = ggml_win_part(ctx0, cur, 14); + } + + const int64_t W = cur->ne[1]; + const int64_t H = cur->ne[2]; + + // self-attention + { + cur = ggml_mul_mat(ctx0, layer.qkv_w, cur); + cur = ggml_add(ctx0, cur, layer.qkv_b); + const int B = cur->ne[3]; + + cur = ggml_reshape_4d(ctx0, cur, enc_n_embd, 3, W * H, B); + cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 0, 3, 1, 2)); + + ggml_tensor * Qcur = + ggml_view_3d(ctx0, cur, enc_n_embd, W * H, B, cur->nb[1], cur->nb[2], 0); + Qcur = ggml_reshape_4d(ctx0, Qcur, enc_d_heads, enc_n_heads, W * H, B); + Qcur = ggml_cont(ctx0, ggml_permute(ctx0, Qcur, 0, 2, 1, 3)); + Qcur = ggml_reshape_3d(ctx0, Qcur, enc_d_heads, W * H, B * enc_n_heads); + + ggml_tensor * Kcur = + ggml_view_3d(ctx0, cur, enc_n_embd, W * H, B, cur->nb[1], cur->nb[2], 1 * cur->nb[3]); + Kcur = ggml_reshape_4d(ctx0, Kcur, enc_d_heads, enc_n_heads, W * H, B); + Kcur = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3)); + Kcur = ggml_reshape_3d(ctx0, Kcur, enc_d_heads, W * H, B * enc_n_heads); + + ggml_tensor * Vcur = + ggml_view_3d(ctx0, cur, enc_n_embd, W * H, B, cur->nb[1], cur->nb[2], 2 * cur->nb[3]); + Vcur = ggml_reshape_4d(ctx0, Vcur, enc_d_heads, enc_n_heads, W * H, B); + Vcur = ggml_cont(ctx0, ggml_permute(ctx0, Vcur, 1, 2, 0, 3)); // transposed + Vcur = ggml_reshape_3d(ctx0, Vcur, W * H, enc_d_heads, B * enc_n_heads); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + + + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, Kcur, Qcur); + + struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, 1.0f / sqrtf(enc_n_heads)); + + struct ggml_tensor * rw = ggml_get_rel_pos(ctx0, layer.rel_pos_w, W, W); + struct ggml_tensor * rh = ggml_get_rel_pos(ctx0, layer.rel_pos_h, H, H); + + struct ggml_tensor * q_r = ggml_reshape_4d(ctx0, Qcur, enc_n_heads, W, H, B * enc_n_embd); + + struct ggml_tensor * rel_w = ggml_cont( + ctx0, + ggml_permute(ctx0, ggml_mul_mat(ctx0, rw, ggml_cont(ctx0, ggml_permute(ctx0, q_r, 0, 2, 1, 3))), 0, + 2, 1, 3)); + struct ggml_tensor * rel_h = ggml_mul_mat(ctx0, rh, q_r); + + struct ggml_tensor * attn = ggml_add_rel_pos_inplace(ctx0, KQ_scaled, rel_w, rel_h); + + struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, attn); + + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, Vcur, KQ_soft_max); + + cur = ggml_reshape_4d( + ctx0, + ggml_cont(ctx0, ggml_permute(ctx0, ggml_reshape_4d(ctx0, KQV, enc_d_heads, W * H, enc_n_heads, B), + 0, 2, 1, 3)), + n_embd, W, H, B); + + cur = ggml_mul_mat(ctx0, layer.o_w, cur); + cur = ggml_add_inplace(ctx0, cur, layer.o_b); + } + + if (hparams.is_global_attn(il) == false) { + // local attention layer - reverse window partition + cur = ggml_win_unpart(ctx0, cur, w0, h0, 14); + } + + if (layer.ls_1_w) { + cur = ggml_mul(ctx0, cur, layer.ls_1_w); + cb(cur, "attn_out_scaled", il); + } + + // re-add the layer input, e.g., residual + cur = ggml_add(ctx0, cur, inpL); + + cb(cur, "ffn_inp", il); + + // layernorm2 + cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, NORM_TYPE_NORMAL, eps, il); + cb(cur, "ffn_inp_normed", il); + + // ffn + cur = build_ffn(cur, layer.ff_up_w, layer.ff_up_b, layer.ff_gate_w, layer.ff_gate_b, layer.ff_down_w, + layer.ff_down_b, hparams.ffn_op, il); + + cb(cur, "ffn_out", il); + + if (layer.ls_2_w) { + cur = ggml_mul(ctx0, cur, layer.ls_2_w); + cb(cur, "ffn_out_scaled", il); + } + + // residual 2 + cur = ggml_add(ctx0, inpL, cur); + cb(cur, "layer_out", il); + + return cur; // B, 1024, 16, 16 + } + + cur = ggml_cont(ctx0, ggml_permute(ctx0, inpL, 2, 0, 1, 3)); + + cur = ggml_conv_2d_sk_p0(ctx0, model.neck_conv_0, cur); + + cur = sam_layer_norm_2d(ctx0, cur, 256, model.neck_norm_0_w, model.neck_norm_0_b, hparams.eps); + + cur = ggml_conv_2d_s1_ph(ctx0, model.neck_conv_1, cur); + + cur = sam_layer_norm_2d(ctx0, cur, 256, model.neck_norm_1_w, model.neck_norm_1_b, hparams.eps); + + //cur = ggml_cpy(ctx0, cur, state.embd_img); + + ggml_build_forward_expand(gf, cur); + return cur; + } + + ggml_tensor * sam_layer_norm_2d(ggml_context * ctx0, + ggml_tensor * layer, + int n_channels, + ggml_tensor * w, + ggml_tensor * b, + float eps) { + // LayerNorm2d + // normalize along channel dimmension + // TODO: better implementation + layer = ggml_permute(ctx0, ggml_norm(ctx0, ggml_cont(ctx0, ggml_permute(ctx0, layer, 1, 2, 0, 3)), eps), 2, 0, + 1, 3); + + layer = + ggml_add(ctx0, ggml_mul(ctx0, ggml_repeat(ctx0, ggml_reshape_3d(ctx0, w, 1, 1, n_channels), layer), layer), + ggml_repeat(ctx0, ggml_reshape_3d(ctx0, b, 1, 1, n_channels), layer)); + + return layer; + } + + ggml_cgraph * build_deepseek_ocr() { + //patch embedding + ggml_tensor * inp_raw = build_inp_raw(); + + + ggml_tensor * global_features_1 = build_sam_enc(inp_raw); + + ggml_tensor * global_features_2 = build_dp_ocr_clip(inp_raw, global_features_1); + + // torch global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1) + ggml_tensor * global_features = ggml_concat(ctx0, global_features_1, global_features_2, 0); + global_features = build_global_local_features( + ctx0, + global_features, + n_patches_y, + n_patches_x, + n_embd + ); + + return gf; + } + + // global_features: [n_dim, h*w] + // image_newline: [n_dim] + // view_separator: [n_dim] + + ggml_tensor * build_global_local_features(ggml_context * ctx0, + ggml_tensor * global_features, + int h, + int w, + int n_dim) { + GGML_ASSERT(model.image_newline != nullptr); + GGML_ASSERT(model.view_seperator != nullptr); + GGML_ASSERT(global_features->ne[0] == (int64_t) n_dim); + GGML_ASSERT(global_features->ne[1] == (int64_t) (h * w)); + + // 1) global_features: [n_dim, h*w] -> [n_dim, w, h] -> [h, w, n_dim] + ggml_tensor * t = ggml_reshape_3d(ctx0, global_features, n_dim, w, h); // (n_dim, w, h) + t = ggml_permute(ctx0, t, 2, 1, 0, 3); // (h, w, n_dim) + + // 2) image_newline: [n_dim] -> [1, 1, n_dim] -> repeat to [h, 1, n_dim] + ggml_tensor * nl = ggml_reshape_3d(ctx0, model.image_newline, 1, 1, n_dim); // (1, 1, n_dim) + + ggml_tensor * nl_target_shape = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, h, n_dim); // (1, h, n_dim) + nl = ggml_repeat(ctx0, nl, nl_target_shape); // (1, h, n_dim) + nl = ggml_permute(ctx0, nl, 1, 0, 2, 3); // (h, 1, n_dim) + + // 3) concat along width dimension (dim=1): (h, w, n_dim) + (h, 1, n_dim) -> (h, w+1, n_dim) + t = ggml_concat(ctx0, t, nl, 1); // (h, w+1, n_dim) + + // 4) flatten back to token axis: (h, w+1, n_dim) -> (n_dim, h*(w+1)) + t = ggml_permute(ctx0, t, 2, 1, 0, 3); // (n_dim, w+1, h) + t = ggml_cont_2d(ctx0, t, n_dim, (w + 1) * h); // (n_dim, h*(w+1)) + + // 5) append view_separator as an extra "token": + // view_separator: [n_dim] -> [n_dim, 1] + ggml_tensor * vs = ggml_reshape_2d(ctx0, model.view_seperator, n_dim, 1); // (n_dim, 1) + + // concat along token dimension (dim=1): + ggml_tensor * global_local_features = ggml_concat(ctx0, t, vs, 1); // (n_dim, h*(w+1) + 1) + + return global_local_features; + } + + + ggml_cgraph * build_pixtral() { const int n_merge = hparams.n_merge; @@ -1215,7 +1499,7 @@ struct clip_graph { norm_t, hparams.ffn_op, model.position_embeddings, - nullptr); + nullptr); // shape [1024, 16, 16] // remove CLS token cur = ggml_view_2d(ctx0, cur, @@ -1261,6 +1545,65 @@ struct clip_graph { return gf; } + ggml_tensor * build_dp_ocr_clip(ggml_tensor * inpL, ggml_tensor * patch_embeds) { + GGML_ASSERT(model.class_embedding != nullptr); + GGML_ASSERT(model.position_embeddings != nullptr); + auto n_embd_vit_clip = 1024; + + const int n_pos = n_patches + 1; + ggml_tensor * inp = + ggml_cont_3d(ctx0, ggml_dup_tensor(ctx0, patch_embeds), patch_embeds->ne[0], n_patches_x, n_patches_y); + //ggml_tensor * inp = ggml_cpy(ctx0, inpL, ggml_dup_tensor(ctx0, inpL)); + + // add CLS token + inp = ggml_concat(ctx0, inp, model.class_embedding, 1); + + // The larger models use a different ViT, which uses RMS norm instead of layer norm + // ref: https://github.com/ggml-org/llama.cpp/pull/13443#issuecomment-2869786188 + norm_type norm_t = (hparams.n_embd == 3200 && hparams.n_layer == 45) ? + NORM_TYPE_RMS // 6B ViT (Used by InternVL 2.5/3 - 26B, 38B, 78B) + : + NORM_TYPE_NORMAL; // 300M ViT (Used by all smaller InternVL models) + + ggml_tensor * cur = build_vit(inp, n_pos, norm_t, hparams.ffn_op, model.position_embeddings, + nullptr); // shape [1024, 16, 16] + + // remove CLS token + cur = ggml_view_2d(ctx0, cur, n_embd, n_patches, ggml_row_size(cur->type, n_embd), 0); + + // pixel shuffle + { + const int scale_factor = model.hparams.n_merge; + const int bsz = 1; // batch size, always 1 for now since we don't support batching + const int height = n_patches_y; + const int width = n_patches_x; + GGML_ASSERT(scale_factor > 0); + cur = ggml_reshape_4d(ctx0, cur, n_embd * scale_factor, height / scale_factor, width, bsz); + cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); + cur = ggml_cont_4d(ctx0, cur, n_embd * scale_factor * scale_factor, height / scale_factor, + width / scale_factor, bsz); + cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); + // flatten to 2D + cur = ggml_cont_2d(ctx0, cur, n_embd * scale_factor * scale_factor, cur->ne[1] * cur->ne[2]); + } + + // projector (always using GELU activation) + { + // projector LayerNorm uses pytorch's default eps = 1e-5 + // ref: https://huggingface.co/OpenGVLab/InternVL3-8B-Instruct/blob/a34d3e4e129a5856abfd6aa6de79776484caa14e/modeling_internvl_chat.py#L79 + cur = build_norm(cur, model.mm_0_w, model.mm_0_b, NORM_TYPE_NORMAL, 1e-5, -1); + cur = ggml_mul_mat(ctx0, model.mm_1_w, cur); + cur = ggml_add(ctx0, cur, model.mm_1_b); + cur = ggml_gelu(ctx0, cur); + cur = ggml_mul_mat(ctx0, model.mm_3_w, cur); + cur = ggml_add(ctx0, cur, model.mm_3_b); + } + + // build the graph + + return cur; + } + ggml_cgraph * build_llama4() { GGML_ASSERT(model.class_embedding != nullptr); GGML_ASSERT(model.position_embeddings != nullptr); @@ -2164,18 +2507,41 @@ private: return inpL; } + // build the input after conv2d (inp_raw --> patches) + // returns tensor with shape [n_embd, n_patches] + ggml_tensor * build_enc_inp(ggml_tensor * inp_raw, + const int enc_patch_size, + const int enc_n_patches, + const int enc_n_embd) { + GGML_ASSERT(model.patch_embed_proj_w != nullptr); + GGML_ASSERT(model.patch_embed_proj_b != nullptr); + // Image to Patch Embedding. + // ggml_tensor * inp_raw = build_inp_raw(); // sam shape = [1024, 1024, 3] + // patch_embed_proj_w shape = [768, 3, 16, 16] + ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embed_proj_w, inp_raw, enc_patch_size, enc_patch_size, 0, 0, + 1, 1); // [64, 64, 768] + inp = ggml_reshape_2d(ctx0, inp, enc_n_patches, enc_n_embd); // [4096, 768] + inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp)); // [768, 4096] + inp = ggml_add(ctx0, inp, model.patch_embed_proj_b); + cb(inp, "enc_patch_bias", -1); + return inp; + } + // build the input after conv2d (inp_raw --> patches) // returns tensor with shape [n_embd, n_patches] ggml_tensor * build_inp() { - ggml_tensor * inp_raw = build_inp_raw(); - ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1); - inp = ggml_reshape_2d(ctx0, inp, n_patches, n_embd); - inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp)); + // Image to Patch Embedding. + ggml_tensor * inp_raw = build_inp_raw(); // sam shape = [1024, 1024, 3] + // sam patch_embeddings_0 shape = [768, 3, 16, 16] + ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1); // sam shape = [64, 64, 768] + inp = ggml_reshape_2d(ctx0, inp, n_patches, n_embd); // sam shape = [4096, 768] + inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp)); // sam shape = [768, 4096] if (model.patch_bias) { + // sam patch_bias shape = [768] inp = ggml_add(ctx0, inp, model.patch_bias); cb(inp, "patch_bias", -1); } - return inp; + return inp; // shape = [n_embd, n_patches] same as [768, 4096] } ggml_tensor * build_inp_raw(int channels = 3) { @@ -3236,6 +3602,10 @@ struct clip_model_loader { model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight")); model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias")); } break; + case PROJECTOR_TYPE_DEEPSEEK_OCR: + { + } + break; default: GGML_ASSERT(false && "unknown projector type"); } @@ -4192,6 +4562,59 @@ private: } }; +static std::vector> ds_build_target_ratios(const int min_num, const int max_num) { + std::vector> ratios; + for (int n = min_num; n <= max_num; ++n) { + for (int i = 1; i <= n; ++i) { + for (int j = 1; j <= n; ++j) { + if (const int blocks = i * j; blocks >= min_num && blocks <= max_num) { + ratios.emplace_back(i, j); // (cols, rows) + } + } + } + } + + // sort by total blocks like in Python (key=lambda x: x[0] * x[1]) + std::sort(ratios.begin(), ratios.end(), + [](const auto &a, const auto &b) { + return (a.first * a.second) < (b.first * b.second); + }); + + // optional: dedup + ratios.erase(std::unique(ratios.begin(), ratios.end()), ratios.end()); + return ratios; +} + +static std::pair ds_find_closest_aspect_ratio( + const float aspect_ratio, + const std::vector> &target_ratios, + const int width, + const int height, + const int image_size +) { + float best_diff = std::numeric_limits::infinity(); + std::pair best_ratio = {1, 1}; + const float area = static_cast(width) * static_cast(height); + + for (const auto &r : target_ratios) { + const float target_ar = static_cast(r.first) / static_cast(r.second); + + if (const float diff = std::fabs(aspect_ratio - target_ar); diff < best_diff) { + best_diff = diff; + best_ratio = r; + } else if (diff == best_diff) { + // same as python: prefer this ratio if the image area is “large enough” + if (const float needed_area = 0.5f * image_size * image_size * r.first * r.second; area > needed_area) { + best_ratio = r; + } + } + } + + return best_ratio; // (cols, rows) +} + + + // returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector // res_imgs memory is being allocated here, previous allocations will be freed if found bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, struct clip_image_f32_batch * res_imgs) { @@ -4406,6 +4829,69 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str } } } break; + case PROJECTOR_TYPE_DEEPSEEK_OCR: + { + // configurable, or read from params + const int min_num = 2; + const int max_num = 9; + const int image_size = params.image_size; // typically 640 + const bool use_thumbnail = true; // mimic python's use_thumbnail + + // original image size + const int orig_w = original_size.width; + const int orig_h = original_size.height; + + // 1) build candidate grids (cols, rows) + auto target_ratios = ds_build_target_ratios(min_num, max_num); + + // 2) pick the grid that best matches the original aspect ratio + const float aspect_ratio = static_cast(orig_w) / static_cast(orig_h); + auto best = ds_find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_w, orig_h, image_size); + const int grid_cols = best.first; // how many tiles horizontally + const int grid_rows = best.second; // how many tiles vertically + + // 3) compute the target (forced) size — python did: + // target_width = image_size * cols + // target_height = image_size * rows + const clip_image_size refined_size{ image_size * grid_cols, image_size * grid_rows }; + + // 4) prepare slice instructions, same style as the idefics3 branch + llava_uhd::slice_instructions instructions; + instructions.overview_size = clip_image_size{ image_size, image_size }; // for thumbnail/global + instructions.refined_size = refined_size; + instructions.grid_size = clip_image_size{ grid_cols, grid_rows }; + + // in deepseek python they always produce *full* 640x640 blocks, + // so we can do a simple double loop over rows/cols: + for (int r = 0; r < grid_rows; ++r) { + for (int c = 0; c < grid_cols; ++c) { + const int x = c * image_size; + const int y = r * image_size; + + instructions.slices.push_back(llava_uhd::slice_coordinates{ + /* x */ x, + /* y */ y, + /* size */ clip_image_size{ image_size, image_size } + }); + } + } + + // 5) run the actual slicing (this should: resize to refined_size, then crop every slice) + auto imgs = llava_uhd::slice_image(img, instructions); + + // 7) cast & normalize like the idefics3 branch + for (size_t i = 0; i < imgs.size(); ++i) { + clip_image_f32_ptr res(clip_image_f32_init()); + normalize_image_u8_to_f32(*imgs[i], *res, params.image_mean, params.image_std); + res_imgs->entries.push_back(std::move(res)); + } + + // keep the grid info — the model may need to know how to reassemble / attend + res_imgs->grid_x = grid_cols; + res_imgs->grid_y = grid_rows; + } + break; + default: LOG_ERR("%s: unsupported projector type %d\n", __func__, ctx->proj_type());