From 386ba479a2c2c24b9ba54d9988cd9ca2986d7d6f Mon Sep 17 00:00:00 2001 From: Saba Fallah <10401143+sfallah@users.noreply.github.com> Date: Thu, 4 Dec 2025 15:05:58 +0100 Subject: [PATCH 1/5] clean up --- convert_hf_to_gguf.py | 50 +------------------- examples/eval-callback/eval-callback.cpp | 18 ++++---- gguf-py/gguf/gguf_writer.py | 4 +- gguf-py/gguf/tensor_mapping.py | 7 +-- tools/mtmd/clip.cpp | 58 +++++++++++------------- 5 files changed, 42 insertions(+), 95 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 7ce5816144..8602776bb1 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1579,15 +1579,7 @@ class MmprojModel(ModelBase): # TODO @ngxson : this is a hack to support both vision and audio encoders have_multiple_encoders = self.has_audio_encoder and self.has_vision_encoder - self.block_count = 128 if have_multiple_encoders else self.find_hparam(self.n_block_keys, True) - # FIXME: DeepseekOCRVisionModel specific hack - if self.block_count is None: - if isinstance(self, DeepseekOCRVisionModel): - clip_block_count = self.hparams['layers'] - if clip_block_count is not None: - self.block_count = clip_block_count - if self.block_count is None: - raise KeyError(f"could not find block count using any of: {self.n_block_keys}") + self.block_count = 128 if have_multiple_encoders else self.find_hparam(self.n_block_keys) self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.MMPROJ, self.block_count) # load preprocessor config @@ -6003,16 +5995,6 @@ class Gemma3VisionModel(MmprojModel): @ModelBase.register("DeepseekOCRForCausalLM") class DeepseekOCRVisionModel(MmprojModel): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - proc_fname = self.dir_model / "processor_config.json" - - if proc_fname.is_file(): - with open(proc_fname, "r") as f: - self.preprocessor_config = json.load(f) - - def set_gguf_parameters(self): super().set_gguf_parameters() hparams = self.hparams @@ -6071,27 +6053,6 @@ class DeepseekOCRVisionModel(MmprojModel): if ".attn.rel_pos_h" in name or ".attn.rel_pos_w" in name: return [(self.map_tensor_name(name, try_suffixes=("",)), data_torch)] - if name.startswith("model.vision_model.transformer.layers."): - # process visual tensors - # split QKV tensors if needed - if ".qkv_proj." in name: - if data_torch.ndim == 2: # weight - c3, _ = data_torch.shape - else: # bias - c3 = data_torch.shape[0] - assert c3 % 3 == 0 - c = c3 // 3 - wq = data_torch[:c] - wk = data_torch[c: c * 2] - wv = data_torch[c * 2:] - return [ - (self.map_tensor_name(name.replace("qkv", "q")), wq), - (self.map_tensor_name(name.replace("qkv", "k")), wk), - (self.map_tensor_name(name.replace("qkv", "v")), wv), - ] - else: - return [(self.map_tensor_name(name), data_torch)] - return [(self.map_tensor_name(name), data_torch)] @@ -7335,10 +7296,9 @@ class DeepseekV2Model(TextModel): super().set_gguf_parameters() hparams = self.hparams - kv_lora_rank = hparams["q_lora_rank"] if hparams["q_lora_rank"] is not None else 512 + kv_lora_rank = hparams["kv_lora_rank"] if hparams["kv_lora_rank"] is not None else 512 routed_scaling_factor = hparams.get("routed_scaling_factor", 1.0) norm_topk_prob = hparams.get("norm_topk_prob", False) - scoring_func = hparams.get("scoring_func", "softmax") self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"]) self.gguf_writer.add_vocab_size(hparams["vocab_size"]) @@ -7361,12 +7321,6 @@ class DeepseekV2Model(TextModel): self.gguf_writer.add_expert_weights_scale(routed_scaling_factor) self.gguf_writer.add_expert_weights_norm(norm_topk_prob) - if scoring_func == "sigmoid": - self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID) - elif scoring_func == "softmax": - self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX) - else: - raise ValueError(f"Unsupported scoring_func value: {scoring_func}") self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) rope_scaling = self.hparams.get("rope_scaling") or {} diff --git a/examples/eval-callback/eval-callback.cpp b/examples/eval-callback/eval-callback.cpp index 039bf19c99..80c693ce61 100644 --- a/examples/eval-callback/eval-callback.cpp +++ b/examples/eval-callback/eval-callback.cpp @@ -74,19 +74,19 @@ static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne } } for (int64_t i3 = 0; i3 < ne[3]; i3++) { - LOG(" [\n"); + LOG(" [\n"); for (int64_t i2 = 0; i2 < ne[2]; i2++) { if (i2 == n && ne[2] > 2*n) { - LOG(" ..., \n"); + LOG(" ..., \n"); i2 = ne[2] - n; } - LOG(" [\n"); + LOG(" [\n"); for (int64_t i1 = 0; i1 < ne[1]; i1++) { if (i1 == n && ne[1] > 2*n) { - LOG(" ..., \n"); + LOG(" ..., \n"); i1 = ne[1] - n; } - LOG(" ["); + LOG(" ["); for (int64_t i0 = 0; i0 < ne[0]; i0++) { if (i0 == n && ne[0] > 2*n) { LOG("..., "); @@ -98,10 +98,10 @@ static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne } LOG("],\n"); } - LOG(" ],\n"); + LOG(" ],\n"); } - LOG(" ]\n"); - LOG(" sum = %f\n", sum); + LOG(" ]\n"); + LOG(" sum = %f\n", sum); } // TODO: make this abort configurable/optional? @@ -136,7 +136,7 @@ static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data) { snprintf(src1_str, sizeof(src1_str), "%s{%s}", src1->name, ggml_ne_string(src1).c_str()); } - LOG("%s: %16s = (%s) %10s(%s{%s}, %s}) = {%s}\n", __func__, + LOG("%s: %24s = (%s) %10s(%s{%s}, %s}) = {%s}\n", __func__, t->name, ggml_type_name(t->type), ggml_op_desc(t), src0->name, ggml_ne_string(src0).c_str(), src1 ? src1_str : "", diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index e0de7f8e72..0c04e10c47 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -1127,12 +1127,12 @@ class GGUFWriter: def add_vision_is_deepstack_layers(self, layers: Sequence[bool]) -> None: self.add_array(Keys.ClipVision.IS_DEEPSTACK_LAYERS, layers) - def add_vision_sam_layers_count(self, value: int) -> None: self.add_uint32(Keys.ClipVision.SAM.BLOCK_COUNT, value) - + def add_vision_sam_embedding_length(self, value: int) -> None: self.add_uint32(Keys.ClipVision.SAM.EMBEDDING_LENGTH, value) + # audio models def add_audio_projection_dim(self, value: int) -> None: diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 5032328723..6c339cd7f2 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -1,9 +1,6 @@ from __future__ import annotations from typing import Sequence - -from numpy.f2py.auxfuncs import throw_error - from .constants import MODEL_ARCH, MODEL_TENSOR, MODEL_TENSORS, TENSOR_NAMES @@ -1242,11 +1239,11 @@ class TensorNameMap: "visual.pos_embed", # qwen3vl "model.vision.patch_embedding.position_embedding", # cogvlm ), - + MODEL_TENSOR.V_ENC_EMBD_IMGNL: ( "model.image_newline", # Deepseek-OCR ), - + MODEL_TENSOR.V_ENC_EMBD_VSEP: ( "model.view_seperator", # Deepseek-OCR ), diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index a50006ca9d..2e77cc8e97 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -561,9 +561,9 @@ struct clip_graph { hparams(model.hparams), img(img), patch_size(hparams.patch_size), - n_patches_x(img.nx / patch_size), // sam 1024 / 16 = 64 - n_patches_y(img.ny / patch_size), // sam 1024 / 16 = 64 - n_patches(n_patches_x * n_patches_y), // sam 64 * 64 = 4096 + n_patches_x(img.nx / patch_size), + n_patches_y(img.ny / patch_size), + n_patches(n_patches_x * n_patches_y), n_embd(hparams.n_embd), n_head(hparams.n_head), d_head(n_embd / n_head), @@ -664,13 +664,13 @@ struct clip_graph { ggml_tensor * inp_raw = build_inp_raw(); ggml_tensor * sam_out = build_sam(inp_raw); ggml_tensor * clip_out = build_dsocr_clip(sam_out); - + int clip_n_patches = sam_out->ne[0] * sam_out->ne[1]; - + sam_out = ggml_cont(ctx0, ggml_permute(ctx0, sam_out, 1, 2, 0, 3)); sam_out = ggml_reshape_2d(ctx0, sam_out, sam_out->ne[0], clip_n_patches); clip_out = ggml_view_2d(ctx0, clip_out, n_embd, clip_n_patches, clip_out->nb[1], clip_out->nb[1]); - + ggml_tensor * cur; cur = ggml_concat(ctx0, clip_out, sam_out, 0); cur = ggml_reshape_2d(ctx0, cur, 2*n_embd,clip_n_patches); @@ -1302,7 +1302,7 @@ struct clip_graph { norm_t, hparams.ffn_op, model.position_embeddings, - nullptr); // shape [1024, 16, 16] + nullptr); // remove CLS token cur = ggml_view_2d(ctx0, cur, @@ -2260,7 +2260,6 @@ private: const int64_t C = rel_pos->ne[0]; // channels const int64_t L = rel_pos->ne[1]; // length - //GGML_ASSERT(2*std::max(q_size, k_size) - 1 == L); const auto max_rel_dist = 2*std::max(q_size, k_size) - 1; ggml_tensor * rel_pos_resized = rel_pos; @@ -2399,18 +2398,15 @@ private: // build the input after conv2d (inp_raw --> patches) // returns tensor with shape [n_embd, n_patches] ggml_tensor * build_inp() { - // Image to Patch Embedding. - ggml_tensor * inp_raw = build_inp_raw(); // sam shape = [1024, 1024, 3] - // sam patch_embeddings_0 shape = [768, 3, 16, 16] - ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1); // sam shape = [64, 64, 768] - inp = ggml_reshape_2d(ctx0, inp, n_patches, n_embd); // sam shape = [4096, 768] - inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp)); // sam shape = [768, 4096] + ggml_tensor * inp_raw = build_inp_raw(); + ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1); + inp = ggml_reshape_2d(ctx0, inp, n_patches, n_embd); + inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp)); if (model.patch_bias) { - // sam patch_bias shape = [768] inp = ggml_add(ctx0, inp, model.patch_bias); cb(inp, "patch_bias", -1); } - return inp; // shape = [n_embd, n_patches] same as [768, 4096] + return inp; } ggml_tensor * build_inp_raw(int channels = 3) { @@ -2707,11 +2703,11 @@ private: const int d_heads = n_embd / n_heads; ggml_tensor * inpL; - + inpL = ggml_conv_2d_sk_p0(ctx0, model.patch_embed_proj_w, inp_raw); inpL = ggml_add(ctx0, inpL, ggml_reshape_3d(ctx0, model.patch_embed_proj_b, 1, 1, n_embd)); inpL = ggml_cont(ctx0, ggml_permute(ctx0, inpL, 1, 2, 0, 3)); - + ggml_tensor * cur; const auto tgt_size = inpL->ne[1]; const auto str_size = model.pos_embed->ne[1]; @@ -2756,7 +2752,7 @@ private: // self-attention { const int B = cur->ne[3]; - + cur = ggml_mul_mat(ctx0, layer.qkv_w, cur); cur = ggml_add(ctx0, cur, layer.qkv_b); cur = ggml_cont(ctx0, cur); // Ensure tensor is contiguous before reshape @@ -2836,7 +2832,7 @@ private: cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 2, 0, 3)); cur = build_norm(cur, model.neck_1_w, model.neck_1_b, NORM_TYPE_NORMAL, hparams.eps, -1); cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 2, 0, 1, 3)); - + cur = ggml_conv_2d(ctx0, model.neck_2_w, cur, 1, 1, 1, 1, 1, 1); cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 2, 0, 3)); cur = build_norm(cur, model.neck_3_w, model.neck_3_b, NORM_TYPE_NORMAL, hparams.eps, -1); @@ -2866,7 +2862,7 @@ private: if (tgt_size != src_size) { ggml_tensor * old_pos_embd; ggml_tensor * cls_tok; - + old_pos_embd = ggml_view_2d( ctx0, new_pos_embd, new_pos_embd->ne[0], src_size * src_size, @@ -2895,7 +2891,7 @@ private: ggml_tensor * positions = ggml_cast(ctx0, ggml_arange(ctx0, 0, n_pos, 1), GGML_TYPE_I32); ggml_tensor * learned_pos_embd = ggml_get_rows(ctx0, new_pos_embd, positions); - ggml_tensor * cur = build_vit(inp, n_pos, NORM_TYPE_NORMAL, ffn_op_type::FFN_GELU_QUICK, + ggml_tensor * cur = build_vit(inp, n_pos, NORM_TYPE_NORMAL, ffn_op_type::FFN_GELU_QUICK, learned_pos_embd, nullptr); // shape [1024, 16, 16] ggml_build_forward_expand(gf, cur); @@ -5174,11 +5170,11 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str const int orig_h = original_size.height; const int orig_area = orig_h * orig_w; std::array color; - + for (int i = 0; i < 3; i++) { color[i] = (int)(255 * params.image_mean[i]); } - + int mode_i = 0; int min_diff = orig_area; @@ -5193,7 +5189,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str if (mode_i < 2) { /* Native Resolution (Tiny/Small) */ const int image_size = native_resolutions[mode_i]; - + // Just resize the image to image_size × image_size clip_image_u8_ptr resized_img(clip_image_u8_init()); img_tool::resize(*img, *resized_img, @@ -5210,7 +5206,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str else if (mode_i < 4) { /* Native Resolution (Base/Large) */ const int image_size = native_resolutions[mode_i]; - + // Resize maintaining aspect ratio, then pad to square float scale = std::min( static_cast(image_size) / orig_w, @@ -5267,7 +5263,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str else { GGML_ABORT("DeepSeek-OCR hasn't supported Gundam/Gundam-Master yet"); /* Dynamic Resolution (Gundam/Gundam-Master) */ - + // configurable, or read from params const int min_num = 2; const int max_num = 9; @@ -5276,10 +5272,10 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str // original image size const int orig_w = original_size.width; const int orig_h = original_size.height; - + // create overview image (thumbnail) clip_image_u8_ptr overview_img(clip_image_u8_init()); - img_tool::resize(*img, *overview_img, { image_size, image_size }, + img_tool::resize(*img, *overview_img, { image_size, image_size }, img_tool::RESIZE_ALGO_BICUBIC_PILLOW, true, color); clip_image_f32_ptr overview_f32(clip_image_f32_init()); normalize_image_u8_to_f32(*overview_img, *overview_f32, params.image_mean, params.image_std); @@ -5287,7 +5283,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str // build candidate grids (cols, rows) auto target_ratios = ds_build_target_ratios(min_num, max_num); - + // pick the grid that best matches the original aspect ratio const float aspect_ratio = static_cast(orig_w) / static_cast(orig_h); auto best = ds_find_closest_ratio(aspect_ratio, target_ratios, orig_w, orig_h, image_size); @@ -5296,7 +5292,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str // resize to refined size (no padding, direct resize) clip_image_u8_ptr refined_img(clip_image_u8_init()); - img_tool::resize(*img, *refined_img, { image_size * grid_cols, image_size * grid_rows }, + img_tool::resize(*img, *refined_img, { image_size * grid_cols, image_size * grid_rows }, img_tool::RESIZE_ALGO_BICUBIC_PILLOW, false); // crop slices from the refined image From a661c5299091bb2b2e2a984dd13c22072e338902 Mon Sep 17 00:00:00 2001 From: Saba Fallah <10401143+sfallah@users.noreply.github.com> Date: Thu, 4 Dec 2025 16:12:41 +0100 Subject: [PATCH 2/5] reverting automatically removed spaces --- tools/mtmd/clip.cpp | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 5bb85a89f1..7771cfc371 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -664,13 +664,13 @@ struct clip_graph { ggml_tensor * inp_raw = build_inp_raw(); ggml_tensor * sam_out = build_sam(inp_raw); ggml_tensor * clip_out = build_dsocr_clip(sam_out); - + int clip_n_patches = sam_out->ne[0] * sam_out->ne[1]; - + sam_out = ggml_cont(ctx0, ggml_permute(ctx0, sam_out, 1, 2, 0, 3)); sam_out = ggml_reshape_2d(ctx0, sam_out, sam_out->ne[0], clip_n_patches); clip_out = ggml_view_2d(ctx0, clip_out, n_embd, clip_n_patches, clip_out->nb[1], clip_out->nb[1]); - + ggml_tensor * cur; cur = ggml_concat(ctx0, clip_out, sam_out, 0); cur = ggml_reshape_2d(ctx0, cur, 2*n_embd,clip_n_patches); @@ -2703,11 +2703,11 @@ private: const int d_heads = n_embd / n_heads; ggml_tensor * inpL; - + inpL = ggml_conv_2d_sk_p0(ctx0, model.patch_embed_proj_w, inp_raw); inpL = ggml_add(ctx0, inpL, ggml_reshape_3d(ctx0, model.patch_embed_proj_b, 1, 1, n_embd)); inpL = ggml_cont(ctx0, ggml_permute(ctx0, inpL, 1, 2, 0, 3)); - + ggml_tensor * cur; const auto tgt_size = inpL->ne[1]; const auto str_size = model.pos_embed->ne[1]; @@ -2752,7 +2752,7 @@ private: // self-attention { const int B = cur->ne[3]; - + cur = ggml_mul_mat(ctx0, layer.qkv_w, cur); cur = ggml_add(ctx0, cur, layer.qkv_b); cur = ggml_cont(ctx0, cur); // Ensure tensor is contiguous before reshape @@ -2832,7 +2832,7 @@ private: cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 2, 0, 3)); cur = build_norm(cur, model.neck_1_w, model.neck_1_b, NORM_TYPE_NORMAL, hparams.eps, -1); cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 2, 0, 1, 3)); - + cur = ggml_conv_2d(ctx0, model.neck_2_w, cur, 1, 1, 1, 1, 1, 1); cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 2, 0, 3)); cur = build_norm(cur, model.neck_3_w, model.neck_3_b, NORM_TYPE_NORMAL, hparams.eps, -1); @@ -2862,7 +2862,7 @@ private: if (tgt_size != src_size) { ggml_tensor * old_pos_embd; ggml_tensor * cls_tok; - + old_pos_embd = ggml_view_2d( ctx0, new_pos_embd, new_pos_embd->ne[0], src_size * src_size, @@ -2891,7 +2891,7 @@ private: ggml_tensor * positions = ggml_cast(ctx0, ggml_arange(ctx0, 0, n_pos, 1), GGML_TYPE_I32); ggml_tensor * learned_pos_embd = ggml_get_rows(ctx0, new_pos_embd, positions); - ggml_tensor * cur = build_vit(inp, n_pos, NORM_TYPE_NORMAL, ffn_op_type::FFN_GELU_QUICK, + ggml_tensor * cur = build_vit(inp, n_pos, NORM_TYPE_NORMAL, ffn_op_type::FFN_GELU_QUICK, learned_pos_embd, nullptr); // shape [1024, 16, 16] ggml_build_forward_expand(gf, cur); @@ -5167,11 +5167,11 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str const int orig_h = original_size.height; const int orig_area = orig_h * orig_w; std::array color; - + for (int i = 0; i < 3; i++) { color[i] = (int)(255 * params.image_mean[i]); } - + int mode_i = 0; int min_diff = orig_area; @@ -5186,7 +5186,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str if (mode_i < 2) { /* Native Resolution (Tiny/Small) */ const int image_size = native_resolutions[mode_i]; - + // Just resize the image to image_size × image_size clip_image_u8_ptr resized_img(clip_image_u8_init()); img_tool::resize(*img, *resized_img, @@ -5203,7 +5203,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str else if (mode_i < 4) { /* Native Resolution (Base/Large) */ const int image_size = native_resolutions[mode_i]; - + // Resize maintaining aspect ratio, then pad to square float scale = std::min( static_cast(image_size) / orig_w, @@ -5260,7 +5260,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str else { GGML_ABORT("DeepSeek-OCR hasn't supported Gundam/Gundam-Master yet"); /* Dynamic Resolution (Gundam/Gundam-Master) */ - + // configurable, or read from params const int min_num = 2; const int max_num = 9; @@ -5269,10 +5269,10 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str // original image size const int orig_w = original_size.width; const int orig_h = original_size.height; - + // create overview image (thumbnail) clip_image_u8_ptr overview_img(clip_image_u8_init()); - img_tool::resize(*img, *overview_img, { image_size, image_size }, + img_tool::resize(*img, *overview_img, { image_size, image_size }, img_tool::RESIZE_ALGO_BICUBIC_PILLOW, true, color); clip_image_f32_ptr overview_f32(clip_image_f32_init()); normalize_image_u8_to_f32(*overview_img, *overview_f32, params.image_mean, params.image_std); @@ -5280,7 +5280,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str // build candidate grids (cols, rows) auto target_ratios = ds_build_target_ratios(min_num, max_num); - + // pick the grid that best matches the original aspect ratio const float aspect_ratio = static_cast(orig_w) / static_cast(orig_h); auto best = ds_find_closest_ratio(aspect_ratio, target_ratios, orig_w, orig_h, image_size); @@ -5289,7 +5289,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str // resize to refined size (no padding, direct resize) clip_image_u8_ptr refined_img(clip_image_u8_init()); - img_tool::resize(*img, *refined_img, { image_size * grid_cols, image_size * grid_rows }, + img_tool::resize(*img, *refined_img, { image_size * grid_cols, image_size * grid_rows }, img_tool::RESIZE_ALGO_BICUBIC_PILLOW, false); // crop slices from the refined image From 0399ddf14583ac02c7e61ed9ae27840ff2738c5a Mon Sep 17 00:00:00 2001 From: Saba Fallah <10401143+sfallah@users.noreply.github.com> Date: Thu, 4 Dec 2025 16:16:59 +0100 Subject: [PATCH 3/5] reverting automatically removed spaces --- gguf-py/gguf/gguf_writer.py | 2 +- gguf-py/gguf/tensor_mapping.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 0c04e10c47..15c318e11c 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -1129,7 +1129,7 @@ class GGUFWriter: def add_vision_sam_layers_count(self, value: int) -> None: self.add_uint32(Keys.ClipVision.SAM.BLOCK_COUNT, value) - + def add_vision_sam_embedding_length(self, value: int) -> None: self.add_uint32(Keys.ClipVision.SAM.EMBEDDING_LENGTH, value) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 283d7c59ca..90491b15da 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -1240,11 +1240,11 @@ class TensorNameMap: "visual.pos_embed", # qwen3vl "model.vision.patch_embedding.position_embedding", # cogvlm ), - + MODEL_TENSOR.V_ENC_EMBD_IMGNL: ( "model.image_newline", # Deepseek-OCR ), - + MODEL_TENSOR.V_ENC_EMBD_VSEP: ( "model.view_seperator", # Deepseek-OCR ), From c89171cf4d729e35da7c39125923165f1dc87071 Mon Sep 17 00:00:00 2001 From: bluebread Date: Thu, 4 Dec 2025 16:50:05 +0000 Subject: [PATCH 4/5] mtmd: fixed bad ocr check in Deepseek2 (LM) --- convert_hf_to_gguf.py | 32 +++++++++++--------------------- gguf-py/gguf/constants.py | 38 ++++++++++++++++++++++++++++++++++++++ src/llama-arch.cpp | 35 +++++++++++++++++++++++++++++++++++ src/llama-arch.h | 1 + src/llama-kv-cache.cpp | 2 +- src/llama-model.cpp | 10 +++++++--- 6 files changed, 93 insertions(+), 25 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 7ce5816144..9e77419e8f 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -6003,15 +6003,6 @@ class Gemma3VisionModel(MmprojModel): @ModelBase.register("DeepseekOCRForCausalLM") class DeepseekOCRVisionModel(MmprojModel): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - proc_fname = self.dir_model / "processor_config.json" - - if proc_fname.is_file(): - with open(proc_fname, "r") as f: - self.preprocessor_config = json.load(f) - def set_gguf_parameters(self): super().set_gguf_parameters() @@ -7263,12 +7254,20 @@ class DeepseekModel(TextModel): @ModelBase.register( "DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM", - "DeepseekOCRForCausalLM", "KimiVLForConditionalGeneration", ) class DeepseekV2Model(TextModel): model_arch = gguf.MODEL_ARCH.DEEPSEEK2 + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + vision_config = self.hparams.get('vision_config', {}).get('width', {}) + + if 'clip-l-14-224' in vision_config and 'sam_vit_b' in vision_config: + self.model_arch = gguf.MODEL_ARCH.DEEPSEEK2OCR + self.gguf_writer.arch = gguf.MODEL_ARCH_NAMES[self.model_arch] + self.gguf_writer.add_architecture() + def set_vocab(self): try: self._set_vocab_gpt2() @@ -7324,7 +7323,7 @@ class DeepseekV2Model(TextModel): raise NotImplementedError(f"Deepseek pre-tokenizer {tokpre!r} is not supported yet!") def set_gguf_parameters(self): - is_ocr = (self.hparams["num_hidden_layers"] == 12) + is_ocr = (self.model_arch == gguf.MODEL_ARCH.DEEPSEEK2OCR) if is_ocr: self.hparams['rope_theta'] = self.hparams.get('rope_theta', 10000.0) @@ -7335,11 +7334,9 @@ class DeepseekV2Model(TextModel): super().set_gguf_parameters() hparams = self.hparams - kv_lora_rank = hparams["q_lora_rank"] if hparams["q_lora_rank"] is not None else 512 + kv_lora_rank = hparams["kv_lora_rank"] if hparams.get("kv_lora_rank") is not None else 512 routed_scaling_factor = hparams.get("routed_scaling_factor", 1.0) norm_topk_prob = hparams.get("norm_topk_prob", False) - scoring_func = hparams.get("scoring_func", "softmax") - self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"]) self.gguf_writer.add_vocab_size(hparams["vocab_size"]) if "q_lora_rank" in hparams and hparams["q_lora_rank"] is not None: @@ -7361,12 +7358,6 @@ class DeepseekV2Model(TextModel): self.gguf_writer.add_expert_weights_scale(routed_scaling_factor) self.gguf_writer.add_expert_weights_norm(norm_topk_prob) - if scoring_func == "sigmoid": - self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID) - elif scoring_func == "softmax": - self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX) - else: - raise ValueError(f"Unsupported scoring_func value: {scoring_func}") self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) rope_scaling = self.hparams.get("rope_scaling") or {} @@ -7462,7 +7453,6 @@ class DeepseekV2Model(TextModel): if len(experts) > 0: raise ValueError(f"Unprocessed experts: {experts}") - @ModelBase.register("MiniMaxM2ForCausalLM") class MiniMaxM2Model(TextModel): model_arch = gguf.MODEL_ARCH.MINIMAXM2 diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 884ce68de5..a6f30f67f2 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -408,6 +408,7 @@ class MODEL_ARCH(IntEnum): ARCTIC = auto() DEEPSEEK = auto() DEEPSEEK2 = auto() + DEEPSEEK2OCR = auto() CHATGLM = auto() GLM4 = auto() GLM4_MOE = auto() @@ -797,6 +798,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.ARCTIC: "arctic", MODEL_ARCH.DEEPSEEK: "deepseek", MODEL_ARCH.DEEPSEEK2: "deepseek2", + MODEL_ARCH.DEEPSEEK2OCR: "deepseek2-ocr", MODEL_ARCH.CHATGLM: "chatglm", MODEL_ARCH.GLM4: "glm4", MODEL_ARCH.GLM4_MOE: "glm4moe", @@ -2375,6 +2377,38 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_UP_SHEXP, MODEL_TENSOR.FFN_EXP_PROBS_B, ], + MODEL_ARCH.DEEPSEEK2OCR: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_A, + MODEL_TENSOR.ATTN_Q_B, + MODEL_TENSOR.ATTN_KV_A_MQA, + MODEL_TENSOR.ATTN_KV_B, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_B, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_V_B, + MODEL_TENSOR.ATTN_Q_A_NORM, + MODEL_TENSOR.ATTN_KV_A_NORM, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_GATE_SHEXP, + MODEL_TENSOR.FFN_DOWN_SHEXP, + MODEL_TENSOR.FFN_UP_SHEXP, + MODEL_TENSOR.FFN_EXP_PROBS_B, + ], MODEL_ARCH.ERNIE4_5_MOE: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, @@ -3192,6 +3226,10 @@ MODEL_TENSOR_SKIP: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_ROT_EMBD, ], + MODEL_ARCH.DEEPSEEK2OCR: [ + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_ROT_EMBD, + ], MODEL_ARCH.CHATGLM: [ MODEL_TENSOR.ROPE_FREQS, ], diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 7f61d547ee..1cb91209f5 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -66,6 +66,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_ARCTIC, "arctic" }, { LLM_ARCH_DEEPSEEK, "deepseek" }, { LLM_ARCH_DEEPSEEK2, "deepseek2" }, + { LLM_ARCH_DEEPSEEK2OCR, "deepseek2-ocr" }, { LLM_ARCH_CHATGLM, "chatglm" }, { LLM_ARCH_GLM4, "glm4" }, { LLM_ARCH_GLM4_MOE, "glm4moe" }, @@ -1549,6 +1550,40 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, }, }, + { + LLM_ARCH_DEEPSEEK2OCR, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q_A_NORM, "blk.%d.attn_q_a_norm" }, + { LLM_TENSOR_ATTN_KV_A_NORM, "blk.%d.attn_kv_a_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_Q_A, "blk.%d.attn_q_a" }, + { LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" }, + { LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" }, + { LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" }, + { LLM_TENSOR_ATTN_K_B, "blk.%d.attn_k_b" }, + { LLM_TENSOR_ATTN_V_B, "blk.%d.attn_v_b" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, + }, + }, { LLM_ARCH_PLM, { diff --git a/src/llama-arch.h b/src/llama-arch.h index e113180024..f01e7c36b8 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -70,6 +70,7 @@ enum llm_arch { LLM_ARCH_ARCTIC, LLM_ARCH_DEEPSEEK, LLM_ARCH_DEEPSEEK2, + LLM_ARCH_DEEPSEEK2OCR, LLM_ARCH_CHATGLM, LLM_ARCH_GLM4, LLM_ARCH_GLM4_MOE, diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index e26385a1fe..d7a261ba3e 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -1385,7 +1385,7 @@ ggml_tensor * llama_kv_cache::build_rope_shift( // See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly. // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation. - const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2 + const float yarn_attn_factor = (model.arch == LLM_ARCH_DEEPSEEK2 || model.arch == LLM_ARCH_DEEPSEEK2OCR) ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) : cparams.yarn_attn_factor; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 11654cc7a9..7bd75c8043 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1605,10 +1605,11 @@ void llama_model::load_hparams(llama_model_loader & ml) { } } break; case LLM_ARCH_DEEPSEEK2: + case LLM_ARCH_DEEPSEEK2OCR: { // lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26); - bool is_ocr = (name.find("ocr") != std::string::npos || name.find("OCR") != std::string::npos); + bool is_ocr = (arch == LLM_ARCH_DEEPSEEK2OCR); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); if (!is_lite && !is_ocr) { @@ -4659,10 +4660,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } } break; case LLM_ARCH_DEEPSEEK2: + case LLM_ARCH_DEEPSEEK2OCR: { // lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B const bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26); - const bool is_ocr = (name.find("ocr") != std::string::npos || name.find("OCR") != std::string::npos); + const bool is_ocr = (arch == LLM_ARCH_DEEPSEEK2OCR); const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0); @@ -6879,7 +6881,7 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); } - if (arch == LLM_ARCH_DEEPSEEK2) { + if (arch == LLM_ARCH_DEEPSEEK2 || arch == LLM_ARCH_DEEPSEEK2OCR) { LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q); LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv); @@ -7406,6 +7408,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { llm = std::make_unique(*this, params); } break; case LLM_ARCH_DEEPSEEK2: + case LLM_ARCH_DEEPSEEK2OCR: { llm = std::make_unique(*this, params); } break; @@ -7754,6 +7757,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_ARCTIC: case LLM_ARCH_DEEPSEEK: case LLM_ARCH_DEEPSEEK2: + case LLM_ARCH_DEEPSEEK2OCR: case LLM_ARCH_PLM: case LLM_ARCH_CHATGLM: case LLM_ARCH_GLM4: From fc3f625fefc08303fb6b866bf1e69c709a1b2c79 Mon Sep 17 00:00:00 2001 From: bluebread Date: Thu, 4 Dec 2025 17:57:43 +0000 Subject: [PATCH 5/5] mtmd: support combined QKV projection in buid_vit --- tools/mtmd/clip.cpp | 47 ++++++++++++++++++++++++++++++++++----------- 1 file changed, 36 insertions(+), 11 deletions(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 7771cfc371..d1bed23d03 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -2152,19 +2152,44 @@ private: // self-attention { - ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.q_w, cur); - if (layer.q_b) { - Qcur = ggml_add(ctx0, Qcur, layer.q_b); - } + ggml_tensor * Qcur; + ggml_tensor * Kcur; + ggml_tensor * Vcur; + + if (layer.qkv_w) { + ggml_tensor * QKV; - ggml_tensor * Kcur = ggml_mul_mat(ctx0, layer.k_w, cur); - if (layer.k_b) { - Kcur = ggml_add(ctx0, Kcur, layer.k_b); - } + QKV = ggml_mul_mat(ctx0, layer.qkv_w, cur); + if (layer.qkv_b) { + QKV = ggml_add(ctx0, QKV, layer.qkv_b); + } + QKV = ggml_reshape_4d(ctx0, QKV, cur->ne[0], 3, cur->ne[1]*cur->ne[2], cur->ne[3]); - ggml_tensor * Vcur = ggml_mul_mat(ctx0, layer.v_w, cur); - if (layer.v_b) { - Vcur = ggml_add(ctx0, Vcur, layer.v_b); + const int ne0 = QKV->ne[0]; + const int ne2 = QKV->ne[2]; + const int ne3 = QKV->ne[3]; + const int nb1 = QKV->nb[1]; + const int nb2 = QKV->nb[2]; + const int nb3 = QKV->nb[3]; + + Qcur = ggml_cont(ctx0, ggml_view_3d(ctx0, QKV, ne0, ne2, ne3, nb2, nb3, 0*nb1)); + Kcur = ggml_cont(ctx0, ggml_view_3d(ctx0, QKV, ne0, ne2, ne3, nb2, nb3, 1*nb1)); + Vcur = ggml_cont(ctx0, ggml_view_3d(ctx0, QKV, ne0, ne2, ne3, nb2, nb3, 2*nb1)); + } else { + Qcur = ggml_mul_mat(ctx0, layer.q_w, cur); + if (layer.q_b) { + Qcur = ggml_add(ctx0, Qcur, layer.q_b); + } + + Kcur = ggml_mul_mat(ctx0, layer.k_w, cur); + if (layer.k_b) { + Kcur = ggml_add(ctx0, Kcur, layer.k_b); + } + + Vcur = ggml_mul_mat(ctx0, layer.v_w, cur); + if (layer.v_b) { + Vcur = ggml_add(ctx0, Vcur, layer.v_b); + } } if (layer.q_norm) {