[update] add paddleocr vl text model instead of ernie4.5

This commit is contained in:
megemini 2026-01-12 21:50:49 +08:00
parent 64f0a46e1c
commit fbfa906910
13 changed files with 219 additions and 43 deletions

View File

@ -601,11 +601,6 @@ common_chat_templates_ptr common_chat_templates_init(
"{%- if false %}");
}
// TODO @ngxson : hot fix for PaddleOCR
if (default_template_src.find("<|IMAGE_PLACEHOLDER|>") != std::string::npos) {
string_replace_all(default_template_src, "<|IMAGE_START|><|IMAGE_PLACEHOLDER|><|IMAGE_END|>", "");
}
std::string token_bos = bos_token_override;
std::string token_eos = eos_token_override;
bool add_bos = false;

View File

@ -3602,13 +3602,20 @@ class LLaDAModel(TextModel):
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("Ernie4_5_ForCausalLM", "Ernie4_5ForCausalLM", "PaddleOCRVLForConditionalGeneration")
@ModelBase.register("Ernie4_5_ForCausalLM", "Ernie4_5ForCausalLM")
class Ernie4_5Model(TextModel):
model_arch = gguf.MODEL_ARCH.ERNIE4_5
def set_vocab(self):
self._set_vocab_sentencepiece()
tokenizer_config_file = self.dir_model / 'tokenizer_config.json'
if tokenizer_config_file.is_file():
with open(tokenizer_config_file, "r", encoding="utf-8") as f:
tokenizer_config_json = json.load(f)
if "add_prefix_space" in tokenizer_config_json:
self.gguf_writer.add_add_space_prefix(tokenizer_config_json["add_prefix_space"])
def set_gguf_parameters(self):
super().set_gguf_parameters()
@ -3739,6 +3746,12 @@ class Ernie4_5MoeModel(Ernie4_5Model):
if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts}")
@ModelBase.register("PaddleOCRVLForConditionalGeneration")
class PaddleOCRModel(Ernie4_5Model):
model_arch = gguf.MODEL_ARCH.PADDLEOCR
@ModelBase.register("PaddleOCRVisionModel")
class PaddleOCRVisionModel(MmprojModel):
# PaddleOCR-VL uses a modified version of Siglip
@ -3776,6 +3789,7 @@ class PaddleOCRVisionModel(MmprojModel):
return [(self.map_tensor_name(name), data_torch)]
return [] # skip other tensors
@ModelBase.register(
"Qwen2VLModel",
"Qwen2VLForConditionalGeneration",

View File

@ -449,6 +449,7 @@ class MODEL_ARCH(IntEnum):
RND1 = auto()
PANGU_EMBED = auto()
MISTRAL3 = auto()
PADDLEOCR = auto()
class VISION_PROJECTOR_TYPE(IntEnum):
@ -826,6 +827,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.RND1: "rnd1",
MODEL_ARCH.PANGU_EMBED: "pangu-embedded",
MODEL_ARCH.MISTRAL3: "mistral3",
MODEL_ARCH.PADDLEOCR: "paddleocr",
}
VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
@ -2827,6 +2829,20 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.PADDLEOCR: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.FALCON_H1: [
# Token embedding
MODEL_TENSOR.TOKEN_EMBD,

View File

@ -60,6 +60,7 @@ add_library(llama
models/dream.cpp
models/ernie4-5-moe.cpp
models/ernie4-5.cpp
models/paddleocr.cpp
models/exaone.cpp
models/exaone4.cpp
models/falcon-h1.cpp

View File

@ -114,6 +114,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_RND1, "rnd1" },
{ LLM_ARCH_PANGU_EMBED, "pangu-embedded" },
{ LLM_ARCH_MISTRAL3, "mistral3" },
{ LLM_ARCH_PADDLEOCR, "paddleocr" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};
@ -700,6 +701,7 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
case LLM_ARCH_INTERNLM2:
case LLM_ARCH_GRANITE:
case LLM_ARCH_ERNIE4_5:
case LLM_ARCH_PADDLEOCR:
case LLM_ARCH_SMOLLM3:
case LLM_ARCH_DREAM:
case LLM_ARCH_LLADA:

View File

@ -118,6 +118,7 @@ enum llm_arch {
LLM_ARCH_RND1,
LLM_ARCH_PANGU_EMBED,
LLM_ARCH_MISTRAL3,
LLM_ARCH_PADDLEOCR,
LLM_ARCH_UNKNOWN,
};

View File

@ -2056,7 +2056,11 @@ void llama_model::load_hparams(llama_model_loader & ml) {
} break;
case LLM_ARCH_ERNIE4_5:
case LLM_ARCH_ERNIE4_5_MOE:
case LLM_ARCH_PADDLEOCR:
{
// paddleocr need mrope_section
ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
if (arch == LLM_ARCH_ERNIE4_5_MOE) {
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
@ -5964,6 +5968,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
} break;
case LLM_ARCH_ERNIE4_5:
case LLM_ARCH_ERNIE4_5_MOE:
case LLM_ARCH_PADDLEOCR:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@ -7574,6 +7579,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
{
llm = std::make_unique<llm_build_ernie4_5_moe>(*this, params);
} break;
case LLM_ARCH_PADDLEOCR:
{
llm = std::make_unique<llm_build_paddleocr>(*this, params);
} break;
case LLM_ARCH_HUNYUAN_MOE:
{
llm = std::make_unique<llm_build_hunyuan_moe>(*this, params);
@ -7866,6 +7875,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
return LLAMA_ROPE_TYPE_NEOX;
case LLM_ARCH_QWEN2VL:
case LLM_ARCH_PADDLEOCR:
return LLAMA_ROPE_TYPE_MROPE;
case LLM_ARCH_QWEN3VL:
case LLM_ARCH_QWEN3VLMOE:

View File

@ -2359,6 +2359,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|| t.first == "<|call|>" // o200k_harmony
|| t.first == "<end_of_turn>"
|| t.first == "<|endoftext|>"
|| t.first == "</s>" // paddleocr
|| t.first == "<|eom_id|>"
|| t.first == "<EOT>"
|| t.first == "_<EOT>"

View File

@ -158,6 +158,10 @@ struct llm_build_ernie4_5_moe : public llm_graph_context {
llm_build_ernie4_5_moe(const llama_model & model, const llm_graph_params & params);
};
struct llm_build_paddleocr : public llm_graph_context {
llm_build_paddleocr(const llama_model & model, const llm_graph_params & params);
};
template <bool iswa>
struct llm_build_exaone4 : public llm_graph_context {
llm_build_exaone4(const llama_model & model, const llm_graph_params & params);

119
src/models/paddleocr.cpp Normal file
View File

@ -0,0 +1,119 @@
#include "models.h"
llm_build_paddleocr::llm_build_paddleocr(const llama_model & model, const llm_graph_params & params) :
llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
ggml_tensor * cur;
ggml_tensor * inpL;
inpL = build_inp_embd(model.tok_embd);
int sections[4];
std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * inp_out_ids = build_inp_out_ids();
for (int il = 0; il < n_layer; ++il) {
ggml_tensor * inpSA = inpL;
// norm
{
cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
cb(cur, "attn_norm", il);
}
// self-attention
{
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
if (model.layers[il].bq) {
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
cb(Qcur, "Qcur", il);
}
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
cb(Kcur, "Kcur", il);
if (model.layers[il].bk) {
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
cb(Kcur, "Kcur", il);
}
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);
if (model.layers[il].bv) {
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
cb(Vcur, "Vcur", il);
}
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
Qcur = ggml_rope_multi(
ctx0, Qcur, inp_pos, nullptr,
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
Kcur = ggml_rope_multi(
ctx0, Kcur, inp_pos, nullptr,
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
cur = build_attn(inp_attn,
model.layers[il].wo, NULL,
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il);
}
if (il == n_layer - 1) {
// skip computing output for unused tokens
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
// feed-forward network
{
cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il);
cb(cur, "ffn_norm", il);
cur = build_ffn(cur,
model.layers[il].ffn_up, NULL, NULL,
model.layers[il].ffn_gate, NULL, NULL,
model.layers[il].ffn_down, NULL, NULL,
NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(cur, "ffn_out", il);
}
cur = ggml_add(ctx0, cur, ffn_inp);
cur = build_cvec(cur, il);
cb(cur, "l_out", il);
// input for next layer
inpL = cur;
}
cur = inpL;
cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
cb(cur, "result_norm", -1);
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cb(cur, "result_output", -1);
res->t_logits = cur;
ggml_build_forward_expand(gf, cur);
}

View File

@ -39,7 +39,6 @@ struct clip_hparams {
int32_t image_min_pixels = -1;
int32_t image_max_pixels = -1;
int32_t n_merge = 0; // number of patch merges **per-side**
int32_t proj_scale_factor = 0;
float image_mean[3];
float image_std[3];

View File

@ -1193,9 +1193,10 @@ struct clip_model_loader {
} break;
case PROJECTOR_TYPE_PADDLEOCR:
{
hparams.proj_scale_factor = 2;
hparams.set_limit_image_tokens(8, 1024);
hparams.set_warmup_n_tokens(256); // avoid OOM on warmup
hparams.n_merge = 2;
// TODO(megemini): paddleocr vl not specified?
hparams.set_limit_image_tokens(8, 4096);
hparams.set_warmup_n_tokens(28*28); // avoid OOM on warmup
} break;
default:
break;
@ -1460,7 +1461,7 @@ struct clip_model_loader {
model.mm_model_ln_kv_w = get_tensor(string_format(TN_RESAMPL_LN, "kv", "weight"));
model.mm_model_ln_kv_b = get_tensor(string_format(TN_RESAMPL_LN, "kv", "bias"));
model.mm_model_ln_post_w = get_tensor(string_format(TN_RESAMPL_LN, "post", "weight"));
model.mm_model_ln_post_b = get_tensor(string_format(TN_RESAMPL_LN, "post", "bias"));
model.mm_model_ln_post_b = get_tensor(string_format(TN_RESAMPL_LN, "post", "bias"));
} break;
case PROJECTOR_TYPE_GLM_EDGE:
{
@ -2869,6 +2870,7 @@ int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 *
case PROJECTOR_TYPE_QWEN25VL:
case PROJECTOR_TYPE_QWEN3VL:
case PROJECTOR_TYPE_GLM4V:
case PROJECTOR_TYPE_PADDLEOCR:
return (img->nx / params.patch_size) / 2;
default:
break;
@ -2884,6 +2886,7 @@ int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 *
case PROJECTOR_TYPE_QWEN25VL:
case PROJECTOR_TYPE_QWEN3VL:
case PROJECTOR_TYPE_GLM4V:
case PROJECTOR_TYPE_PADDLEOCR:
return (img->ny / params.patch_size) / 2;
default:
break;
@ -2971,8 +2974,8 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
case PROJECTOR_TYPE_PADDLEOCR:
{
// dynamic size
int scale_factor = ctx->model.hparams.proj_scale_factor;
int stride = scale_factor * scale_factor;
int n_merge = ctx->model.hparams.n_merge;
int stride = n_merge * n_merge;
n_patches = CLIP_ALIGN(n_patches, stride) / stride;
} break;
case PROJECTOR_TYPE_PIXTRAL:
@ -3219,6 +3222,29 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
}
}
set_input_i32("positions", positions);
} break;
case PROJECTOR_TYPE_PADDLEOCR:
{
const int merge_ratio = hparams.n_merge;
const int pw = image_size_width / patch_size;
const int ph = image_size_height / patch_size;
std::vector<int> positions(n_pos * 4);
int ptr = 0;
for (int dy = 0; dy < 2; dy++) {
for (int y = 0; y < ph; y += merge_ratio) {
for (int x = 0; x < pw; x += merge_ratio) {
for (int dx = 0; dx < 2; dx++) {
positions[ ptr] = y + dy;
positions[ num_patches + ptr] = x + dx;
positions[2 * num_patches + ptr] = y + dy;
positions[3 * num_patches + ptr] = x + dx;
ptr++;
}
}
}
}
set_input_i32("positions", positions);
} break;
case PROJECTOR_TYPE_QWEN25VL:
@ -3304,7 +3330,6 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
} break;
case PROJECTOR_TYPE_PIXTRAL:
case PROJECTOR_TYPE_KIMIVL:
case PROJECTOR_TYPE_PADDLEOCR:
case PROJECTOR_TYPE_LIGHTONOCR:
{
// set the 2D positions
@ -3499,6 +3524,7 @@ bool clip_is_mrope(const struct clip_ctx * ctx) {
return ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL
|| ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL
|| ctx->proj_type() == PROJECTOR_TYPE_QWEN3VL
|| ctx->proj_type() == PROJECTOR_TYPE_PADDLEOCR
|| ctx->proj_type() == PROJECTOR_TYPE_GLM4V;
}

View File

@ -1,23 +1,23 @@
#include "models.h"
ggml_cgraph * clip_graph_paddleocr::build() {
// 2D input positions
ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches);
ggml_set_name(pos_h, "pos_h");
ggml_set_input(pos_h);
const int n_pos = n_patches;
const int num_position_ids = n_pos * 4; // m-rope requires 4 dim per position
ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches);
ggml_set_name(pos_w, "pos_w");
ggml_set_input(pos_w);
int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
ggml_tensor * learned_pos_embd = resize_position_embeddings();
ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_position_ids);
ggml_set_name(positions, "positions");
ggml_set_input(positions);
// build ViT with 2D position embeddings
auto add_pos = [&](ggml_tensor * cur, const clip_layer &) {
// first half is X axis and second half is Y axis
return build_rope_2d(ctx0, cur, pos_w, pos_h, hparams.rope_theta, false);
return ggml_rope_multi(
ctx0, cur, positions, nullptr,
d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION,
32768, 10000, 1, 0, 1, 32, 1);
};
ggml_tensor * learned_pos_embd = resize_position_embeddings();
ggml_tensor * inp = build_inp();
ggml_tensor * cur = build_vit(
inp, n_patches,
@ -29,28 +29,16 @@ ggml_cgraph * clip_graph_paddleocr::build() {
cb(cur, "vit_out", -1);
{
// mlp_AR
float proj_norm_eps = 1e-5; // PaddleOCR uses hard-coded value eps=1e-5 for Projector
// mlp_AR paddleocr projector
float proj_norm_eps = 1e-5;
cur = build_norm(cur,
model.mm_input_norm_w, model.mm_input_norm_b,
NORM_TYPE_NORMAL, proj_norm_eps, -1);
//cur = build_patch_merge_permute(cur, hparams.proj_scale_factor);
// stack and padding
int64_t stride = hparams.proj_scale_factor * hparams.proj_scale_factor;
int64_t n_embd = cur->ne[0];
int64_t n_tokens = cur->ne[1];
int64_t n_tokens_padded = CLIP_ALIGN(n_tokens, stride);
int64_t n_pad = n_tokens_padded - n_tokens;
if (n_pad > 0) {
cur = ggml_view_1d(ctx0, cur, ggml_nelements(cur), 0);
cur = ggml_pad(ctx0, cur, n_pad * n_embd, 0, 0, 0);
}
cur = ggml_view_2d(ctx0, cur,
n_embd * stride,
n_tokens_padded / stride,
ggml_row_size(cur->type, n_embd * stride), 0);
cb(cur, "after_stacked", -1);
const int scale_factor = model.hparams.n_merge;
int width = img.nx / patch_size;
int height = img.ny / patch_size;
cur = ggml_reshape_3d(ctx0, cur, n_embd * scale_factor * scale_factor, width / scale_factor * height / scale_factor, 1);
cur = build_ffn(cur,
model.mm_1_w, model.mm_1_b,