From 63a042f21e19c90d51645741fbd18f2d14c9864f Mon Sep 17 00:00:00 2001 From: Saba Fallah <10401143+sfallah@users.noreply.github.com> Date: Tue, 18 Nov 2025 09:43:11 +0100 Subject: [PATCH 1/3] concat image_newline and image_seperator tokens --- tools/mtmd/clip-impl.h | 2 +- tools/mtmd/clip.cpp | 67 ++++++++++++++++++++---------------------- 2 files changed, 33 insertions(+), 36 deletions(-) diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index ba094cc25b..63d5905566 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -91,7 +91,7 @@ #define TN_MM_INP_NORM_B "mm.input_norm.bias" #define TN_MM_INP_PROJ "mm.input_projection.weight" // gemma3 #define TN_MM_SOFT_EMB_N "mm.soft_emb_norm.weight" // gemma3 -#define TN_MM_PROJECTOR "mm.model.fc.weight" // idefics3 +#define TN_MM_PROJECTOR "mm.model.fc.%s" // idefics3, deepseekocr #define TN_MM_PATCH_MERGER "mm.patch_merger.weight" // mistral small 3.1 #define TN_TOK_IMG_BREAK "v.token_embd.img_break" // pixtral #define TN_TOK_GLM_BOI "adapter.boi" // glm-edge (these embeddings are not in text model) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 8bd1eef4bf..99b5ab45d9 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -316,7 +316,8 @@ struct clip_model { ggml_tensor * post_ln_w; ggml_tensor * post_ln_b; - ggml_tensor * projection; // TODO: rename it to fc (fully connected layer) + ggml_tensor * fc_w; + ggml_tensor * fc_b; ggml_tensor * mm_fc_w; ggml_tensor * mm_fc_b; @@ -623,7 +624,7 @@ struct clip_graph { // https://github.com/huggingface/transformers/blob/0a950e0bbe1ed58d5401a6b547af19f15f0c195e/src/transformers/models/idefics3/modeling_idefics3.py#L578 const int scale_factor = model.hparams.n_merge; cur = build_patch_merge_permute(cur, scale_factor); - cur = ggml_mul_mat(ctx0, model.projection, cur); + cur = ggml_mul_mat(ctx0, model.fc_w, cur); } else if (ctx->proj_type() == PROJECTOR_TYPE_LFM2) { // pixel unshuffle block @@ -844,15 +845,12 @@ struct clip_graph { ggml_row_size(global_features_2->type, n_embd), 0); ggml_tensor * global_features = ggml_concat(ctx0, global_features_2, global_features_1, 1); - global_features = build_global_local_features( - ctx0, - global_features, - n_patches_y, - n_patches_x, - n_embd - ); + global_features = ggml_reshape_2d(ctx0, global_features, 2* n_embd, n_patches); + global_features = ggml_cont(ctx0, global_features); + global_features = ggml_mul_mat(ctx0, model.fc_w, global_features); + global_features = ggml_add(ctx0, global_features, model.fc_b); + global_features = build_global_local_features(ctx0,global_features); ggml_build_forward_expand(gf, global_features); - return gf; } @@ -861,41 +859,31 @@ struct clip_graph { // view_separator: [n_dim] ggml_tensor * build_global_local_features(ggml_context * ctx0, - ggml_tensor * global_features, - int h, - int w, - int n_dim) { + ggml_tensor * global_features) { GGML_ASSERT(model.image_newline != nullptr); GGML_ASSERT(model.view_seperator != nullptr); - GGML_ASSERT(global_features->ne[0] == static_cast(n_dim)); - GGML_ASSERT(global_features->ne[1] == static_cast(2 * (h * w))); // 1) global_features: [n_dim, h*w] -> [n_dim, w, h] -> [h, w, n_dim] - ggml_tensor * t = ggml_reshape_3d(ctx0, global_features, n_dim, w, h); // (n_dim, w, h) - t = ggml_permute(ctx0, t, 2, 1, 0, 3); // (h, w, n_dim) + ggml_tensor * t = ggml_reshape_4d(ctx0, global_features, 1280, 64, 64, 1); // (n_dim, w, h) + t = ggml_permute(ctx0, t, 2, 1, 0, 3); // (h, w, n_dim) + ggml_tensor * nl = ggml_permute(ctx0, model.image_newline, 2, 1, 0, 3); + nl = ggml_repeat_4d(ctx0, nl, 64, 1, 1280, 1); // n_pos rows + // 2) image_newline: [n_dim] -> [1, 1, n_dim] -> repeat to [h, 1, n_dim] - ggml_tensor * nl = ggml_reshape_3d(ctx0, model.image_newline, 1, 1, n_dim); // (1, 1, n_dim) - - ggml_tensor * nl_target_shape = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, h, n_dim); // (1, h, n_dim) - nl = ggml_repeat(ctx0, nl, nl_target_shape); // (1, h, n_dim) - nl = ggml_permute(ctx0, nl, 1, 0, 2, 3); // (h, 1, n_dim) - - // 3) concat along width dimension (dim=1): (h, w, n_dim) + (h, 1, n_dim) -> (h, w+1, n_dim) t = ggml_concat(ctx0, t, nl, 1); // (h, w+1, n_dim) - // 4) flatten back to token axis: (h, w+1, n_dim) -> (n_dim, h*(w+1)) - t = ggml_permute(ctx0, t, 2, 1, 0, 3); // (n_dim, w+1, h) - t = ggml_cont_2d(ctx0, t, n_dim, (w + 1) * h); // (n_dim, h*(w+1)) + t = ggml_reshape_2d(ctx0, t, 1280, 64 * (64 + 1)); // (n_dim, h*(w+1)) + // 5) append view_separator as an extra "token": // view_separator: [n_dim] -> [n_dim, 1] - ggml_tensor * vs = ggml_reshape_2d(ctx0, model.view_seperator, n_dim, 1); // (n_dim, 1) + ggml_tensor * vs = ggml_reshape_2d(ctx0, model.view_seperator, 1280, 1); // (n_dim, 1) // concat along token dimension (dim=1): - ggml_tensor * global_local_features = ggml_concat(ctx0, t, vs, 1); // (n_dim, h*(w+1) + 1) + t = ggml_concat(ctx0, t, vs, 1); // (n_dim, h*(w+1) + 1) - return global_local_features; + return t; } @@ -3488,7 +3476,7 @@ struct clip_model_loader { } break; case PROJECTOR_TYPE_IDEFICS3: { - model.projection = get_tensor(TN_MM_PROJECTOR); + model.fc_w = get_tensor(string_format(TN_MM_PROJECTOR, "weight")); } break; case PROJECTOR_TYPE_LFM2: case PROJECTOR_TYPE_KIMIVL: @@ -3561,13 +3549,13 @@ struct clip_model_loader { } break; case PROJECTOR_TYPE_LLAMA4: { - model.mm_model_proj = get_tensor(TN_MM_PROJECTOR); + model.mm_model_proj = get_tensor(string_format(TN_MM_PROJECTOR, "weight")); model.mm_model_mlp_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight")); model.mm_model_mlp_2_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 2, "weight")); } break; case PROJECTOR_TYPE_COGVLM: { - model.mm_model_proj = get_tensor(TN_MM_PROJECTOR); + model.mm_model_proj = get_tensor(string_format(TN_MM_PROJECTOR, "weight")); model.mm_post_fc_norm_w = get_tensor(string_format(TN_MM_POST_FC_NORM, "weight")); model.mm_post_fc_norm_b = get_tensor(string_format(TN_MM_POST_FC_NORM, "bias")); model.mm_h_to_4h_w = get_tensor(string_format(TN_MM_H_TO_4H, "weight")); @@ -3617,6 +3605,9 @@ struct clip_model_loader { } model.image_newline = get_tensor(TN_IMAGE_NEWLINE, false); model.view_seperator = get_tensor(TN_IMAGE_SEPERATOR, false); + model.fc_w = get_tensor(string_format(TN_MM_PROJECTOR, "weight")); + model.fc_b = get_tensor(string_format(TN_MM_PROJECTOR, "bias")); + break; default: @@ -5086,6 +5077,10 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im { n_patches += 2; // for BOI and EOI token embeddings } break; + case PROJECTOR_TYPE_DEEPSEEKOCR: + { + n_patches += 2; + } break; default: GGML_ABORT("unsupported projector type"); } @@ -5512,7 +5507,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { case PROJECTOR_TYPE_GEMMA3: return ctx->model.mm_input_proj_w->ne[0]; case PROJECTOR_TYPE_IDEFICS3: - return ctx->model.projection->ne[1]; + return ctx->model.fc_w->ne[1]; case PROJECTOR_TYPE_ULTRAVOX: case PROJECTOR_TYPE_VOXTRAL: return ctx->model.mm_2_w->ne[1]; @@ -5527,6 +5522,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { return ctx->model.mm_2_w->ne[1]; case PROJECTOR_TYPE_COGVLM: return ctx->model.mm_4h_to_h_w->ne[1]; + case PROJECTOR_TYPE_DEEPSEEKOCR: + return ctx->model.fc_w->ne[1]; default: GGML_ABORT("Unknown projector type"); } From 89afda8da90024aaf908448a2bb8dafee739934c Mon Sep 17 00:00:00 2001 From: Saba Fallah <10401143+sfallah@users.noreply.github.com> Date: Tue, 18 Nov 2025 10:26:32 +0100 Subject: [PATCH 2/3] visual_model warmup (technically) works --- tools/mtmd/clip.cpp | 5 +++++ tools/mtmd/clip.h | 2 ++ tools/mtmd/mtmd.cpp | 3 ++- 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 99b5ab45d9..797f921f50 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -5412,6 +5412,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima case PROJECTOR_TYPE_VOXTRAL: case PROJECTOR_TYPE_JANUS_PRO: case PROJECTOR_TYPE_COGVLM: + case PROJECTOR_TYPE_DEEPSEEKOCR: { // do nothing } break; @@ -5554,6 +5555,10 @@ bool clip_is_gemma3(const struct clip_ctx * ctx) { return ctx->proj_type() == PROJECTOR_TYPE_GEMMA3; } +bool clip_is_deepseekocr(const struct clip_ctx * ctx) { + return ctx->proj_type() == PROJECTOR_TYPE_DEEPSEEKOCR; +} + bool clip_has_vision_encoder(const struct clip_ctx * ctx) { return ctx->model.modality == CLIP_MODALITY_VISION; } diff --git a/tools/mtmd/clip.h b/tools/mtmd/clip.h index 3e4c985f11..458ee98fc7 100644 --- a/tools/mtmd/clip.h +++ b/tools/mtmd/clip.h @@ -105,6 +105,8 @@ bool clip_is_glm(const struct clip_ctx * ctx); bool clip_is_qwen2vl(const struct clip_ctx * ctx); bool clip_is_llava(const struct clip_ctx * ctx); bool clip_is_gemma3(const struct clip_ctx * ctx); +bool clip_is_deepseekocr(const struct clip_ctx * ctx); + bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec); diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index e599137769..16349e8f40 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -810,7 +810,8 @@ int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens) if (clip_is_llava(ctx_clip) || clip_is_minicpmv(ctx_clip) - || clip_is_glm(ctx_clip)) { + || clip_is_glm(ctx_clip) + || clip_is_deepseekocr(ctx_clip)) { // TODO @ngxson : llava does not support batched encoding ; this should be fixed inside clip_image_batch_encode() const auto & entries = image_tokens->batch_f32.entries; for (size_t i = 0; i < entries.size(); i++) { From 88032f46b1cf496670fb029dfcfa071ea2e31e02 Mon Sep 17 00:00:00 2001 From: Saba Fallah <10401143+sfallah@users.noreply.github.com> Date: Thu, 20 Nov 2025 10:07:54 +0100 Subject: [PATCH 3/3] window partitioning using standard ggml ops --- tools/mtmd/clip.cpp | 50 +++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 46 insertions(+), 4 deletions(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 797f921f50..40b60cbfd5 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -690,7 +690,8 @@ struct clip_graph { if (hparams.is_global_attn(il) == false) { // local attention layer - apply window partition // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L169-L172 - cur = ggml_win_part(ctx0, cur, 14); + //cur = ggml_win_part(ctx0, cur, 14); + cur = window_partition(ctx0, cur, 14); } const int64_t W = cur->ne[1]; @@ -762,7 +763,7 @@ struct clip_graph { if (hparams.is_global_attn(il) == false) { // local attention layer - reverse window partition - cur = ggml_win_unpart(ctx0, cur, w0, h0, 14); + cur = window_unpartition(ctx0, cur, w0, h0, 14); } // re-add the layer input, e.g., residual @@ -865,9 +866,10 @@ struct clip_graph { // 1) global_features: [n_dim, h*w] -> [n_dim, w, h] -> [h, w, n_dim] ggml_tensor * t = ggml_reshape_4d(ctx0, global_features, 1280, 64, 64, 1); // (n_dim, w, h) - t = ggml_permute(ctx0, t, 2, 1, 0, 3); // (h, w, n_dim) - ggml_tensor * nl = ggml_permute(ctx0, model.image_newline, 2, 1, 0, 3); + t = ggml_cont(ctx0, ggml_permute(ctx0, t, 2, 1, 0, 3)); // (h, w, n_dim) + ggml_tensor * nl = ggml_cont(ctx0,ggml_permute(ctx0, model.image_newline, 2, 1, 0, 3)); nl = ggml_repeat_4d(ctx0, nl, 64, 1, 1280, 1); // n_pos rows + nl = ggml_cont(ctx0, nl); // 2) image_newline: [n_dim] -> [1, 1, n_dim] -> repeat to [h, 1, n_dim] @@ -2464,6 +2466,46 @@ private: return inpL; } + static ggml_tensor* window_partition(ggml_context* ctx, ggml_tensor* x, int window) { + auto [c, w, h, b] = x->ne; + // same as + // x = ggml_win_part(m, x, window); + // x = ggml_reshape_3d(m, x, c, window * window, x->ne[3]); + + int64_t px = (window - w % window) % window; + int64_t py = (window - h % window) % window; + int64_t npw = (w + px) / window; + int64_t nph = (h + py) / window; + + if (px > 0 || py > 0) { + x = ggml_pad(ctx, x, 0, int(px), int(py), 0); + } + x = ggml_reshape_4d(ctx, x, c * window, npw, window, nph * b); + x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); + x = ggml_reshape_4d(ctx, x, c, window ,window, npw * nph * b); + return x; + } + + static ggml_tensor* window_unpartition(ggml_context* m, ggml_tensor* x, int w, int h, int window) { + int64_t c = x->ne[0]; + // same as + // x = ggml_reshape_4d(m, x, c, window, window, x->ne[2]); + // x = ggml_win_unpart(m, x, w, h, window); + + int64_t px = (window - w % window) % window; + int64_t py = (window - h % window) % window; + int64_t npw = (w + px) / window; + int64_t nph = (h + py) / window; + + int64_t b = x->ne[3] / (npw * nph); + x = ggml_reshape_4d(m, x, c * window, window, npw, nph * b); + x = ggml_cont(m, ggml_permute(m, x, 0, 2, 1, 3)); + x = ggml_reshape_4d(m, x, c, w + px, h + py, b); + x = ggml_view_4d(m, x, x->ne[0], w, h, x->ne[3], x->nb[1], x->nb[2], x->nb[3], 0); + x = ggml_cont(m, x); + return x; + } + // build the input after conv2d (inp_raw --> patches) // returns tensor with shape [n_embd, n_patches] ggml_tensor * build_enc_inp(ggml_tensor * inp_raw,