This commit is contained in:
Saba Fallah 2025-12-04 15:05:58 +01:00
parent b26b507c4e
commit 386ba479a2
5 changed files with 42 additions and 95 deletions

View File

@ -1579,15 +1579,7 @@ class MmprojModel(ModelBase):
# TODO @ngxson : this is a hack to support both vision and audio encoders # TODO @ngxson : this is a hack to support both vision and audio encoders
have_multiple_encoders = self.has_audio_encoder and self.has_vision_encoder have_multiple_encoders = self.has_audio_encoder and self.has_vision_encoder
self.block_count = 128 if have_multiple_encoders else self.find_hparam(self.n_block_keys, True) self.block_count = 128 if have_multiple_encoders else self.find_hparam(self.n_block_keys)
# FIXME: DeepseekOCRVisionModel specific hack
if self.block_count is None:
if isinstance(self, DeepseekOCRVisionModel):
clip_block_count = self.hparams['layers']
if clip_block_count is not None:
self.block_count = clip_block_count
if self.block_count is None:
raise KeyError(f"could not find block count using any of: {self.n_block_keys}")
self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.MMPROJ, self.block_count) self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.MMPROJ, self.block_count)
# load preprocessor config # load preprocessor config
@ -6003,16 +5995,6 @@ class Gemma3VisionModel(MmprojModel):
@ModelBase.register("DeepseekOCRForCausalLM") @ModelBase.register("DeepseekOCRForCausalLM")
class DeepseekOCRVisionModel(MmprojModel): class DeepseekOCRVisionModel(MmprojModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
proc_fname = self.dir_model / "processor_config.json"
if proc_fname.is_file():
with open(proc_fname, "r") as f:
self.preprocessor_config = json.load(f)
def set_gguf_parameters(self): def set_gguf_parameters(self):
super().set_gguf_parameters() super().set_gguf_parameters()
hparams = self.hparams hparams = self.hparams
@ -6071,27 +6053,6 @@ class DeepseekOCRVisionModel(MmprojModel):
if ".attn.rel_pos_h" in name or ".attn.rel_pos_w" in name: if ".attn.rel_pos_h" in name or ".attn.rel_pos_w" in name:
return [(self.map_tensor_name(name, try_suffixes=("",)), data_torch)] return [(self.map_tensor_name(name, try_suffixes=("",)), data_torch)]
if name.startswith("model.vision_model.transformer.layers."):
# process visual tensors
# split QKV tensors if needed
if ".qkv_proj." in name:
if data_torch.ndim == 2: # weight
c3, _ = data_torch.shape
else: # bias
c3 = data_torch.shape[0]
assert c3 % 3 == 0
c = c3 // 3
wq = data_torch[:c]
wk = data_torch[c: c * 2]
wv = data_torch[c * 2:]
return [
(self.map_tensor_name(name.replace("qkv", "q")), wq),
(self.map_tensor_name(name.replace("qkv", "k")), wk),
(self.map_tensor_name(name.replace("qkv", "v")), wv),
]
else:
return [(self.map_tensor_name(name), data_torch)]
return [(self.map_tensor_name(name), data_torch)] return [(self.map_tensor_name(name), data_torch)]
@ -7335,10 +7296,9 @@ class DeepseekV2Model(TextModel):
super().set_gguf_parameters() super().set_gguf_parameters()
hparams = self.hparams hparams = self.hparams
kv_lora_rank = hparams["q_lora_rank"] if hparams["q_lora_rank"] is not None else 512 kv_lora_rank = hparams["kv_lora_rank"] if hparams["kv_lora_rank"] is not None else 512
routed_scaling_factor = hparams.get("routed_scaling_factor", 1.0) routed_scaling_factor = hparams.get("routed_scaling_factor", 1.0)
norm_topk_prob = hparams.get("norm_topk_prob", False) norm_topk_prob = hparams.get("norm_topk_prob", False)
scoring_func = hparams.get("scoring_func", "softmax")
self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"]) self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"])
self.gguf_writer.add_vocab_size(hparams["vocab_size"]) self.gguf_writer.add_vocab_size(hparams["vocab_size"])
@ -7361,12 +7321,6 @@ class DeepseekV2Model(TextModel):
self.gguf_writer.add_expert_weights_scale(routed_scaling_factor) self.gguf_writer.add_expert_weights_scale(routed_scaling_factor)
self.gguf_writer.add_expert_weights_norm(norm_topk_prob) self.gguf_writer.add_expert_weights_norm(norm_topk_prob)
if scoring_func == "sigmoid":
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
elif scoring_func == "softmax":
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
else:
raise ValueError(f"Unsupported scoring_func value: {scoring_func}")
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
rope_scaling = self.hparams.get("rope_scaling") or {} rope_scaling = self.hparams.get("rope_scaling") or {}

View File

@ -74,19 +74,19 @@ static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne
} }
} }
for (int64_t i3 = 0; i3 < ne[3]; i3++) { for (int64_t i3 = 0; i3 < ne[3]; i3++) {
LOG(" [\n"); LOG(" [\n");
for (int64_t i2 = 0; i2 < ne[2]; i2++) { for (int64_t i2 = 0; i2 < ne[2]; i2++) {
if (i2 == n && ne[2] > 2*n) { if (i2 == n && ne[2] > 2*n) {
LOG(" ..., \n"); LOG(" ..., \n");
i2 = ne[2] - n; i2 = ne[2] - n;
} }
LOG(" [\n"); LOG(" [\n");
for (int64_t i1 = 0; i1 < ne[1]; i1++) { for (int64_t i1 = 0; i1 < ne[1]; i1++) {
if (i1 == n && ne[1] > 2*n) { if (i1 == n && ne[1] > 2*n) {
LOG(" ..., \n"); LOG(" ..., \n");
i1 = ne[1] - n; i1 = ne[1] - n;
} }
LOG(" ["); LOG(" [");
for (int64_t i0 = 0; i0 < ne[0]; i0++) { for (int64_t i0 = 0; i0 < ne[0]; i0++) {
if (i0 == n && ne[0] > 2*n) { if (i0 == n && ne[0] > 2*n) {
LOG("..., "); LOG("..., ");
@ -98,10 +98,10 @@ static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne
} }
LOG("],\n"); LOG("],\n");
} }
LOG(" ],\n"); LOG(" ],\n");
} }
LOG(" ]\n"); LOG(" ]\n");
LOG(" sum = %f\n", sum); LOG(" sum = %f\n", sum);
} }
// TODO: make this abort configurable/optional? // TODO: make this abort configurable/optional?
@ -136,7 +136,7 @@ static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data) {
snprintf(src1_str, sizeof(src1_str), "%s{%s}", src1->name, ggml_ne_string(src1).c_str()); snprintf(src1_str, sizeof(src1_str), "%s{%s}", src1->name, ggml_ne_string(src1).c_str());
} }
LOG("%s: %16s = (%s) %10s(%s{%s}, %s}) = {%s}\n", __func__, LOG("%s: %24s = (%s) %10s(%s{%s}, %s}) = {%s}\n", __func__,
t->name, ggml_type_name(t->type), ggml_op_desc(t), t->name, ggml_type_name(t->type), ggml_op_desc(t),
src0->name, ggml_ne_string(src0).c_str(), src0->name, ggml_ne_string(src0).c_str(),
src1 ? src1_str : "", src1 ? src1_str : "",

View File

@ -1127,12 +1127,12 @@ class GGUFWriter:
def add_vision_is_deepstack_layers(self, layers: Sequence[bool]) -> None: def add_vision_is_deepstack_layers(self, layers: Sequence[bool]) -> None:
self.add_array(Keys.ClipVision.IS_DEEPSTACK_LAYERS, layers) self.add_array(Keys.ClipVision.IS_DEEPSTACK_LAYERS, layers)
def add_vision_sam_layers_count(self, value: int) -> None: def add_vision_sam_layers_count(self, value: int) -> None:
self.add_uint32(Keys.ClipVision.SAM.BLOCK_COUNT, value) self.add_uint32(Keys.ClipVision.SAM.BLOCK_COUNT, value)
def add_vision_sam_embedding_length(self, value: int) -> None: def add_vision_sam_embedding_length(self, value: int) -> None:
self.add_uint32(Keys.ClipVision.SAM.EMBEDDING_LENGTH, value) self.add_uint32(Keys.ClipVision.SAM.EMBEDDING_LENGTH, value)
# audio models # audio models
def add_audio_projection_dim(self, value: int) -> None: def add_audio_projection_dim(self, value: int) -> None:

View File

@ -1,9 +1,6 @@
from __future__ import annotations from __future__ import annotations
from typing import Sequence from typing import Sequence
from numpy.f2py.auxfuncs import throw_error
from .constants import MODEL_ARCH, MODEL_TENSOR, MODEL_TENSORS, TENSOR_NAMES from .constants import MODEL_ARCH, MODEL_TENSOR, MODEL_TENSORS, TENSOR_NAMES

View File

@ -561,9 +561,9 @@ struct clip_graph {
hparams(model.hparams), hparams(model.hparams),
img(img), img(img),
patch_size(hparams.patch_size), patch_size(hparams.patch_size),
n_patches_x(img.nx / patch_size), // sam 1024 / 16 = 64 n_patches_x(img.nx / patch_size),
n_patches_y(img.ny / patch_size), // sam 1024 / 16 = 64 n_patches_y(img.ny / patch_size),
n_patches(n_patches_x * n_patches_y), // sam 64 * 64 = 4096 n_patches(n_patches_x * n_patches_y),
n_embd(hparams.n_embd), n_embd(hparams.n_embd),
n_head(hparams.n_head), n_head(hparams.n_head),
d_head(n_embd / n_head), d_head(n_embd / n_head),
@ -1302,7 +1302,7 @@ struct clip_graph {
norm_t, norm_t,
hparams.ffn_op, hparams.ffn_op,
model.position_embeddings, model.position_embeddings,
nullptr); // shape [1024, 16, 16] nullptr);
// remove CLS token // remove CLS token
cur = ggml_view_2d(ctx0, cur, cur = ggml_view_2d(ctx0, cur,
@ -2260,7 +2260,6 @@ private:
const int64_t C = rel_pos->ne[0]; // channels const int64_t C = rel_pos->ne[0]; // channels
const int64_t L = rel_pos->ne[1]; // length const int64_t L = rel_pos->ne[1]; // length
//GGML_ASSERT(2*std::max(q_size, k_size) - 1 == L);
const auto max_rel_dist = 2*std::max(q_size, k_size) - 1; const auto max_rel_dist = 2*std::max(q_size, k_size) - 1;
ggml_tensor * rel_pos_resized = rel_pos; ggml_tensor * rel_pos_resized = rel_pos;
@ -2399,18 +2398,15 @@ private:
// build the input after conv2d (inp_raw --> patches) // build the input after conv2d (inp_raw --> patches)
// returns tensor with shape [n_embd, n_patches] // returns tensor with shape [n_embd, n_patches]
ggml_tensor * build_inp() { ggml_tensor * build_inp() {
// Image to Patch Embedding. ggml_tensor * inp_raw = build_inp_raw();
ggml_tensor * inp_raw = build_inp_raw(); // sam shape = [1024, 1024, 3] ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
// sam patch_embeddings_0 shape = [768, 3, 16, 16] inp = ggml_reshape_2d(ctx0, inp, n_patches, n_embd);
ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1); // sam shape = [64, 64, 768] inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp));
inp = ggml_reshape_2d(ctx0, inp, n_patches, n_embd); // sam shape = [4096, 768]
inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp)); // sam shape = [768, 4096]
if (model.patch_bias) { if (model.patch_bias) {
// sam patch_bias shape = [768]
inp = ggml_add(ctx0, inp, model.patch_bias); inp = ggml_add(ctx0, inp, model.patch_bias);
cb(inp, "patch_bias", -1); cb(inp, "patch_bias", -1);
} }
return inp; // shape = [n_embd, n_patches] same as [768, 4096] return inp;
} }
ggml_tensor * build_inp_raw(int channels = 3) { ggml_tensor * build_inp_raw(int channels = 3) {