diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 77fc77e823..c8a48c01bf 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1494,12 +1494,9 @@ class MmprojModel(ModelBase): # FIXME: DeepseekOCRVisionModel specific hack if self.block_count is None: if isinstance(self, DeepseekOCRVisionModel): - print(self.hparams) clip_block_count = self.hparams['layers'] if clip_block_count is not None: self.block_count = clip_block_count - if sam_block_count is not None: - self.block_count = sam_block_count if self.block_count is None else self.block_count + sam_block_count if self.block_count is None: raise KeyError(f"could not find block count using any of: {self.n_block_keys}") self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.MMPROJ, self.block_count) @@ -7095,10 +7092,15 @@ class DeepseekV2Model(TextModel): raise NotImplementedError(f"Deepseek pre-tokenizer {tokpre!r} is not supported yet!") def set_gguf_parameters(self): + is_ocr = (self.hparams["num_hidden_layers"] == 12) - # note: deepseek2 using MLA converts into MQA (ie: GQA with 1 group) - self.hparams["num_key_value_heads"] = 1 - + if is_ocr: + self.hparams['rope_theta'] = self.hparams.get('rope_theta', 10000.0) + self.hparams['rms_norm_eps'] = self.hparams.get('rms_norm_eps', 1e-6) + else: + # note: deepseek2 using MLA converts into MQA (ie: GQA with 1 group) + self.hparams["num_key_value_heads"] = 1 + super().set_gguf_parameters() hparams = self.hparams kv_lora_rank = hparams["q_lora_rank"] if hparams["q_lora_rank"] is not None else 512 diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index fca498a859..34ecb5e396 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -813,7 +813,7 @@ class GGUFWriter: self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value) def add_layer_norm_rms_eps(self, value: float) -> None: - self.add_float32(Keys.Attention.LAYERNORM_RMS_EPS.format(arch=self.arch), value) + self.add_float64(Keys.Attention.LAYERNORM_RMS_EPS.format(arch=self.arch), value) def add_group_norm_eps(self, value: float) -> None: self.add_float32(Keys.Attention.GROUPNORM_EPS.format(arch=self.arch), value) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index b7642b568d..ac3ab5cfa7 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -1446,6 +1446,8 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_ATTN_Q_A_NORM, "blk.%d.attn_q_a_norm" }, { LLM_TENSOR_ATTN_KV_A_NORM, "blk.%d.attn_kv_a_norm" }, { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, { LLM_TENSOR_ATTN_Q_A, "blk.%d.attn_q_a" }, { LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" }, { LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" }, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 829f1e3c14..a21a3ce619 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -4550,6 +4550,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { case LLM_ARCH_DEEPSEEK2: { const bool is_lite = (hparams.n_layer == 27); + const bool is_ocr = (name.find("ocr") != std::string::npos || name.find("OCR") != std::string::npos); const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0); @@ -4575,6 +4576,35 @@ bool llama_model::load_tensors(llama_model_loader & ml) { for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; + if (is_ocr) { + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_head * n_embd_head_k_mla}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_head * n_embd_head_k_mla}, 0); // TODO + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_head * n_embd_head_k_mla}, 0); // TODO + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v_mla, n_embd}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (i < (int) hparams.n_layer_dense_lead) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + // MoE branch + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + // Shared expert branch + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + } + + continue; + } + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); if (!is_lite) { layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, 0); diff --git a/src/models/deepseek2.cpp b/src/models/deepseek2.cpp index 68f72f72bb..e649286cec 100644 --- a/src/models/deepseek2.cpp +++ b/src/models/deepseek2.cpp @@ -5,6 +5,7 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { bool is_lite = (hparams.n_layer == 27); + bool is_ocr = (model.name.find("ocr") != std::string::npos || model.name.find("OCR") != std::string::npos); const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0); @@ -44,7 +45,33 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr cb(cur, "attn_norm", il); // self_attention - { + if (is_ocr) { + ggml_tensor * Qcur = NULL; + ggml_tensor * Kcur = NULL; + ggml_tensor * Vcur = NULL; + + Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + cb(Qcur, "q", il); + cb(Kcur, "k", il); + cb(Vcur, "v", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_rot, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_rot, n_head, n_tokens); + + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + cb(Qcur, "q_pe", il); + cb(Kcur, "k_pe", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + } + else { ggml_tensor * q = NULL; if (!is_lite) { q = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur);