support PaddleOCR-VL

This commit is contained in:
megemini 2025-12-19 12:24:55 +08:00
parent 6853bee680
commit b4cde7c7d9
12 changed files with 184 additions and 26 deletions

View File

@ -601,6 +601,11 @@ common_chat_templates_ptr common_chat_templates_init(
"{%- if false %}");
}
// TODO @ngxson : hot fix for PaddleOCR
if (default_template_src.find("<|IMAGE_PLACEHOLDER|>") != std::string::npos) {
string_replace_all(default_template_src, "<|IMAGE_START|><|IMAGE_PLACEHOLDER|><|IMAGE_END|>", "");
}
std::string token_bos = bos_token_override;
std::string token_eos = eos_token_override;
bool add_bos = false;

View File

@ -3602,7 +3602,7 @@ class LLaDAModel(TextModel):
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("Ernie4_5_ForCausalLM", "Ernie4_5ForCausalLM")
@ModelBase.register("Ernie4_5_ForCausalLM", "Ernie4_5ForCausalLM", "PaddleOCRVLForConditionalGeneration")
class Ernie4_5Model(TextModel):
model_arch = gguf.MODEL_ARCH.ERNIE4_5
@ -3618,6 +3618,10 @@ class Ernie4_5Model(TextModel):
if (head_dim := self.hparams.get("head_dim")) is None:
head_dim = self.hparams["hidden_size"] // num_heads
if "mlp_AR" in name or "vision_model" in name:
# skip vision model and projector tensors
return []
if "ernie." in name:
name = name.replace("ernie.", "model.")
# split the qkv weights
@ -3735,6 +3739,42 @@ class Ernie4_5MoeModel(Ernie4_5Model):
if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts}")
@ModelBase.register("PaddleOCRVisionModel")
class PaddleOCRVisionModel(MmprojModel):
# PaddleOCR-VL uses a modified version of Siglip
min_pixels: int = 0
max_pixels: int = 0
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert self.hparams_vision is not None
self.min_pixels = self.preprocessor_config["min_pixels"]
self.max_pixels = self.preprocessor_config["max_pixels"]
self.hparams_vision["image_size"] = int(math.sqrt(self.max_pixels))
def set_gguf_parameters(self):
super().set_gguf_parameters()
assert self.hparams_vision is not None
hparams = self.hparams_vision
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.PADDLEOCR)
self.gguf_writer.add_vision_max_pixels(self.max_pixels)
self.gguf_writer.add_vision_min_pixels(self.min_pixels)
self.gguf_writer.add_vision_use_gelu(True)
self.gguf_writer.add_vision_attention_layernorm_eps(hparams.get("rms_norm_eps", 1e-6))
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
name = name.replace("visual.", "model.")
if "vision_model" in name or "mlp_AR" in name:
if "packing_position_embedding" in name:
return [] # unused
elif "vision_model.head" in name:
# we don't yet support image embeddings for this model
return []
else:
return [(self.map_tensor_name(name), data_torch)]
return [] # skip other tensors
@ModelBase.register(
"Qwen2VLModel",

View File

@ -281,6 +281,8 @@ class Keys:
class ClipVision:
IMAGE_SIZE = "clip.vision.image_size"
MAX_PIXELS = "clip.vision.max_pixels"
MIN_PIXELS = "clip.vision.min_pixels"
PREPROC_IMAGE_SIZE = "clip.vision.preproc_image_size"
PATCH_SIZE = "clip.vision.patch_size"
EMBEDDING_LENGTH = "clip.vision.embedding_length"
@ -3360,6 +3362,7 @@ class VisionProjectorType:
VOXTRAL = "voxtral"
LFM2 = "lfm2"
KIMIVL = "kimivl"
PADDLEOCR = "paddleocr"
LIGHTONOCR = "lightonocr"
COGVLM = "cogvlm"
JANUS_PRO = "janus_pro"

View File

@ -1085,6 +1085,12 @@ class GGUFWriter:
def add_vision_embedding_length(self, value: int) -> None:
self.add_uint32(Keys.ClipVision.EMBEDDING_LENGTH, value)
def add_vision_max_pixels(self, value: int) -> None:
self.add_uint32(Keys.ClipVision.MAX_PIXELS, value)
def add_vision_min_pixels(self, value: int) -> None:
self.add_uint32(Keys.ClipVision.MIN_PIXELS, value)
def add_vision_feed_forward_length(self, value: int) -> None:
self.add_uint32(Keys.ClipVision.FEED_FORWARD_LENGTH, value)

View File

@ -1207,6 +1207,7 @@ class TensorNameMap:
MODEL_TENSOR.V_MMPROJ: (
"multi_modal_projector.linear_{bid}",
"visual.merger.mlp.{bid}", # qwen2vl
"mlp_AR.linear_{bid}", # PaddleOCR-VL
),
MODEL_TENSOR.V_MMPROJ_FC: (
@ -1432,6 +1433,7 @@ class TensorNameMap:
"multi_modal_projector.pre_norm",
"pre_mm_projector_norm",
"model.vision.linear_proj.norm1", # cogvlm
"mlp_AR.pre_norm", # PaddleOCR-VL
),
MODEL_TENSOR.V_MM_SOFT_EMB_NORM: (
@ -1456,6 +1458,7 @@ class TensorNameMap:
MODEL_TENSOR.V_RESMPL_ATTN_OUT: (
"resampler.attn.out_proj",
"model.vision_model.head.attention.out_proj",
),
MODEL_TENSOR.V_RESMPL_KV: (

View File

@ -21,6 +21,7 @@ add_library(mtmd
models/llama4.cpp
models/llava.cpp
models/minicpmv.cpp
models/paddleocr.cpp
models/pixtral.cpp
models/qwen2vl.cpp
models/qwen3vl.cpp

View File

@ -109,12 +109,14 @@
#define TN_DEEPSTACK_FC2 "v.deepstack.%d.fc2.%s" // qwen3vl deepstack
// mimicpmv
#define TN_MINICPMV_POS_EMBD_K "resampler.pos_embed_k"
#define TN_MINICPMV_QUERY "resampler.query"
#define TN_MINICPMV_PROJ "resampler.proj.weight"
#define TN_MINICPMV_KV_PROJ "resampler.kv.weight"
#define TN_MINICPMV_ATTN "resampler.attn.%s.%s"
#define TN_MINICPMV_LN "resampler.ln_%s.%s"
#define TN_RESAMPL_POS_EMBD_K "resampler.pos_embed_k"
#define TN_RESAMPL_QUERY "resampler.query"
#define TN_RESAMPL_PROJ "resampler.proj.weight"
#define TN_RESAMPL_KV_PROJ "resampler.kv.weight"
#define TN_RESAMPL_ATTN "resampler.attn.%s.%s"
#define TN_RESAMPL_LN "resampler.ln_%s.%s"
#define TN_RESAMPL_FFN_UP "resampler.ffn_up.%s"
#define TN_RESAMPL_FFN_DOWN "resampler.ffn_down.%s"
#define TN_GLM_ADAPER_CONV "adapter.conv.%s"
#define TN_GLM_ADAPTER_LINEAR "adapter.linear.linear.%s"
@ -167,6 +169,7 @@ enum projector_type {
PROJECTOR_TYPE_VOXTRAL,
PROJECTOR_TYPE_LFM2,
PROJECTOR_TYPE_KIMIVL,
PROJECTOR_TYPE_PADDLEOCR,
PROJECTOR_TYPE_LIGHTONOCR,
PROJECTOR_TYPE_COGVLM,
PROJECTOR_TYPE_JANUS_PRO,
@ -195,6 +198,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
{ PROJECTOR_TYPE_VOXTRAL, "voxtral"},
{ PROJECTOR_TYPE_LFM2, "lfm2"},
{ PROJECTOR_TYPE_KIMIVL, "kimivl"},
{ PROJECTOR_TYPE_PADDLEOCR, "paddleocr"},
{ PROJECTOR_TYPE_LIGHTONOCR,"lightonocr"},
{ PROJECTOR_TYPE_COGVLM, "cogvlm"},
{ PROJECTOR_TYPE_JANUS_PRO, "janus_pro"},

View File

@ -39,6 +39,7 @@ struct clip_hparams {
int32_t image_min_pixels = -1;
int32_t image_max_pixels = -1;
int32_t n_merge = 0; // number of patch merges **per-side**
int32_t proj_scale_factor = 0;
float image_mean[3];
float image_std[3];

View File

@ -825,6 +825,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
{
builder = std::make_unique<clip_graph_kimivl>(ctx, img);
} break;
case PROJECTOR_TYPE_PADDLEOCR:
{
builder = std::make_unique<clip_graph_paddleocr>(ctx, img);
} break;
case PROJECTOR_TYPE_COGVLM:
{
builder = std::make_unique<clip_graph_cogvlm>(ctx, img);
@ -1187,6 +1191,10 @@ struct clip_model_loader {
hparams.audio_window_len = 400;
hparams.audio_hop_len = 160;
} break;
case PROJECTOR_TYPE_PADDLEOCR:
{
hparams.proj_scale_factor = 2;
} break;
default:
break;
}
@ -1432,25 +1440,25 @@ struct clip_model_loader {
} break;
case PROJECTOR_TYPE_MINICPMV:
{
// model.mm_model_pos_embed = get_tensor(new_clip->ctx_data, TN_MINICPMV_POS_EMBD);
model.mm_model_pos_embed_k = get_tensor(TN_MINICPMV_POS_EMBD_K);
model.mm_model_query = get_tensor(TN_MINICPMV_QUERY);
model.mm_model_proj = get_tensor(TN_MINICPMV_PROJ);
model.mm_model_kv_proj = get_tensor(TN_MINICPMV_KV_PROJ);
model.mm_model_attn_q_w = get_tensor(string_format(TN_MINICPMV_ATTN, "q", "weight"));
model.mm_model_attn_k_w = get_tensor(string_format(TN_MINICPMV_ATTN, "k", "weight"));
model.mm_model_attn_v_w = get_tensor(string_format(TN_MINICPMV_ATTN, "v", "weight"));
model.mm_model_attn_q_b = get_tensor(string_format(TN_MINICPMV_ATTN, "q", "bias"));
model.mm_model_attn_k_b = get_tensor(string_format(TN_MINICPMV_ATTN, "k", "bias"));
model.mm_model_attn_v_b = get_tensor(string_format(TN_MINICPMV_ATTN, "v", "bias"));
model.mm_model_attn_o_w = get_tensor(string_format(TN_MINICPMV_ATTN, "out", "weight"));
model.mm_model_attn_o_b = get_tensor(string_format(TN_MINICPMV_ATTN, "out", "bias"));
model.mm_model_ln_q_w = get_tensor(string_format(TN_MINICPMV_LN, "q", "weight"));
model.mm_model_ln_q_b = get_tensor(string_format(TN_MINICPMV_LN, "q", "bias"));
model.mm_model_ln_kv_w = get_tensor(string_format(TN_MINICPMV_LN, "kv", "weight"));
model.mm_model_ln_kv_b = get_tensor(string_format(TN_MINICPMV_LN, "kv", "bias"));
model.mm_model_ln_post_w = get_tensor(string_format(TN_MINICPMV_LN, "post", "weight"));
model.mm_model_ln_post_b = get_tensor(string_format(TN_MINICPMV_LN, "post", "bias"));
// model.mm_model_pos_embed = get_tensor(new_clip->ctx_data, TN_RESAMPL_POS_EMBD);
model.mm_model_pos_embed_k = get_tensor(TN_RESAMPL_POS_EMBD_K);
model.mm_model_query = get_tensor(TN_RESAMPL_QUERY);
model.mm_model_proj = get_tensor(TN_RESAMPL_PROJ);
model.mm_model_kv_proj = get_tensor(TN_RESAMPL_KV_PROJ);
model.mm_model_attn_q_w = get_tensor(string_format(TN_RESAMPL_ATTN, "q", "weight"));
model.mm_model_attn_k_w = get_tensor(string_format(TN_RESAMPL_ATTN, "k", "weight"));
model.mm_model_attn_v_w = get_tensor(string_format(TN_RESAMPL_ATTN, "v", "weight"));
model.mm_model_attn_q_b = get_tensor(string_format(TN_RESAMPL_ATTN, "q", "bias"));
model.mm_model_attn_k_b = get_tensor(string_format(TN_RESAMPL_ATTN, "k", "bias"));
model.mm_model_attn_v_b = get_tensor(string_format(TN_RESAMPL_ATTN, "v", "bias"));
model.mm_model_attn_o_w = get_tensor(string_format(TN_RESAMPL_ATTN, "out", "weight"));
model.mm_model_attn_o_b = get_tensor(string_format(TN_RESAMPL_ATTN, "out", "bias"));
model.mm_model_ln_q_w = get_tensor(string_format(TN_RESAMPL_LN, "q", "weight"));
model.mm_model_ln_q_b = get_tensor(string_format(TN_RESAMPL_LN, "q", "bias"));
model.mm_model_ln_kv_w = get_tensor(string_format(TN_RESAMPL_LN, "kv", "weight"));
model.mm_model_ln_kv_b = get_tensor(string_format(TN_RESAMPL_LN, "kv", "bias"));
model.mm_model_ln_post_w = get_tensor(string_format(TN_RESAMPL_LN, "post", "weight"));
model.mm_model_ln_post_b = get_tensor(string_format(TN_RESAMPL_LN, "post", "bias"));
} break;
case PROJECTOR_TYPE_GLM_EDGE:
{
@ -1505,6 +1513,7 @@ struct clip_model_loader {
} break;
case PROJECTOR_TYPE_LFM2:
case PROJECTOR_TYPE_KIMIVL:
case PROJECTOR_TYPE_PADDLEOCR:
{
model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM);
model.mm_input_norm_b = get_tensor(TN_MM_INP_NORM_B);
@ -2701,6 +2710,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
} break;
case PROJECTOR_TYPE_PIXTRAL:
case PROJECTOR_TYPE_PADDLEOCR:
case PROJECTOR_TYPE_LIGHTONOCR:
{
GGML_ASSERT(params.image_min_pixels > 0 && params.image_max_pixels > 0);
@ -2956,6 +2966,13 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
int y_patch = CLIP_ALIGN(img->ny, out_patch_size) / out_patch_size;
n_patches = x_patch * y_patch;
} break;
case PROJECTOR_TYPE_PADDLEOCR:
{
// dynamic size
int scale_factor = ctx->model.hparams.proj_scale_factor;
int stride = scale_factor * scale_factor;
n_patches = CLIP_ALIGN(n_patches, stride) / stride;
} break;
case PROJECTOR_TYPE_PIXTRAL:
case PROJECTOR_TYPE_LIGHTONOCR:
{
@ -3285,6 +3302,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
} break;
case PROJECTOR_TYPE_PIXTRAL:
case PROJECTOR_TYPE_KIMIVL:
case PROJECTOR_TYPE_PADDLEOCR:
case PROJECTOR_TYPE_LIGHTONOCR:
{
// set the 2D positions
@ -3453,6 +3471,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
return ctx->model.mm_2_w->ne[1];
case PROJECTOR_TYPE_LFM2:
case PROJECTOR_TYPE_KIMIVL:
case PROJECTOR_TYPE_PADDLEOCR:
return ctx->model.mm_2_w->ne[1];
case PROJECTOR_TYPE_COGVLM:
return ctx->model.mm_4h_to_h_w->ne[1];

View File

@ -42,6 +42,11 @@ struct clip_graph_kimivl : clip_graph {
ggml_cgraph * build() override;
};
struct clip_graph_paddleocr : clip_graph {
clip_graph_paddleocr(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
ggml_cgraph * build() override;
};
struct clip_graph_cogvlm : clip_graph {
clip_graph_cogvlm(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
ggml_cgraph * build() override;

View File

@ -0,0 +1,67 @@
#include "models.h"
ggml_cgraph * clip_graph_paddleocr::build() {
// 2D input positions
ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches);
ggml_set_name(pos_h, "pos_h");
ggml_set_input(pos_h);
ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches);
ggml_set_name(pos_w, "pos_w");
ggml_set_input(pos_w);
ggml_tensor * learned_pos_embd = resize_position_embeddings();
// build ViT with 2D position embeddings
auto add_pos = [&](ggml_tensor * cur, const clip_layer &) {
// first half is X axis and second half is Y axis
return build_rope_2d(ctx0, cur, pos_w, pos_h, hparams.rope_theta, false);
};
ggml_tensor * inp = build_inp();
ggml_tensor * cur = build_vit(
inp, n_patches,
NORM_TYPE_NORMAL,
hparams.ffn_op,
learned_pos_embd,
add_pos);
cb(cur, "vit_out", -1);
{
// mlp_AR
float proj_norm_eps = 1e-5; // PaddleOCR uses hard-coded value eps=1e-5 for Projector
cur = build_norm(cur,
model.mm_input_norm_w, model.mm_input_norm_b,
NORM_TYPE_NORMAL, proj_norm_eps, -1);
//cur = build_patch_merge_permute(cur, hparams.proj_scale_factor);
// stack and padding
int64_t stride = hparams.proj_scale_factor * hparams.proj_scale_factor;
int64_t n_embd = cur->ne[0];
int64_t n_tokens = cur->ne[1];
int64_t n_tokens_padded = CLIP_ALIGN(n_tokens, stride);
int64_t n_pad = n_tokens_padded - n_tokens;
if (n_pad > 0) {
cur = ggml_view_1d(ctx0, cur, ggml_nelements(cur), 0);
cur = ggml_pad(ctx0, cur, n_pad * n_embd, 0, 0, 0);
}
cur = ggml_view_2d(ctx0, cur,
n_embd * stride,
n_tokens_padded / stride,
ggml_row_size(cur->type, n_embd * stride), 0);
cb(cur, "after_stacked", -1);
cur = build_ffn(cur,
model.mm_1_w, model.mm_1_b,
nullptr, nullptr,
model.mm_2_w, model.mm_2_b,
hparams.ffn_op, -1);
cb(cur, "mlp_out", -1);
}
// build the graph
ggml_build_forward_expand(gf, cur);
return gf;
}

View File

@ -313,6 +313,10 @@ struct mtmd_context {
img_beg = "<|begin_of_image|>";
img_end = "<|end_of_image|>";
} else if (proj == PROJECTOR_TYPE_PADDLEOCR) {
// <|IMAGE_START|> ... (image embeddings) ... <|IMAGE_END|>
img_beg = "<|IMAGE_START|>";
img_end = "<|IMAGE_END|>";
}
}