diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index eb43520f98..9a2ceed1dc 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -3693,6 +3693,13 @@ class Ernie4_5Model(TextModel): def set_vocab(self): self._set_vocab_sentencepiece() + tokenizer_config_file = self.dir_model / 'tokenizer_config.json' + if tokenizer_config_file.is_file(): + with open(tokenizer_config_file, "r", encoding="utf-8") as f: + tokenizer_config_json = json.load(f) + if "add_prefix_space" in tokenizer_config_json: + self.gguf_writer.add_add_space_prefix(tokenizer_config_json["add_prefix_space"]) + def set_gguf_parameters(self): super().set_gguf_parameters() @@ -3702,6 +3709,10 @@ class Ernie4_5Model(TextModel): if (head_dim := self.hparams.get("head_dim")) is None: head_dim = self.hparams["hidden_size"] // num_heads + if "mlp_AR" in name or "vision_model" in name: + # skip vision model and projector tensors + return [] + if "ernie." in name: name = name.replace("ernie.", "model.") # split the qkv weights @@ -3811,6 +3822,49 @@ class Ernie4_5MoeModel(Ernie4_5Model): raise ValueError(f"Unprocessed experts: {experts}") +@ModelBase.register("PaddleOCRVLForConditionalGeneration") +class PaddleOCRModel(Ernie4_5Model): + model_arch = gguf.MODEL_ARCH.PADDLEOCR + + +@ModelBase.register("PaddleOCRVisionModel") +class PaddleOCRVisionModel(MmprojModel): + # PaddleOCR-VL uses a modified version of Siglip + min_pixels: int = 0 + max_pixels: int = 0 + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + assert self.hparams_vision is not None + self.min_pixels = self.preprocessor_config["min_pixels"] + self.max_pixels = self.preprocessor_config["max_pixels"] + self.hparams_vision["image_size"] = int(math.sqrt(self.max_pixels)) + + def set_gguf_parameters(self): + super().set_gguf_parameters() + assert self.hparams_vision is not None + hparams = self.hparams_vision + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.PADDLEOCR) + self.gguf_writer.add_vision_max_pixels(self.max_pixels) + self.gguf_writer.add_vision_min_pixels(self.min_pixels) + self.gguf_writer.add_vision_use_gelu(True) + self.gguf_writer.add_vision_attention_layernorm_eps(hparams.get("rms_norm_eps", 1e-6)) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + name = name.replace("visual.", "model.") + + if "vision_model" in name or "mlp_AR" in name: + if "packing_position_embedding" in name: + return [] # unused + elif "vision_model.head" in name: + # we don't yet support image embeddings for this model + return [] + else: + return [(self.map_tensor_name(name), data_torch)] + return [] # skip other tensors + + @ModelBase.register( "Qwen2VLModel", "Qwen2VLForConditionalGeneration", diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 31273b2b5a..26cc306184 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -284,6 +284,8 @@ class Keys: class ClipVision: PROJECTOR_TYPE = "clip.vision.projector_type" # for mixed modality models IMAGE_SIZE = "clip.vision.image_size" + IMAGE_MIN_PIXELS = "clip.vision.image_min_pixels" + IMAGE_MAX_PIXELS = "clip.vision.image_max_pixels" PREPROC_IMAGE_SIZE = "clip.vision.preproc_image_size" PATCH_SIZE = "clip.vision.patch_size" EMBEDDING_LENGTH = "clip.vision.embedding_length" @@ -456,6 +458,7 @@ class MODEL_ARCH(IntEnum): RND1 = auto() PANGU_EMBED = auto() MISTRAL3 = auto() + PADDLEOCR = auto() MIMO2 = auto() LLAMA_EMBED = auto() MAINCODER = auto() @@ -877,6 +880,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.RND1: "rnd1", MODEL_ARCH.PANGU_EMBED: "pangu-embedded", MODEL_ARCH.MISTRAL3: "mistral3", + MODEL_ARCH.PADDLEOCR: "paddleocr", MODEL_ARCH.MIMO2: "mimo2", MODEL_ARCH.LLAMA_EMBED: "llama-embed", MODEL_ARCH.MAINCODER: "maincoder", @@ -3016,6 +3020,20 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.PADDLEOCR: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], MODEL_ARCH.FALCON_H1: [ # Token embedding MODEL_TENSOR.TOKEN_EMBD, @@ -3610,6 +3628,7 @@ class VisionProjectorType: VOXTRAL = "voxtral" LFM2 = "lfm2" KIMIVL = "kimivl" + PADDLEOCR = "paddleocr" LIGHTONOCR = "lightonocr" COGVLM = "cogvlm" JANUS_PRO = "janus_pro" diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 7fbb78866b..39cb03191e 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -1098,6 +1098,12 @@ class GGUFWriter: def add_vision_embedding_length(self, value: int) -> None: self.add_uint32(Keys.ClipVision.EMBEDDING_LENGTH, value) + def add_vision_max_pixels(self, value: int) -> None: + self.add_uint32(Keys.ClipVision.IMAGE_MAX_PIXELS, value) + + def add_vision_min_pixels(self, value: int) -> None: + self.add_uint32(Keys.ClipVision.IMAGE_MIN_PIXELS, value) + def add_vision_feed_forward_length(self, value: int) -> None: self.add_uint32(Keys.ClipVision.FEED_FORWARD_LENGTH, value) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 84aa868809..bdd0af6800 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -1256,6 +1256,7 @@ class TensorNameMap: MODEL_TENSOR.V_MMPROJ: ( "multi_modal_projector.linear_{bid}", "visual.merger.mlp.{bid}", # qwen2vl + "mlp_AR.linear_{bid}", # PaddleOCR-VL "merger.mlp.{bid}", ), @@ -1492,6 +1493,7 @@ class TensorNameMap: "multi_modal_projector.pre_norm", "pre_mm_projector_norm", "model.vision.linear_proj.norm1", # cogvlm + "mlp_AR.pre_norm", # PaddleOCR-VL "merger.ln_q", ), @@ -1517,6 +1519,7 @@ class TensorNameMap: MODEL_TENSOR.V_RESMPL_ATTN_OUT: ( "resampler.attn.out_proj", + "model.vision_model.head.attention.out_proj", ), MODEL_TENSOR.V_RESMPL_KV: ( diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index f337afd6b3..f68f014ee6 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -61,6 +61,7 @@ add_library(llama models/dream.cpp models/ernie4-5-moe.cpp models/ernie4-5.cpp + models/paddleocr.cpp models/exaone.cpp models/exaone4.cpp models/exaone-moe.cpp diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index a54bc1956a..b016788ee9 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -117,6 +117,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_RND1, "rnd1" }, { LLM_ARCH_PANGU_EMBED, "pangu-embedded" }, { LLM_ARCH_MISTRAL3, "mistral3" }, + { LLM_ARCH_PADDLEOCR, "paddleocr" }, { LLM_ARCH_MIMO2, "mimo2" }, { LLM_ARCH_LLAMA_EMBED, "llama-embed" }, { LLM_ARCH_MAINCODER, "maincoder" }, @@ -710,6 +711,7 @@ static std::set llm_get_tensor_names(llm_arch arch) { case LLM_ARCH_INTERNLM2: case LLM_ARCH_GRANITE: case LLM_ARCH_ERNIE4_5: + case LLM_ARCH_PADDLEOCR: case LLM_ARCH_SMOLLM3: case LLM_ARCH_DREAM: case LLM_ARCH_LLADA: diff --git a/src/llama-arch.h b/src/llama-arch.h index 270d28b16a..658785929d 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -121,6 +121,7 @@ enum llm_arch { LLM_ARCH_RND1, LLM_ARCH_PANGU_EMBED, LLM_ARCH_MISTRAL3, + LLM_ARCH_PADDLEOCR, LLM_ARCH_MIMO2, LLM_ARCH_LLAMA_EMBED, LLM_ARCH_MAINCODER, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 72490a89b5..59827bd49f 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2175,7 +2175,11 @@ void llama_model::load_hparams(llama_model_loader & ml) { } break; case LLM_ARCH_ERNIE4_5: case LLM_ARCH_ERNIE4_5_MOE: + case LLM_ARCH_PADDLEOCR: { + // paddleocr need mrope_section + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); if (arch == LLM_ARCH_ERNIE4_5_MOE) { ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); @@ -6276,6 +6280,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } break; case LLM_ARCH_ERNIE4_5: case LLM_ARCH_ERNIE4_5_MOE: + case LLM_ARCH_PADDLEOCR: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -8021,6 +8026,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_PADDLEOCR: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_HUNYUAN_MOE: { llm = std::make_unique(*this, params); @@ -8333,6 +8342,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { return LLAMA_ROPE_TYPE_NEOX; case LLM_ARCH_QWEN2VL: + case LLM_ARCH_PADDLEOCR: return LLAMA_ROPE_TYPE_MROPE; case LLM_ARCH_QWEN3VL: case LLM_ARCH_QWEN3VLMOE: diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index a23950d007..6063cc5290 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -2418,6 +2418,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "<|calls|>" // solar-open || t.first == "" || t.first == "<|endoftext|>" + || t.first == "" // paddleocr || t.first == "<|eom_id|>" || t.first == "" || t.first == "_" diff --git a/src/models/models.h b/src/models/models.h index 3a44f7f140..eedbee2c46 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -158,6 +158,10 @@ struct llm_build_ernie4_5_moe : public llm_graph_context { llm_build_ernie4_5_moe(const llama_model & model, const llm_graph_params & params); }; +struct llm_build_paddleocr : public llm_graph_context { + llm_build_paddleocr(const llama_model & model, const llm_graph_params & params); +}; + template struct llm_build_exaone4 : public llm_graph_context { llm_build_exaone4(const llama_model & model, const llm_graph_params & params); diff --git a/src/models/paddleocr.cpp b/src/models/paddleocr.cpp new file mode 100644 index 0000000000..1f6336eb97 --- /dev/null +++ b/src/models/paddleocr.cpp @@ -0,0 +1,119 @@ +#include "models.h" + +llm_build_paddleocr::llm_build_paddleocr(const llama_model & model, const llm_graph_params & params) : + llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + { + cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + } + // self-attention + { + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_multi( + ctx0, Qcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_multi( + ctx0, Kcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); + } + if (il == n_layer - 1) { + // skip computing output for unused tokens + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + { + cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + cur = inpL; + + cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/tools/mtmd/CMakeLists.txt b/tools/mtmd/CMakeLists.txt index 751440af32..a0777e1b47 100644 --- a/tools/mtmd/CMakeLists.txt +++ b/tools/mtmd/CMakeLists.txt @@ -22,6 +22,7 @@ add_library(mtmd models/llama4.cpp models/llava.cpp models/minicpmv.cpp + models/paddleocr.cpp models/pixtral.cpp models/qwen2vl.cpp models/qwen3vl.cpp diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index dd693623a2..5258ae841b 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -36,6 +36,8 @@ // vision-specific #define KEY_VISION_PROJ_TYPE "clip.vision.projector_type" // for models with mixed modalities #define KEY_IMAGE_SIZE "clip.vision.image_size" +#define KEY_IMAGE_MIN_PIXELS "clip.vision.image_min_pixels" +#define KEY_IMAGE_MAX_PIXELS "clip.vision.image_max_pixels" #define KEY_PREPROC_IMAGE_SIZE "clip.vision.preproc_image_size" #define KEY_PATCH_SIZE "clip.vision.patch_size" #define KEY_IMAGE_MEAN "clip.vision.image_mean" @@ -227,6 +229,7 @@ enum projector_type { PROJECTOR_TYPE_MUSIC_FLAMINGO, PROJECTOR_TYPE_LFM2, PROJECTOR_TYPE_KIMIVL, + PROJECTOR_TYPE_PADDLEOCR, PROJECTOR_TYPE_LIGHTONOCR, PROJECTOR_TYPE_COGVLM, PROJECTOR_TYPE_JANUS_PRO, @@ -260,6 +263,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_MUSIC_FLAMINGO, "musicflamingo"}, { PROJECTOR_TYPE_LFM2, "lfm2"}, { PROJECTOR_TYPE_KIMIVL, "kimivl"}, + { PROJECTOR_TYPE_PADDLEOCR, "paddleocr"}, { PROJECTOR_TYPE_LIGHTONOCR,"lightonocr"}, { PROJECTOR_TYPE_COGVLM, "cogvlm"}, { PROJECTOR_TYPE_JANUS_PRO, "janus_pro"}, diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 9fa5afc390..01765b0b5a 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -825,6 +825,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 { builder = std::make_unique(ctx, img); } break; + case PROJECTOR_TYPE_PADDLEOCR: + { + builder = std::make_unique(ctx, img); + } break; case PROJECTOR_TYPE_COGVLM: { builder = std::make_unique(ctx, img); @@ -1220,6 +1224,14 @@ struct clip_model_loader { hparams.audio_window_len = 400; hparams.audio_hop_len = 160; } break; + case PROJECTOR_TYPE_PADDLEOCR: + { + hparams.n_merge = 2; + get_u32(KEY_IMAGE_MIN_PIXELS, hparams.image_min_pixels); + get_u32(KEY_IMAGE_MAX_PIXELS, hparams.image_max_pixels); + + hparams.set_warmup_n_tokens(28*28); // avoid OOM on warmup + } break; case PROJECTOR_TYPE_LFM2A: { // audio preprocessing params @@ -1668,6 +1680,7 @@ struct clip_model_loader { model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias")); } break; case PROJECTOR_TYPE_KIMIVL: + case PROJECTOR_TYPE_PADDLEOCR: { model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM); model.mm_input_norm_b = get_tensor(TN_MM_INP_NORM_B); @@ -2987,6 +3000,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str } break; case PROJECTOR_TYPE_PIXTRAL: + case PROJECTOR_TYPE_PADDLEOCR: case PROJECTOR_TYPE_LIGHTONOCR: { GGML_ASSERT(params.image_min_pixels > 0 && params.image_max_pixels > 0); @@ -3143,6 +3157,7 @@ int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 * case PROJECTOR_TYPE_QWEN25VL: case PROJECTOR_TYPE_QWEN3VL: case PROJECTOR_TYPE_GLM4V: + case PROJECTOR_TYPE_PADDLEOCR: case PROJECTOR_TYPE_YOUTUVL: return (img->nx / params.patch_size) / 2; default: @@ -3159,6 +3174,7 @@ int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 * case PROJECTOR_TYPE_QWEN25VL: case PROJECTOR_TYPE_QWEN3VL: case PROJECTOR_TYPE_GLM4V: + case PROJECTOR_TYPE_PADDLEOCR: case PROJECTOR_TYPE_YOUTUVL: return (img->ny / params.patch_size) / 2; default: @@ -3254,6 +3270,13 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im int y_patch = CLIP_ALIGN(img->ny, out_patch_size) / out_patch_size; n_patches = x_patch * y_patch; } break; + case PROJECTOR_TYPE_PADDLEOCR: + { + // dynamic size + int n_merge = ctx->model.hparams.n_merge; + int stride = n_merge * n_merge; + n_patches = CLIP_ALIGN(n_patches, stride) / stride; + } break; case PROJECTOR_TYPE_PIXTRAL: case PROJECTOR_TYPE_LIGHTONOCR: { @@ -3501,6 +3524,29 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } } + set_input_i32("positions", positions); + } break; + case PROJECTOR_TYPE_PADDLEOCR: + { + const int merge_ratio = hparams.n_merge; + const int pw = image_size_width / patch_size; + const int ph = image_size_height / patch_size; + std::vector positions(n_pos * 4); + int ptr = 0; + for (int y = 0; y < ph; y += merge_ratio) { + for (int dy = 0; dy < 2; dy++) { + for (int x = 0; x < pw; x += merge_ratio) { + for (int dx = 0; dx < 2; dx++) { + positions[ ptr] = y + dy; + positions[ num_patches + ptr] = x + dx; + positions[2 * num_patches + ptr] = y + dy; + positions[3 * num_patches + ptr] = x + dx; + ptr++; + } + } + } + } + set_input_i32("positions", positions); } break; case PROJECTOR_TYPE_QWEN25VL: @@ -3770,6 +3816,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { return ctx->model.mm_2_w->ne[1]; case PROJECTOR_TYPE_LFM2: case PROJECTOR_TYPE_KIMIVL: + case PROJECTOR_TYPE_PADDLEOCR: return ctx->model.mm_2_w->ne[1]; case PROJECTOR_TYPE_COGVLM: return ctx->model.mm_4h_to_h_w->ne[1]; diff --git a/tools/mtmd/models/models.h b/tools/mtmd/models/models.h index 9970980c7b..61faf68826 100644 --- a/tools/mtmd/models/models.h +++ b/tools/mtmd/models/models.h @@ -52,6 +52,11 @@ struct clip_graph_kimivl : clip_graph { ggml_cgraph * build() override; }; +struct clip_graph_paddleocr : clip_graph { + clip_graph_paddleocr(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} + ggml_cgraph * build() override; +}; + struct clip_graph_cogvlm : clip_graph { clip_graph_cogvlm(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} ggml_cgraph * build() override; diff --git a/tools/mtmd/models/paddleocr.cpp b/tools/mtmd/models/paddleocr.cpp new file mode 100644 index 0000000000..5d3a13fb57 --- /dev/null +++ b/tools/mtmd/models/paddleocr.cpp @@ -0,0 +1,52 @@ +#include "models.h" + +ggml_cgraph * clip_graph_paddleocr::build() { + const int n_pos = n_patches; + const int num_position_ids = n_pos * 4; // m-rope requires 4 dim per position + + int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4}; + + ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_position_ids); + ggml_set_name(positions, "positions"); + ggml_set_input(positions); + + auto add_pos = [&](ggml_tensor * cur, const clip_layer &) { + return ggml_rope_multi( + ctx0, cur, positions, nullptr, + d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, + 32768, 10000, 1, 0, 1, 32, 1); + }; + + ggml_tensor * learned_pos_embd = resize_position_embeddings(); + ggml_tensor * inp = build_inp(); + ggml_tensor * cur = build_vit( + inp, n_patches, + NORM_TYPE_NORMAL, + hparams.ffn_op, + learned_pos_embd, + add_pos); + + cb(cur, "vit_out", -1); + + { + // mlp_AR paddleocr projector + float proj_norm_eps = 1e-5; + cur = build_norm(cur, + model.mm_input_norm_w, model.mm_input_norm_b, + NORM_TYPE_NORMAL, proj_norm_eps, -1); + + const int scale_factor = model.hparams.n_merge; + cur = build_patch_merge_permute(cur, scale_factor); + cur = build_ffn(cur, + model.mm_1_w, model.mm_1_b, + nullptr, nullptr, + model.mm_2_w, model.mm_2_b, + hparams.ffn_op, -1); + cb(cur, "mlp_out", -1); + } + + // build the graph + ggml_build_forward_expand(gf, cur); + + return gf; +} diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index d037e834f3..4688586bd5 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -314,6 +314,10 @@ struct mtmd_context { img_beg = "<|begin_of_image|>"; img_end = "<|end_of_image|>"; + } else if (proj == PROJECTOR_TYPE_PADDLEOCR) { + // <|IMAGE_START|> ... (image embeddings) ... <|IMAGE_END|> + img_beg = "<|IMAGE_START|>"; + img_end = "<|IMAGE_END|>"; } } @@ -877,6 +881,7 @@ bool mtmd_decode_use_mrope(mtmd_context * ctx) { case PROJECTOR_TYPE_QWEN25VL: case PROJECTOR_TYPE_QWEN3VL: case PROJECTOR_TYPE_GLM4V: + case PROJECTOR_TYPE_PADDLEOCR: return true; default: return false;