Merge branch 'sf/deepseek-ocr' of github.com:sfallah/llama.cpp into sf/deepseek-ocr

This commit is contained in:
bluebread 2025-11-20 13:36:07 +00:00
commit 1268dc3fd1
4 changed files with 86 additions and 39 deletions

View File

@ -91,7 +91,7 @@
#define TN_MM_INP_NORM_B "mm.input_norm.bias"
#define TN_MM_INP_PROJ "mm.input_projection.weight" // gemma3
#define TN_MM_SOFT_EMB_N "mm.soft_emb_norm.weight" // gemma3
#define TN_MM_PROJECTOR "mm.model.fc.weight" // idefics3
#define TN_MM_PROJECTOR "mm.model.fc.%s" // idefics3, deepseekocr
#define TN_MM_PATCH_MERGER "mm.patch_merger.weight" // mistral small 3.1
#define TN_TOK_IMG_BREAK "v.token_embd.img_break" // pixtral
#define TN_TOK_GLM_BOI "adapter.boi" // glm-edge (these embeddings are not in text model)

View File

@ -316,7 +316,8 @@ struct clip_model {
ggml_tensor * post_ln_w;
ggml_tensor * post_ln_b;
ggml_tensor * projection; // TODO: rename it to fc (fully connected layer)
ggml_tensor * fc_w;
ggml_tensor * fc_b;
ggml_tensor * mm_fc_w;
ggml_tensor * mm_fc_b;
@ -623,7 +624,7 @@ struct clip_graph {
// https://github.com/huggingface/transformers/blob/0a950e0bbe1ed58d5401a6b547af19f15f0c195e/src/transformers/models/idefics3/modeling_idefics3.py#L578
const int scale_factor = model.hparams.n_merge;
cur = build_patch_merge_permute(cur, scale_factor);
cur = ggml_mul_mat(ctx0, model.projection, cur);
cur = ggml_mul_mat(ctx0, model.fc_w, cur);
} else if (ctx->proj_type() == PROJECTOR_TYPE_LFM2) {
// pixel unshuffle block
@ -689,7 +690,8 @@ struct clip_graph {
if (hparams.is_global_attn(il) == false) {
// local attention layer - apply window partition
// ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L169-L172
cur = ggml_win_part(ctx0, cur, 14);
//cur = ggml_win_part(ctx0, cur, 14);
cur = window_partition(ctx0, cur, 14);
}
const int64_t W = cur->ne[1];
@ -761,7 +763,7 @@ struct clip_graph {
if (hparams.is_global_attn(il) == false) {
// local attention layer - reverse window partition
cur = ggml_win_unpart(ctx0, cur, w0, h0, 14);
cur = window_unpartition(ctx0, cur, w0, h0, 14);
}
// re-add the layer input, e.g., residual
@ -844,15 +846,12 @@ struct clip_graph {
ggml_row_size(global_features_2->type, n_embd), 0);
ggml_tensor * global_features = ggml_concat(ctx0, global_features_2, global_features_1, 1);
global_features = build_global_local_features(
ctx0,
global_features,
n_patches_y,
n_patches_x,
n_embd
);
global_features = ggml_reshape_2d(ctx0, global_features, 2* n_embd, n_patches);
global_features = ggml_cont(ctx0, global_features);
global_features = ggml_mul_mat(ctx0, model.fc_w, global_features);
global_features = ggml_add(ctx0, global_features, model.fc_b);
global_features = build_global_local_features(ctx0,global_features);
ggml_build_forward_expand(gf, global_features);
return gf;
}
@ -861,41 +860,32 @@ struct clip_graph {
// view_separator: [n_dim]
ggml_tensor * build_global_local_features(ggml_context * ctx0,
ggml_tensor * global_features,
int h,
int w,
int n_dim) {
ggml_tensor * global_features) {
GGML_ASSERT(model.image_newline != nullptr);
GGML_ASSERT(model.view_seperator != nullptr);
GGML_ASSERT(global_features->ne[0] == static_cast<int64_t>(n_dim));
GGML_ASSERT(global_features->ne[1] == static_cast<int64_t>(2 * (h * w)));
// 1) global_features: [n_dim, h*w] -> [n_dim, w, h] -> [h, w, n_dim]
ggml_tensor * t = ggml_reshape_3d(ctx0, global_features, n_dim, w, h); // (n_dim, w, h)
t = ggml_permute(ctx0, t, 2, 1, 0, 3); // (h, w, n_dim)
ggml_tensor * t = ggml_reshape_4d(ctx0, global_features, 1280, 64, 64, 1); // (n_dim, w, h)
t = ggml_cont(ctx0, ggml_permute(ctx0, t, 2, 1, 0, 3)); // (h, w, n_dim)
ggml_tensor * nl = ggml_cont(ctx0,ggml_permute(ctx0, model.image_newline, 2, 1, 0, 3));
nl = ggml_repeat_4d(ctx0, nl, 64, 1, 1280, 1); // n_pos rows
nl = ggml_cont(ctx0, nl);
// 2) image_newline: [n_dim] -> [1, 1, n_dim] -> repeat to [h, 1, n_dim]
ggml_tensor * nl = ggml_reshape_3d(ctx0, model.image_newline, 1, 1, n_dim); // (1, 1, n_dim)
ggml_tensor * nl_target_shape = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, h, n_dim); // (1, h, n_dim)
nl = ggml_repeat(ctx0, nl, nl_target_shape); // (1, h, n_dim)
nl = ggml_permute(ctx0, nl, 1, 0, 2, 3); // (h, 1, n_dim)
// 3) concat along width dimension (dim=1): (h, w, n_dim) + (h, 1, n_dim) -> (h, w+1, n_dim)
t = ggml_concat(ctx0, t, nl, 1); // (h, w+1, n_dim)
// 4) flatten back to token axis: (h, w+1, n_dim) -> (n_dim, h*(w+1))
t = ggml_permute(ctx0, t, 2, 1, 0, 3); // (n_dim, w+1, h)
t = ggml_cont_2d(ctx0, t, n_dim, (w + 1) * h); // (n_dim, h*(w+1))
t = ggml_reshape_2d(ctx0, t, 1280, 64 * (64 + 1)); // (n_dim, h*(w+1))
// 5) append view_separator as an extra "token":
// view_separator: [n_dim] -> [n_dim, 1]
ggml_tensor * vs = ggml_reshape_2d(ctx0, model.view_seperator, n_dim, 1); // (n_dim, 1)
ggml_tensor * vs = ggml_reshape_2d(ctx0, model.view_seperator, 1280, 1); // (n_dim, 1)
// concat along token dimension (dim=1):
ggml_tensor * global_local_features = ggml_concat(ctx0, t, vs, 1); // (n_dim, h*(w+1) + 1)
t = ggml_concat(ctx0, t, vs, 1); // (n_dim, h*(w+1) + 1)
return global_local_features;
return t;
}
@ -2476,6 +2466,46 @@ private:
return inpL;
}
static ggml_tensor* window_partition(ggml_context* ctx, ggml_tensor* x, int window) {
auto [c, w, h, b] = x->ne;
// same as
// x = ggml_win_part(m, x, window);
// x = ggml_reshape_3d(m, x, c, window * window, x->ne[3]);
int64_t px = (window - w % window) % window;
int64_t py = (window - h % window) % window;
int64_t npw = (w + px) / window;
int64_t nph = (h + py) / window;
if (px > 0 || py > 0) {
x = ggml_pad(ctx, x, 0, int(px), int(py), 0);
}
x = ggml_reshape_4d(ctx, x, c * window, npw, window, nph * b);
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3));
x = ggml_reshape_4d(ctx, x, c, window ,window, npw * nph * b);
return x;
}
static ggml_tensor* window_unpartition(ggml_context* m, ggml_tensor* x, int w, int h, int window) {
int64_t c = x->ne[0];
// same as
// x = ggml_reshape_4d(m, x, c, window, window, x->ne[2]);
// x = ggml_win_unpart(m, x, w, h, window);
int64_t px = (window - w % window) % window;
int64_t py = (window - h % window) % window;
int64_t npw = (w + px) / window;
int64_t nph = (h + py) / window;
int64_t b = x->ne[3] / (npw * nph);
x = ggml_reshape_4d(m, x, c * window, window, npw, nph * b);
x = ggml_cont(m, ggml_permute(m, x, 0, 2, 1, 3));
x = ggml_reshape_4d(m, x, c, w + px, h + py, b);
x = ggml_view_4d(m, x, x->ne[0], w, h, x->ne[3], x->nb[1], x->nb[2], x->nb[3], 0);
x = ggml_cont(m, x);
return x;
}
// build the input after conv2d (inp_raw --> patches)
// returns tensor with shape [n_embd, n_patches]
ggml_tensor * build_enc_inp(ggml_tensor * inp_raw,
@ -3488,7 +3518,7 @@ struct clip_model_loader {
} break;
case PROJECTOR_TYPE_IDEFICS3:
{
model.projection = get_tensor(TN_MM_PROJECTOR);
model.fc_w = get_tensor(string_format(TN_MM_PROJECTOR, "weight"));
} break;
case PROJECTOR_TYPE_LFM2:
case PROJECTOR_TYPE_KIMIVL:
@ -3561,13 +3591,13 @@ struct clip_model_loader {
} break;
case PROJECTOR_TYPE_LLAMA4:
{
model.mm_model_proj = get_tensor(TN_MM_PROJECTOR);
model.mm_model_proj = get_tensor(string_format(TN_MM_PROJECTOR, "weight"));
model.mm_model_mlp_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight"));
model.mm_model_mlp_2_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 2, "weight"));
} break;
case PROJECTOR_TYPE_COGVLM:
{
model.mm_model_proj = get_tensor(TN_MM_PROJECTOR);
model.mm_model_proj = get_tensor(string_format(TN_MM_PROJECTOR, "weight"));
model.mm_post_fc_norm_w = get_tensor(string_format(TN_MM_POST_FC_NORM, "weight"));
model.mm_post_fc_norm_b = get_tensor(string_format(TN_MM_POST_FC_NORM, "bias"));
model.mm_h_to_4h_w = get_tensor(string_format(TN_MM_H_TO_4H, "weight"));
@ -3617,6 +3647,9 @@ struct clip_model_loader {
}
model.image_newline = get_tensor(TN_IMAGE_NEWLINE, false);
model.view_seperator = get_tensor(TN_IMAGE_SEPERATOR, false);
model.fc_w = get_tensor(string_format(TN_MM_PROJECTOR, "weight"));
model.fc_b = get_tensor(string_format(TN_MM_PROJECTOR, "bias"));
break;
default:
@ -5086,6 +5119,10 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
{
n_patches += 2; // for BOI and EOI token embeddings
} break;
case PROJECTOR_TYPE_DEEPSEEKOCR:
{
n_patches += 2;
} break;
default:
GGML_ABORT("unsupported projector type");
}
@ -5417,6 +5454,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
case PROJECTOR_TYPE_VOXTRAL:
case PROJECTOR_TYPE_JANUS_PRO:
case PROJECTOR_TYPE_COGVLM:
case PROJECTOR_TYPE_DEEPSEEKOCR:
{
// do nothing
} break;
@ -5512,7 +5550,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
case PROJECTOR_TYPE_GEMMA3:
return ctx->model.mm_input_proj_w->ne[0];
case PROJECTOR_TYPE_IDEFICS3:
return ctx->model.projection->ne[1];
return ctx->model.fc_w->ne[1];
case PROJECTOR_TYPE_ULTRAVOX:
case PROJECTOR_TYPE_VOXTRAL:
return ctx->model.mm_2_w->ne[1];
@ -5527,6 +5565,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
return ctx->model.mm_2_w->ne[1];
case PROJECTOR_TYPE_COGVLM:
return ctx->model.mm_4h_to_h_w->ne[1];
case PROJECTOR_TYPE_DEEPSEEKOCR:
return ctx->model.fc_w->ne[1];
default:
GGML_ABORT("Unknown projector type");
}
@ -5557,6 +5597,10 @@ bool clip_is_gemma3(const struct clip_ctx * ctx) {
return ctx->proj_type() == PROJECTOR_TYPE_GEMMA3;
}
bool clip_is_deepseekocr(const struct clip_ctx * ctx) {
return ctx->proj_type() == PROJECTOR_TYPE_DEEPSEEKOCR;
}
bool clip_has_vision_encoder(const struct clip_ctx * ctx) {
return ctx->model.modality == CLIP_MODALITY_VISION;
}

View File

@ -105,6 +105,8 @@ bool clip_is_glm(const struct clip_ctx * ctx);
bool clip_is_qwen2vl(const struct clip_ctx * ctx);
bool clip_is_llava(const struct clip_ctx * ctx);
bool clip_is_gemma3(const struct clip_ctx * ctx);
bool clip_is_deepseekocr(const struct clip_ctx * ctx);
bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec);

View File

@ -810,7 +810,8 @@ int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens)
if (clip_is_llava(ctx_clip)
|| clip_is_minicpmv(ctx_clip)
|| clip_is_glm(ctx_clip)) {
|| clip_is_glm(ctx_clip)
|| clip_is_deepseekocr(ctx_clip)) {
// TODO @ngxson : llava does not support batched encoding ; this should be fixed inside clip_image_batch_encode()
const auto & entries = image_tokens->batch_f32.entries;
for (size_t i = 0; i < entries.size(); i++) {