mtmd: fixed bad ocr check in Deepseek2 (LM)

This commit is contained in:
bluebread 2025-12-04 16:50:05 +00:00
parent 7451b84105
commit c89171cf4d
6 changed files with 93 additions and 25 deletions

View File

@ -6003,15 +6003,6 @@ class Gemma3VisionModel(MmprojModel):
@ModelBase.register("DeepseekOCRForCausalLM")
class DeepseekOCRVisionModel(MmprojModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
proc_fname = self.dir_model / "processor_config.json"
if proc_fname.is_file():
with open(proc_fname, "r") as f:
self.preprocessor_config = json.load(f)
def set_gguf_parameters(self):
super().set_gguf_parameters()
@ -7263,12 +7254,20 @@ class DeepseekModel(TextModel):
@ModelBase.register(
"DeepseekV2ForCausalLM",
"DeepseekV3ForCausalLM",
"DeepseekOCRForCausalLM",
"KimiVLForConditionalGeneration",
)
class DeepseekV2Model(TextModel):
model_arch = gguf.MODEL_ARCH.DEEPSEEK2
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
vision_config = self.hparams.get('vision_config', {}).get('width', {})
if 'clip-l-14-224' in vision_config and 'sam_vit_b' in vision_config:
self.model_arch = gguf.MODEL_ARCH.DEEPSEEK2OCR
self.gguf_writer.arch = gguf.MODEL_ARCH_NAMES[self.model_arch]
self.gguf_writer.add_architecture()
def set_vocab(self):
try:
self._set_vocab_gpt2()
@ -7324,7 +7323,7 @@ class DeepseekV2Model(TextModel):
raise NotImplementedError(f"Deepseek pre-tokenizer {tokpre!r} is not supported yet!")
def set_gguf_parameters(self):
is_ocr = (self.hparams["num_hidden_layers"] == 12)
is_ocr = (self.model_arch == gguf.MODEL_ARCH.DEEPSEEK2OCR)
if is_ocr:
self.hparams['rope_theta'] = self.hparams.get('rope_theta', 10000.0)
@ -7335,11 +7334,9 @@ class DeepseekV2Model(TextModel):
super().set_gguf_parameters()
hparams = self.hparams
kv_lora_rank = hparams["q_lora_rank"] if hparams["q_lora_rank"] is not None else 512
kv_lora_rank = hparams["kv_lora_rank"] if hparams.get("kv_lora_rank") is not None else 512
routed_scaling_factor = hparams.get("routed_scaling_factor", 1.0)
norm_topk_prob = hparams.get("norm_topk_prob", False)
scoring_func = hparams.get("scoring_func", "softmax")
self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"])
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
if "q_lora_rank" in hparams and hparams["q_lora_rank"] is not None:
@ -7361,12 +7358,6 @@ class DeepseekV2Model(TextModel):
self.gguf_writer.add_expert_weights_scale(routed_scaling_factor)
self.gguf_writer.add_expert_weights_norm(norm_topk_prob)
if scoring_func == "sigmoid":
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
elif scoring_func == "softmax":
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
else:
raise ValueError(f"Unsupported scoring_func value: {scoring_func}")
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
rope_scaling = self.hparams.get("rope_scaling") or {}
@ -7462,7 +7453,6 @@ class DeepseekV2Model(TextModel):
if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts}")
@ModelBase.register("MiniMaxM2ForCausalLM")
class MiniMaxM2Model(TextModel):
model_arch = gguf.MODEL_ARCH.MINIMAXM2

View File

@ -408,6 +408,7 @@ class MODEL_ARCH(IntEnum):
ARCTIC = auto()
DEEPSEEK = auto()
DEEPSEEK2 = auto()
DEEPSEEK2OCR = auto()
CHATGLM = auto()
GLM4 = auto()
GLM4_MOE = auto()
@ -797,6 +798,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.ARCTIC: "arctic",
MODEL_ARCH.DEEPSEEK: "deepseek",
MODEL_ARCH.DEEPSEEK2: "deepseek2",
MODEL_ARCH.DEEPSEEK2OCR: "deepseek2-ocr",
MODEL_ARCH.CHATGLM: "chatglm",
MODEL_ARCH.GLM4: "glm4",
MODEL_ARCH.GLM4_MOE: "glm4moe",
@ -2375,6 +2377,38 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_UP_SHEXP,
MODEL_TENSOR.FFN_EXP_PROBS_B,
],
MODEL_ARCH.DEEPSEEK2OCR: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_Q_A,
MODEL_TENSOR.ATTN_Q_B,
MODEL_TENSOR.ATTN_KV_A_MQA,
MODEL_TENSOR.ATTN_KV_B,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_K_B,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_V_B,
MODEL_TENSOR.ATTN_Q_A_NORM,
MODEL_TENSOR.ATTN_KV_A_NORM,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_ROT_EMBD,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_GATE_SHEXP,
MODEL_TENSOR.FFN_DOWN_SHEXP,
MODEL_TENSOR.FFN_UP_SHEXP,
MODEL_TENSOR.FFN_EXP_PROBS_B,
],
MODEL_ARCH.ERNIE4_5_MOE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
@ -3192,6 +3226,10 @@ MODEL_TENSOR_SKIP: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_ROT_EMBD,
],
MODEL_ARCH.DEEPSEEK2OCR: [
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_ROT_EMBD,
],
MODEL_ARCH.CHATGLM: [
MODEL_TENSOR.ROPE_FREQS,
],

View File

@ -66,6 +66,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_ARCTIC, "arctic" },
{ LLM_ARCH_DEEPSEEK, "deepseek" },
{ LLM_ARCH_DEEPSEEK2, "deepseek2" },
{ LLM_ARCH_DEEPSEEK2OCR, "deepseek2-ocr" },
{ LLM_ARCH_CHATGLM, "chatglm" },
{ LLM_ARCH_GLM4, "glm4" },
{ LLM_ARCH_GLM4_MOE, "glm4moe" },
@ -1549,6 +1550,40 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
},
},
{
LLM_ARCH_DEEPSEEK2OCR,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q_A_NORM, "blk.%d.attn_q_a_norm" },
{ LLM_TENSOR_ATTN_KV_A_NORM, "blk.%d.attn_kv_a_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_Q_A, "blk.%d.attn_q_a" },
{ LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" },
{ LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" },
{ LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" },
{ LLM_TENSOR_ATTN_K_B, "blk.%d.attn_k_b" },
{ LLM_TENSOR_ATTN_V_B, "blk.%d.attn_v_b" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" },
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
},
},
{
LLM_ARCH_PLM,
{

View File

@ -70,6 +70,7 @@ enum llm_arch {
LLM_ARCH_ARCTIC,
LLM_ARCH_DEEPSEEK,
LLM_ARCH_DEEPSEEK2,
LLM_ARCH_DEEPSEEK2OCR,
LLM_ARCH_CHATGLM,
LLM_ARCH_GLM4,
LLM_ARCH_GLM4_MOE,

View File

@ -1385,7 +1385,7 @@ ggml_tensor * llama_kv_cache::build_rope_shift(
// See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly.
// See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2
const float yarn_attn_factor = (model.arch == LLM_ARCH_DEEPSEEK2 || model.arch == LLM_ARCH_DEEPSEEK2OCR)
? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale))
: cparams.yarn_attn_factor;

View File

@ -1605,10 +1605,11 @@ void llama_model::load_hparams(llama_model_loader & ml) {
}
} break;
case LLM_ARCH_DEEPSEEK2:
case LLM_ARCH_DEEPSEEK2OCR:
{
// lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B
bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26);
bool is_ocr = (name.find("ocr") != std::string::npos || name.find("OCR") != std::string::npos);
bool is_ocr = (arch == LLM_ARCH_DEEPSEEK2OCR);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead);
if (!is_lite && !is_ocr) {
@ -4659,10 +4660,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
}
} break;
case LLM_ARCH_DEEPSEEK2:
case LLM_ARCH_DEEPSEEK2OCR:
{
// lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B
const bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26);
const bool is_ocr = (name.find("ocr") != std::string::npos || name.find("OCR") != std::string::npos);
const bool is_ocr = (arch == LLM_ARCH_DEEPSEEK2OCR);
const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0);
@ -6879,7 +6881,7 @@ void llama_model::print_info() const {
LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale);
}
if (arch == LLM_ARCH_DEEPSEEK2) {
if (arch == LLM_ARCH_DEEPSEEK2 || arch == LLM_ARCH_DEEPSEEK2OCR) {
LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead);
LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q);
LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv);
@ -7406,6 +7408,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
llm = std::make_unique<llm_build_deepseek>(*this, params);
} break;
case LLM_ARCH_DEEPSEEK2:
case LLM_ARCH_DEEPSEEK2OCR:
{
llm = std::make_unique<llm_build_deepseek2>(*this, params);
} break;
@ -7754,6 +7757,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_ARCTIC:
case LLM_ARCH_DEEPSEEK:
case LLM_ARCH_DEEPSEEK2:
case LLM_ARCH_DEEPSEEK2OCR:
case LLM_ARCH_PLM:
case LLM_ARCH_CHATGLM:
case LLM_ARCH_GLM4: