From 97e0907c5b6a73d6f3e0614e4bb37e26e42ea17b Mon Sep 17 00:00:00 2001 From: Saba Fallah <10401143+sfallah@users.noreply.github.com> Date: Mon, 17 Nov 2025 11:07:33 +0100 Subject: [PATCH] loading LM testing Vision model loading --- convert_hf_to_gguf.py | 39 +++++++++++++++++++++------------------ src/llama-arch.cpp | 2 ++ src/llama-model.cpp | 39 +++++++++++++++++++++++++++++++++++++-- src/models/deepseek2.cpp | 32 +++++++++++++++++++++++++++++++- tools/mtmd/clip-impl.h | 28 ++++++++++++++-------------- tools/mtmd/clip.cpp | 19 ++++++++++++++----- 6 files changed, 119 insertions(+), 40 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 77fc77e823..385864dd11 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1494,12 +1494,9 @@ class MmprojModel(ModelBase): # FIXME: DeepseekOCRVisionModel specific hack if self.block_count is None: if isinstance(self, DeepseekOCRVisionModel): - print(self.hparams) clip_block_count = self.hparams['layers'] if clip_block_count is not None: self.block_count = clip_block_count - if sam_block_count is not None: - self.block_count = sam_block_count if self.block_count is None else self.block_count + sam_block_count if self.block_count is None: raise KeyError(f"could not find block count using any of: {self.n_block_keys}") self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.MMPROJ, self.block_count) @@ -5793,16 +5790,16 @@ class Gemma3VisionModel(MmprojModel): @ModelBase.register("DeepseekOCRForCausalLM") class DeepseekOCRVisionModel(MmprojModel): - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - + proc_fname = self.dir_model / "processor_config.json" - + if proc_fname.is_file(): with open(proc_fname, "r") as f: self.preprocessor_config = json.load(f) - - + + def set_gguf_parameters(self): super().set_gguf_parameters() hparams = self.hparams @@ -5860,7 +5857,7 @@ class DeepseekOCRVisionModel(MmprojModel): return [(self.map_tensor_name(name, try_suffixes=("",)), data_torch)] return [(self.map_tensor_name(name), data_torch)] - + @ModelBase.register("Gemma3nForConditionalGeneration") class Gemma3NModel(Gemma3Model): @@ -7095,9 +7092,14 @@ class DeepseekV2Model(TextModel): raise NotImplementedError(f"Deepseek pre-tokenizer {tokpre!r} is not supported yet!") def set_gguf_parameters(self): + is_ocr = (self.hparams["num_hidden_layers"] == 12) - # note: deepseek2 using MLA converts into MQA (ie: GQA with 1 group) - self.hparams["num_key_value_heads"] = 1 + if is_ocr: + self.hparams['rope_theta'] = self.hparams.get('rope_theta', 10000.0) + self.hparams['rms_norm_eps'] = self.hparams.get('rms_norm_eps', 1e-6) + else: + # note: deepseek2 using MLA converts into MQA (ie: GQA with 1 group) + self.hparams["num_key_value_heads"] = 1 super().set_gguf_parameters() hparams = self.hparams @@ -7110,13 +7112,16 @@ class DeepseekV2Model(TextModel): self.gguf_writer.add_vocab_size(hparams["vocab_size"]) if "q_lora_rank" in hparams and hparams["q_lora_rank"] is not None: self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"]) - self.gguf_writer.add_kv_lora_rank(kv_lora_rank) + if "kv_lora_rank" in hparams and hparams["kv_lora_rank"] is not None: + self.gguf_writer.add_kv_lora_rank(kv_lora_rank) # note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA - self.gguf_writer.add_key_length(kv_lora_rank + hparams["qk_rope_head_dim"]) - self.gguf_writer.add_value_length(kv_lora_rank) - self.gguf_writer.add_key_length_mla(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"]) - self.gguf_writer.add_value_length_mla(hparams["v_head_dim"]) + if not is_ocr: + self.gguf_writer.add_key_length(kv_lora_rank + hparams["qk_rope_head_dim"]) + self.gguf_writer.add_value_length(kv_lora_rank) + self.gguf_writer.add_key_length_mla(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"]) + self.gguf_writer.add_value_length_mla(hparams["v_head_dim"]) + self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"]) self.gguf_writer.add_expert_count(hparams["n_routed_experts"]) @@ -7131,8 +7136,6 @@ class DeepseekV2Model(TextModel): else: raise ValueError(f"Unsupported scoring_func value: {scoring_func}") - self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) - rope_scaling = self.hparams.get("rope_scaling") or {} if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling: self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index b7642b568d..ac3ab5cfa7 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -1446,6 +1446,8 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_ATTN_Q_A_NORM, "blk.%d.attn_q_a_norm" }, { LLM_TENSOR_ATTN_KV_A_NORM, "blk.%d.attn_kv_a_norm" }, { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, { LLM_TENSOR_ATTN_Q_A, "blk.%d.attn_q_a" }, { LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" }, { LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" }, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 829f1e3c14..79639c515e 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1562,12 +1562,16 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_DEEPSEEK2: { bool is_lite = (hparams.n_layer == 27); + bool is_ocr = (name.find("ocr") != std::string::npos || name.find("OCR") != std::string::npos); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); - if (!is_lite) { + if (!is_lite && !is_ocr) { ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); } - ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + if (!is_ocr) { + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + } ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla, false); ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla, false); ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); @@ -1583,6 +1587,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, false); switch (hparams.n_layer) { + case 12: type = LLM_TYPE_3B; break; case 27: type = LLM_TYPE_16B; break; case 60: type = LLM_TYPE_236B; break; case 61: type = LLM_TYPE_671B; break; @@ -4550,6 +4555,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { case LLM_ARCH_DEEPSEEK2: { const bool is_lite = (hparams.n_layer == 27); + const bool is_ocr = (name.find("ocr") != std::string::npos || name.find("OCR") != std::string::npos); const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0); @@ -4575,6 +4581,35 @@ bool llama_model::load_tensors(llama_model_loader & ml) { for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; + if (is_ocr) { + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (i < (int) hparams.n_layer_dense_lead) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + // MoE branch + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + // Shared expert branch + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + } + + continue; + } + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); if (!is_lite) { layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, 0); diff --git a/src/models/deepseek2.cpp b/src/models/deepseek2.cpp index 68f72f72bb..375f359454 100644 --- a/src/models/deepseek2.cpp +++ b/src/models/deepseek2.cpp @@ -5,6 +5,7 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { bool is_lite = (hparams.n_layer == 27); + bool is_ocr = (model.name.find("ocr") != std::string::npos || model.name.find("OCR") != std::string::npos); const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0); @@ -44,7 +45,36 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr cb(cur, "attn_norm", il); // self_attention - { + if (is_ocr) { + const int n_embed_head = hparams.n_embd / hparams.n_head(); + GGML_ASSERT(n_embed_head == n_embd_head_k && n_embed_head == n_embd_head_v); + + ggml_tensor * Qcur = NULL; + ggml_tensor * Kcur = NULL; + ggml_tensor * Vcur = NULL; + + Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + cb(Qcur, "q", il); + cb(Kcur, "k", il); + cb(Vcur, "v", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embed_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embed_head, n_head, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embed_head, n_head, n_tokens); + + GGML_ASSERT(fabs(freq_base - 10000.0) < 1e-4); + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_embed_head, rope_type, 0, freq_base, 1, 0, 1, 0, 0); + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_embed_head, rope_type, 0, freq_base, 1, 0, 1, 0, 0); + cb(Qcur, "q_pe", il); + cb(Kcur, "k_pe", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + } + else { ggml_tensor * q = NULL; if (!is_lite) { q = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur); diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index 4cb2808c26..520e0cf508 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -130,18 +130,18 @@ #define TN_TOK_EOI "v.eoi" // deepseek-ocr -#define TN_SAM_POS_EMBD "sam.pos_embd" -#define TN_SAM_PATCH_EMBD "sam.patch_embd.%s" -#define TN_SAM_PRE_NORM "sam.blk.%d.pre_ln.%s" -#define TN_SAM_POST_NORM "sam.blk.%d.post_ln" -#define TN_SAM_ATTN_POS_H "sam.blk.%d.attn.pos_h" -#define TN_SAM_ATTN_POS_W "sam.blk.%d.attn.pos_w" -#define TN_SAM_ATTN_QKV "sam.blk.%d.attn.qkv.%s" -#define TN_SAM_ATTN_OUT "sam.blk.%d.attn.out.%s" -#define TN_SAM_FFN_UP "sam.blk.%d.mlp.lin1.%s" -#define TN_SAM_FFN_DOWN "sam.blk.%d.mlp.lin2.%s" -#define TN_SAM_NECK "sam.neck.%d.%s" -#define TN_SAM_NET "sam.net_%d.%s" +#define TN_SAM_POS_EMBD "v.sam.pos_embd" +#define TN_SAM_PATCH_EMBD "v.sam.patch_embd.%s" +#define TN_SAM_PRE_NORM "v.sam.blk.%d.pre_ln.%s" +#define TN_SAM_POST_NORM "v.sam.blk.%d.post_ln" +#define TN_SAM_ATTN_POS_H "v.sam.blk.%d.attn.pos_h" +#define TN_SAM_ATTN_POS_W "v.sam.blk.%d.attn.pos_w" +#define TN_SAM_ATTN_QKV "v.sam.blk.%d.attn.qkv.%s" +#define TN_SAM_ATTN_OUT "v.sam.blk.%d.attn.out.%s" +#define TN_SAM_FFN_UP "v.sam.blk.%d.mlp.lin1.%s" +#define TN_SAM_FFN_DOWN "v.sam.blk.%d.mlp.lin2.%s" +#define TN_SAM_NECK "v.sam.neck.%d.%s" +#define TN_SAM_NET "v.sam.net_%d.%s" // align x to upper multiple of n #define CLIP_ALIGN(x, n) ((((x) + (n) - 1) / (n)) * (n)) @@ -170,7 +170,7 @@ enum projector_type { PROJECTOR_TYPE_LIGHTONOCR, PROJECTOR_TYPE_COGVLM, PROJECTOR_TYPE_JANUS_PRO, - PROJECTOR_TYPE_DEEPSEEK_OCR, + PROJECTOR_TYPE_DEEPSEEKOCR, PROJECTOR_TYPE_UNKNOWN, }; @@ -197,7 +197,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_LIGHTONOCR,"lightonocr"}, { PROJECTOR_TYPE_COGVLM, "cogvlm"}, { PROJECTOR_TYPE_JANUS_PRO, "janus_pro"}, - { PROJECTOR_TYPE_DEEPSEEK_OCR,"deepseek_orc"}, + { PROJECTOR_TYPE_DEEPSEEKOCR,"deepseekocr"}, }; static projector_type clip_projector_type_from_string(const std::string & str) { diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index d94d05b2f2..5d4257ac84 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -682,8 +682,8 @@ struct clip_graph { const int enc_n_patches = enc_image_size / enc_patch_size; // 64 - ggml_tensor * inpL = build_enc_inp(inp_raw, enc_patch_size, enc_image_size, enc_n_embd); - ggml_tensor * cur = ggml_add(ctx0, inpL, model.position_embeddings); + ggml_tensor * inpL = build_enc_inp(inp_raw, enc_patch_size, enc_n_patches, enc_n_embd); + ggml_tensor * cur = ggml_add(ctx0, inpL, model.pos_embed); // loop over layers for (int il = 0; il < _depth; il++) { @@ -842,7 +842,7 @@ struct clip_graph { ggml_tensor * inp_raw = build_inp_raw(); - ggml_tensor * global_features_1 = build_sam_enc(inp_raw); + ggml_tensor * global_features_1 = build_sam_enc(inp_raw, std::max(img.nx, img.ny)); ggml_tensor * global_features_2 = build_dp_ocr_clip(inp_raw, global_features_1); @@ -2862,6 +2862,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 { res = graph.build_cogvlm(); } break; + case PROJECTOR_TYPE_DEEPSEEKOCR: + { + res = graph.build_deepseek_ocr(); + } break; default: { res = graph.build_llava(); @@ -3187,6 +3191,11 @@ struct clip_model_loader { hparams.ffn_op = FFN_GELU_ERF; log_ffn_op = "gelu_erf"; // temporary solution for logging } break; + case PROJECTOR_TYPE_DEEPSEEKOCR: + { + hparams.set_limit_image_tokens(8, 1024); + hparams.set_warmup_n_tokens(256); // avoid OOM on warmup + } break; default: break; } @@ -3574,7 +3583,7 @@ struct clip_model_loader { model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight")); model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias")); } break; - case PROJECTOR_TYPE_DEEPSEEK_OCR: + case PROJECTOR_TYPE_DEEPSEEKOCR: { model.pos_embed = get_tensor(TN_SAM_POS_EMBD); model.patch_embed_proj_w = get_tensor(string_format(TN_SAM_PATCH_EMBD, "weight")); @@ -4830,7 +4839,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str } } } break; - case PROJECTOR_TYPE_DEEPSEEK_OCR: + case PROJECTOR_TYPE_DEEPSEEKOCR: { // configurable, or read from params const int min_num = 2;