parent
7f3e9d339c
commit
43a130b4d0
|
|
@ -620,6 +620,9 @@ class ModelBase:
|
||||||
if "thinker_config" in config:
|
if "thinker_config" in config:
|
||||||
# rename for Qwen2.5-Omni
|
# rename for Qwen2.5-Omni
|
||||||
config["text_config"] = config["thinker_config"]["text_config"]
|
config["text_config"] = config["thinker_config"]["text_config"]
|
||||||
|
if "language_config" in config:
|
||||||
|
# rename for DeepSeekOCR
|
||||||
|
config["text_config"] = config["language_config"]
|
||||||
return config
|
return config
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -1442,7 +1445,7 @@ class MmprojModel(ModelBase):
|
||||||
preprocessor_config: dict[str, Any]
|
preprocessor_config: dict[str, Any]
|
||||||
global_config: dict[str, Any]
|
global_config: dict[str, Any]
|
||||||
|
|
||||||
n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth"]
|
n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth", "width.clip-l-14-224.layers", "sam_vit_b.layers"]
|
||||||
|
|
||||||
has_vision_encoder: bool = True # by default
|
has_vision_encoder: bool = True # by default
|
||||||
has_audio_encoder: bool = False
|
has_audio_encoder: bool = False
|
||||||
|
|
@ -1488,13 +1491,31 @@ class MmprojModel(ModelBase):
|
||||||
# TODO @ngxson : this is a hack to support both vision and audio encoders
|
# TODO @ngxson : this is a hack to support both vision and audio encoders
|
||||||
have_multiple_encoders = self.has_audio_encoder and self.has_vision_encoder
|
have_multiple_encoders = self.has_audio_encoder and self.has_vision_encoder
|
||||||
self.block_count = 128 if have_multiple_encoders else self.find_hparam(self.n_block_keys, True)
|
self.block_count = 128 if have_multiple_encoders else self.find_hparam(self.n_block_keys, True)
|
||||||
|
# FIXME: DeepseekOCRVisionModel specific hack
|
||||||
|
if self.block_count is None:
|
||||||
|
if isinstance(self, DeepseekOCRVisionModel):
|
||||||
|
clip_block_count = self.hparams['width']['clip-l-14-224']['layers']
|
||||||
|
sam_block_count = self.hparams['width']['sam_vit_b']['layers']
|
||||||
|
if clip_block_count is not None:
|
||||||
|
self.block_count = clip_block_count
|
||||||
|
if sam_block_count is not None:
|
||||||
|
self.block_count = sam_block_count if self.block_count is None else self.block_count + sam_block_count
|
||||||
|
if self.block_count is None:
|
||||||
|
raise KeyError(f"could not find block count using any of: {self.n_block_keys}")
|
||||||
self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.MMPROJ, self.block_count)
|
self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.MMPROJ, self.block_count)
|
||||||
|
|
||||||
# load preprocessor config
|
# load preprocessor config
|
||||||
self.preprocessor_config = {}
|
self.preprocessor_config = {}
|
||||||
if not self.is_mistral_format:
|
if not self.is_mistral_format:
|
||||||
with open(self.dir_model / "preprocessor_config.json", "r", encoding="utf-8") as f:
|
# check if preprocessor_config.json exists
|
||||||
self.preprocessor_config = json.load(f)
|
if (self.dir_model / "preprocessor_config.json").is_file():
|
||||||
|
with open(self.dir_model / "preprocessor_config.json", "r", encoding="utf-8") as f:
|
||||||
|
self.preprocessor_config = json.load(f)
|
||||||
|
else:
|
||||||
|
# try "processing_config" file if exists
|
||||||
|
if (self.dir_model / "processing_config.json").is_file():
|
||||||
|
with open(self.dir_model / "processing_config.json", "r", encoding="utf-8") as f:
|
||||||
|
self.preprocessor_config = json.load(f)
|
||||||
|
|
||||||
def get_vision_config(self) -> dict[str, Any] | None:
|
def get_vision_config(self) -> dict[str, Any] | None:
|
||||||
config_name = "vision_config" if not self.is_mistral_format else "vision_encoder"
|
config_name = "vision_config" if not self.is_mistral_format else "vision_encoder"
|
||||||
|
|
@ -5770,6 +5791,61 @@ class Gemma3VisionModel(MmprojModel):
|
||||||
|
|
||||||
return [] # skip other tensors
|
return [] # skip other tensors
|
||||||
|
|
||||||
|
@ModelBase.register("DeepseekOCRForCausalLM")
|
||||||
|
class DeepseekOCRVisionModel(MmprojModel):
|
||||||
|
def set_gguf_parameters(self):
|
||||||
|
super().set_gguf_parameters()
|
||||||
|
hparams = self.hparams
|
||||||
|
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.DEEPSEEKOCR)
|
||||||
|
# default values below are taken from HF tranformers code
|
||||||
|
self.gguf_writer.add_vision_attention_layernorm_eps(hparams.get("layer_norm_eps", 1e-6))
|
||||||
|
self.gguf_writer.add_vision_use_gelu(True)
|
||||||
|
# calculate proj_scale_factor (used by tinygemma3 test model)
|
||||||
|
image_seq_length = self.preprocessor_config.get("image_seq_length", 256)
|
||||||
|
n_per_side = int(image_seq_length ** 0.5)
|
||||||
|
image_size = self.hparams["image_size"]
|
||||||
|
patch_size = self.hparams["patch_size"]
|
||||||
|
proj_scale_factor = (image_size // patch_size) // n_per_side
|
||||||
|
if proj_scale_factor > 0 and proj_scale_factor != 4:
|
||||||
|
# we only need to write this if it's not the default value
|
||||||
|
# in this case, we are converting a test model
|
||||||
|
self.gguf_writer.add_vision_projector_scale_factor(proj_scale_factor)
|
||||||
|
|
||||||
|
def get_vision_config(self) -> dict[str, Any]:
|
||||||
|
orig_vision_config = self.global_config.get("vision_config")
|
||||||
|
|
||||||
|
super().get_vision_config()
|
||||||
|
|
||||||
|
def tensor_force_quant(self, name, new_name, bid, n_dims):
|
||||||
|
# related to https://github.com/ggml-org/llama.cpp/issues/13025
|
||||||
|
if "input_projection" in name:
|
||||||
|
return gguf.GGMLQuantizationType.F16
|
||||||
|
if ".embeddings." in name:
|
||||||
|
return gguf.GGMLQuantizationType.F32
|
||||||
|
return super().tensor_force_quant(name, new_name, bid, n_dims)
|
||||||
|
|
||||||
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||||
|
del bid # unused
|
||||||
|
|
||||||
|
if "vision_model.head." in name:
|
||||||
|
return [] # skip redundant tensors for tinygemma3
|
||||||
|
|
||||||
|
if name.startswith("multi_modal_projector.") or name.startswith("vision_tower.") \
|
||||||
|
or name.startswith("multimodal_projector.") or name.startswith("vision_model."):
|
||||||
|
# process vision tensors
|
||||||
|
name = name.replace("_weight", ".weight")
|
||||||
|
|
||||||
|
# correct norm value ; only this "soft_emb_norm" need to be corrected as it's part of Gemma projector
|
||||||
|
# the other norm values are part of SigLIP model, and they are already correct
|
||||||
|
# ref code: Gemma3RMSNorm
|
||||||
|
if "soft_emb_norm.weight" in name:
|
||||||
|
logger.info(f"Correcting norm value for '{name}'")
|
||||||
|
data_torch = data_torch + 1
|
||||||
|
|
||||||
|
return [(self.map_tensor_name(name), data_torch)]
|
||||||
|
|
||||||
|
return [] # skip other tensors
|
||||||
|
|
||||||
|
|
||||||
@ModelBase.register("Gemma3nForConditionalGeneration")
|
@ModelBase.register("Gemma3nForConditionalGeneration")
|
||||||
class Gemma3NModel(Gemma3Model):
|
class Gemma3NModel(Gemma3Model):
|
||||||
|
|
@ -6943,6 +7019,7 @@ class DeepseekModel(TextModel):
|
||||||
@ModelBase.register(
|
@ModelBase.register(
|
||||||
"DeepseekV2ForCausalLM",
|
"DeepseekV2ForCausalLM",
|
||||||
"DeepseekV3ForCausalLM",
|
"DeepseekV3ForCausalLM",
|
||||||
|
"DeepseekOCRForCausalLM",
|
||||||
"KimiVLForConditionalGeneration",
|
"KimiVLForConditionalGeneration",
|
||||||
)
|
)
|
||||||
class DeepseekV2Model(TextModel):
|
class DeepseekV2Model(TextModel):
|
||||||
|
|
@ -7009,31 +7086,35 @@ class DeepseekV2Model(TextModel):
|
||||||
|
|
||||||
super().set_gguf_parameters()
|
super().set_gguf_parameters()
|
||||||
hparams = self.hparams
|
hparams = self.hparams
|
||||||
|
kv_lora_rank = hparams["q_lora_rank"] if hparams["q_lora_rank"] is not None else 512
|
||||||
|
routed_scaling_factor = hparams.get("routed_scaling_factor", 1.0)
|
||||||
|
norm_topk_prob = hparams.get("norm_topk_prob", False)
|
||||||
|
scoring_func = hparams.get("scoring_func", "softmax")
|
||||||
|
|
||||||
self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"])
|
self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"])
|
||||||
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
|
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
|
||||||
if "q_lora_rank" in hparams and hparams["q_lora_rank"] is not None:
|
if "q_lora_rank" in hparams and hparams["q_lora_rank"] is not None:
|
||||||
self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"])
|
self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"])
|
||||||
self.gguf_writer.add_kv_lora_rank(hparams["kv_lora_rank"])
|
self.gguf_writer.add_kv_lora_rank(kv_lora_rank)
|
||||||
|
|
||||||
# note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA
|
# note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA
|
||||||
self.gguf_writer.add_key_length(hparams["kv_lora_rank"] + hparams["qk_rope_head_dim"])
|
self.gguf_writer.add_key_length(kv_lora_rank + hparams["qk_rope_head_dim"])
|
||||||
self.gguf_writer.add_value_length(hparams["kv_lora_rank"])
|
self.gguf_writer.add_value_length(kv_lora_rank)
|
||||||
self.gguf_writer.add_key_length_mla(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"])
|
self.gguf_writer.add_key_length_mla(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"])
|
||||||
self.gguf_writer.add_value_length_mla(hparams["v_head_dim"])
|
self.gguf_writer.add_value_length_mla(hparams["v_head_dim"])
|
||||||
|
|
||||||
self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
|
self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
|
||||||
self.gguf_writer.add_expert_count(hparams["n_routed_experts"])
|
self.gguf_writer.add_expert_count(hparams["n_routed_experts"])
|
||||||
self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"])
|
self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"])
|
||||||
self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"])
|
self.gguf_writer.add_expert_weights_scale(routed_scaling_factor)
|
||||||
self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])
|
self.gguf_writer.add_expert_weights_norm(norm_topk_prob)
|
||||||
|
|
||||||
if hparams["scoring_func"] == "sigmoid":
|
if scoring_func == "sigmoid":
|
||||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
|
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
|
||||||
elif hparams["scoring_func"] == "softmax":
|
elif scoring_func == "softmax":
|
||||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
|
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported scoring_func value: {hparams['scoring_func']}")
|
raise ValueError(f"Unsupported scoring_func value: {scoring_func}")
|
||||||
|
|
||||||
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
|
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
|
||||||
|
|
||||||
|
|
@ -7043,12 +7124,14 @@ class DeepseekV2Model(TextModel):
|
||||||
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
|
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
|
||||||
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
|
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
|
||||||
self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1 * rope_scaling["mscale_all_dim"])
|
self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1 * rope_scaling["mscale_all_dim"])
|
||||||
|
self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("rms_norm_eps", 1e-6))
|
||||||
|
|
||||||
_experts: list[dict[str, Tensor]] | None = None
|
_experts: list[dict[str, Tensor]] | None = None
|
||||||
|
|
||||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||||
# skip vision tensors and remove "language_model." for Kimi-VL
|
# skip vision tensors and remove "language_model." for Kimi-VL
|
||||||
if "vision_tower" in name or "multi_modal_projector" in name:
|
if "vision_" in name or "multi_modal_projector" in name \
|
||||||
|
or "image_newline" in name or "model.projector" in name or "sam_model" in name or "view_seperator" in name:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
if name.startswith("language_model."):
|
if name.startswith("language_model."):
|
||||||
|
|
|
||||||
|
|
@ -664,6 +664,21 @@ class MODEL_TENSOR(IntEnum):
|
||||||
V_MM_GATE = auto() # cogvlm
|
V_MM_GATE = auto() # cogvlm
|
||||||
V_TOK_BOI = auto() # cogvlm
|
V_TOK_BOI = auto() # cogvlm
|
||||||
V_TOK_EOI = auto() # cogvlm
|
V_TOK_EOI = auto() # cogvlm
|
||||||
|
# DeepSeek-OCR sam_model
|
||||||
|
V_SAM_POS_EMBD = auto()
|
||||||
|
V_SAM_PATCH_EMBD = auto()
|
||||||
|
V_SAM_PRE_NORM = auto()
|
||||||
|
V_SAM_POST_NORM = auto()
|
||||||
|
V_SAM_ATTN_POS_H = auto()
|
||||||
|
V_SAM_ATTN_POS_W = auto()
|
||||||
|
V_SAM_ATTN_QKV = auto()
|
||||||
|
V_SAM_ATTN_OUT = auto()
|
||||||
|
V_SAM_MLP_LIN_1 = auto()
|
||||||
|
V_SAM_MLP_LIN_2 = auto()
|
||||||
|
V_SAM_NECK = auto()
|
||||||
|
V_SAM_NET_2 = auto()
|
||||||
|
V_SAM_NET_3 = auto()
|
||||||
|
|
||||||
# audio (mtmd)
|
# audio (mtmd)
|
||||||
A_ENC_EMBD_POS = auto()
|
A_ENC_EMBD_POS = auto()
|
||||||
A_ENC_CONV1D = auto()
|
A_ENC_CONV1D = auto()
|
||||||
|
|
@ -1030,6 +1045,20 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||||
MODEL_TENSOR.V_MM_GATE: "mm.gate",
|
MODEL_TENSOR.V_MM_GATE: "mm.gate",
|
||||||
MODEL_TENSOR.V_TOK_BOI: "v.boi",
|
MODEL_TENSOR.V_TOK_BOI: "v.boi",
|
||||||
MODEL_TENSOR.V_TOK_EOI: "v.eoi",
|
MODEL_TENSOR.V_TOK_EOI: "v.eoi",
|
||||||
|
# DeepSeek-OCR sam_model
|
||||||
|
MODEL_TENSOR.V_SAM_POS_EMBD: "v.sam.pos_embd",
|
||||||
|
MODEL_TENSOR.V_SAM_PATCH_EMBD: "v.sam.patch_embd",
|
||||||
|
MODEL_TENSOR.V_SAM_PRE_NORM: "v.sam.blk.{bid}.pre_ln",
|
||||||
|
MODEL_TENSOR.V_SAM_POST_NORM: "v.sam.blk.{bid}.post_ln",
|
||||||
|
MODEL_TENSOR.V_SAM_ATTN_POS_H: "v.sam.blk.{bid}.attn.pos_h",
|
||||||
|
MODEL_TENSOR.V_SAM_ATTN_POS_W: "v.sam.blk.{bid}.attn.pos_w",
|
||||||
|
MODEL_TENSOR.V_SAM_ATTN_QKV: "v.sam.blk.{bid}.attn.qkv",
|
||||||
|
MODEL_TENSOR.V_SAM_ATTN_OUT: "v.sam.blk.{bid}.attn.out",
|
||||||
|
MODEL_TENSOR.V_SAM_MLP_LIN_1: "v.sam.blk.{bid}.mlp.lin1",
|
||||||
|
MODEL_TENSOR.V_SAM_MLP_LIN_2: "v.sam.blk.{bid}.mlp.lin2",
|
||||||
|
MODEL_TENSOR.V_SAM_NECK: "v.sam.neck.{bid}",
|
||||||
|
MODEL_TENSOR.V_SAM_NET_2: "v.sam.net_2",
|
||||||
|
MODEL_TENSOR.V_SAM_NET_3: "v.sam.net_3",
|
||||||
# audio (mtmd)
|
# audio (mtmd)
|
||||||
MODEL_TENSOR.A_ENC_EMBD_POS: "a.position_embd",
|
MODEL_TENSOR.A_ENC_EMBD_POS: "a.position_embd",
|
||||||
MODEL_TENSOR.A_ENC_CONV1D: "a.conv1d.{bid}",
|
MODEL_TENSOR.A_ENC_CONV1D: "a.conv1d.{bid}",
|
||||||
|
|
@ -2247,7 +2276,9 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||||
MODEL_TENSOR.ATTN_Q_B,
|
MODEL_TENSOR.ATTN_Q_B,
|
||||||
MODEL_TENSOR.ATTN_KV_A_MQA,
|
MODEL_TENSOR.ATTN_KV_A_MQA,
|
||||||
MODEL_TENSOR.ATTN_KV_B,
|
MODEL_TENSOR.ATTN_KV_B,
|
||||||
|
MODEL_TENSOR.ATTN_K,
|
||||||
MODEL_TENSOR.ATTN_K_B,
|
MODEL_TENSOR.ATTN_K_B,
|
||||||
|
MODEL_TENSOR.ATTN_V,
|
||||||
MODEL_TENSOR.ATTN_V_B,
|
MODEL_TENSOR.ATTN_V_B,
|
||||||
MODEL_TENSOR.ATTN_Q_A_NORM,
|
MODEL_TENSOR.ATTN_Q_A_NORM,
|
||||||
MODEL_TENSOR.ATTN_KV_A_NORM,
|
MODEL_TENSOR.ATTN_KV_A_NORM,
|
||||||
|
|
@ -3207,6 +3238,7 @@ class VisionProjectorType:
|
||||||
LIGHTONOCR = "lightonocr"
|
LIGHTONOCR = "lightonocr"
|
||||||
COGVLM = "cogvlm"
|
COGVLM = "cogvlm"
|
||||||
JANUS_PRO = "janus_pro"
|
JANUS_PRO = "janus_pro"
|
||||||
|
DEEPSEEKOCR = "deepseekocr"
|
||||||
|
|
||||||
|
|
||||||
# Items here are (block size, type size)
|
# Items here are (block size, type size)
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,8 @@ from __future__ import annotations
|
||||||
|
|
||||||
from typing import Sequence
|
from typing import Sequence
|
||||||
|
|
||||||
|
from numpy.f2py.auxfuncs import throw_error
|
||||||
|
|
||||||
from .constants import MODEL_ARCH, MODEL_TENSOR, MODEL_TENSORS, TENSOR_NAMES
|
from .constants import MODEL_ARCH, MODEL_TENSOR, MODEL_TENSORS, TENSOR_NAMES
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1457,6 +1459,58 @@ class TensorNameMap:
|
||||||
"model.visual.deepstack_merger_list.{bid}.linear_fc2", # deepstack in qwen3vl
|
"model.visual.deepstack_merger_list.{bid}.linear_fc2", # deepstack in qwen3vl
|
||||||
),
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.V_SAM_POS_EMBD: (
|
||||||
|
"model.sam_model.pos_embed"
|
||||||
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.V_SAM_PATCH_EMBD: (
|
||||||
|
"model.sam_model.patch_embed.proj"
|
||||||
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.V_SAM_PRE_NORM: (
|
||||||
|
"model.sam_model.blocks.{bid}.norm1", # deepstack in qwen3vl
|
||||||
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.V_SAM_POST_NORM: (
|
||||||
|
"model.sam_model.blocks.{bid}.norm2", # deepstack in qwen3vl
|
||||||
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.V_SAM_ATTN_POS_H: (
|
||||||
|
"model.sam_model.blocks.{bid}.attn.rel_pos_h"
|
||||||
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.V_SAM_ATTN_POS_W: (
|
||||||
|
"model.sam_model.blocks.{bid}.attn.rel_pos_w"
|
||||||
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.V_SAM_ATTN_QKV: (
|
||||||
|
"model.sam_model.blocks.{bid}.attn.qkv"
|
||||||
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.V_SAM_ATTN_OUT: (
|
||||||
|
"model.sam_model.blocks.{bid}.attn.proj"
|
||||||
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.V_SAM_MLP_LIN_1: (
|
||||||
|
"model.sam_model.blocks.{bid}.mlp.lin1",
|
||||||
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.V_SAM_MLP_LIN_2: (
|
||||||
|
"model.sam_model.blocks.{bid}.mlp.lin2",
|
||||||
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.V_SAM_NECK: (
|
||||||
|
"model.sam_model.neck.{bid}"
|
||||||
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.V_SAM_NET_2: (
|
||||||
|
"model.sam_model.net_2"
|
||||||
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.V_SAM_NET_3: (
|
||||||
|
"model.sam_model.net_3"
|
||||||
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.V_MM_POST_FC_NORM: (
|
MODEL_TENSOR.V_MM_POST_FC_NORM: (
|
||||||
"model.vision.linear_proj.norm1", # cogvlm
|
"model.vision.linear_proj.norm1", # cogvlm
|
||||||
),
|
),
|
||||||
|
|
|
||||||
|
|
@ -129,6 +129,24 @@
|
||||||
#define TN_TOK_BOI "v.boi"
|
#define TN_TOK_BOI "v.boi"
|
||||||
#define TN_TOK_EOI "v.eoi"
|
#define TN_TOK_EOI "v.eoi"
|
||||||
|
|
||||||
|
// deepseek-ocr
|
||||||
|
#define TN_SAM_POS_EMBD "sam.pos_embd"
|
||||||
|
#define TN_SAM_PATCH_EMBD "sam.patch_embd"
|
||||||
|
#define TN_SAM_PRE_NORM "sam.blk.%d.pre_ln"
|
||||||
|
#define TN_SAM_POST_NORM "sam.blk.%d.post_ln"
|
||||||
|
#define TN_SAM_ATTN_POS_H "sam.blk.%d.attn.pos_h"
|
||||||
|
#define TN_SAM_ATTN_POS_W "sam.blk.%d.attn.pos_w"
|
||||||
|
#define TN_SAM_ATTN_QKV "sam.blk.%d.attn.qkv"
|
||||||
|
#define TN_SAM_ATTN_OUT "sam.blk.%d.attn.out"
|
||||||
|
#define TN_SAM_MLP_LIN_1 "sam.blk.%d.mlp.lin1"
|
||||||
|
#define TN_SAM_MLP_LIN_2 "sam.blk.%d.mlp.lin2"
|
||||||
|
#define TN_SAM_NECK "sam.neck.%d"
|
||||||
|
#define TN_SAM_NET_2 "sam.net_2"
|
||||||
|
#define TN_SAM_NET_3 "sam.net_3"
|
||||||
|
|
||||||
|
|
||||||
|
#define TN_SAM_ATTN_OUT "sam.blk.%d.attn_out"
|
||||||
|
|
||||||
// align x to upper multiple of n
|
// align x to upper multiple of n
|
||||||
#define CLIP_ALIGN(x, n) ((((x) + (n) - 1) / (n)) * (n))
|
#define CLIP_ALIGN(x, n) ((((x) + (n) - 1) / (n)) * (n))
|
||||||
|
|
||||||
|
|
@ -156,6 +174,7 @@ enum projector_type {
|
||||||
PROJECTOR_TYPE_LIGHTONOCR,
|
PROJECTOR_TYPE_LIGHTONOCR,
|
||||||
PROJECTOR_TYPE_COGVLM,
|
PROJECTOR_TYPE_COGVLM,
|
||||||
PROJECTOR_TYPE_JANUS_PRO,
|
PROJECTOR_TYPE_JANUS_PRO,
|
||||||
|
PROJECTOR_TYPE_DEEPSEEK_OCR,
|
||||||
PROJECTOR_TYPE_UNKNOWN,
|
PROJECTOR_TYPE_UNKNOWN,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -182,6 +201,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
|
||||||
{ PROJECTOR_TYPE_LIGHTONOCR,"lightonocr"},
|
{ PROJECTOR_TYPE_LIGHTONOCR,"lightonocr"},
|
||||||
{ PROJECTOR_TYPE_COGVLM, "cogvlm"},
|
{ PROJECTOR_TYPE_COGVLM, "cogvlm"},
|
||||||
{ PROJECTOR_TYPE_JANUS_PRO, "janus_pro"},
|
{ PROJECTOR_TYPE_JANUS_PRO, "janus_pro"},
|
||||||
|
{ PROJECTOR_TYPE_DEEPSEEK_OCR,"deepseek_orc"},
|
||||||
};
|
};
|
||||||
|
|
||||||
static projector_type clip_projector_type_from_string(const std::string & str) {
|
static projector_type clip_projector_type_from_string(const std::string & str) {
|
||||||
|
|
|
||||||
|
|
@ -222,6 +222,33 @@ struct clip_hparams {
|
||||||
warmup_image_size = n_tok_per_side * patch_size * cur_merge;
|
warmup_image_size = n_tok_per_side * patch_size * cur_merge;
|
||||||
// TODO: support warmup size for custom token numbers
|
// TODO: support warmup size for custom token numbers
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// sam vit deepseek-ocr
|
||||||
|
std::vector<int32_t> global_attn_indices() const {
|
||||||
|
switch (n_embd) {
|
||||||
|
case 768: return { 2, 5, 8, 11 };
|
||||||
|
case 1024: return { 5, 11, 17, 23 };
|
||||||
|
case 1280: return { 7, 15, 23, 31 };
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
fprintf(stderr, "%s: unsupported n_enc_state = %d\n", __func__, n_embd);
|
||||||
|
} break;
|
||||||
|
};
|
||||||
|
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_global_attn(int32_t layer) const {
|
||||||
|
const auto indices = global_attn_indices();
|
||||||
|
|
||||||
|
for (const auto & idx : indices) {
|
||||||
|
if (layer == idx) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct clip_layer {
|
struct clip_layer {
|
||||||
|
|
@ -271,6 +298,10 @@ struct clip_layer {
|
||||||
bool has_deepstack() const {
|
bool has_deepstack() const {
|
||||||
return deepstack_fc1_w != nullptr;
|
return deepstack_fc1_w != nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// sam rel_pos
|
||||||
|
ggml_tensor * rel_pos_w = nullptr;
|
||||||
|
ggml_tensor * rel_pos_h = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct clip_model {
|
struct clip_model {
|
||||||
|
|
@ -308,6 +339,7 @@ struct clip_model {
|
||||||
ggml_tensor * mm_2_b = nullptr;
|
ggml_tensor * mm_2_b = nullptr;
|
||||||
|
|
||||||
ggml_tensor * image_newline = nullptr;
|
ggml_tensor * image_newline = nullptr;
|
||||||
|
ggml_tensor * view_seperator = nullptr;
|
||||||
|
|
||||||
// Yi type models with mlp+normalization projection
|
// Yi type models with mlp+normalization projection
|
||||||
ggml_tensor * mm_1_w = nullptr; // Yi type models have 0, 1, 3, 4
|
ggml_tensor * mm_1_w = nullptr; // Yi type models have 0, 1, 3, 4
|
||||||
|
|
@ -400,6 +432,11 @@ struct clip_model {
|
||||||
ggml_tensor * mm_boi = nullptr;
|
ggml_tensor * mm_boi = nullptr;
|
||||||
ggml_tensor * mm_eoi = nullptr;
|
ggml_tensor * mm_eoi = nullptr;
|
||||||
|
|
||||||
|
// deepseek ocr sam
|
||||||
|
ggml_tensor * patch_embed_proj_w = nullptr;
|
||||||
|
ggml_tensor * patch_embed_proj_b = nullptr;
|
||||||
|
ggml_tensor * pos_embed = nullptr;
|
||||||
|
|
||||||
bool audio_has_avgpool() const {
|
bool audio_has_avgpool() const {
|
||||||
return proj_type == PROJECTOR_TYPE_QWEN2A
|
return proj_type == PROJECTOR_TYPE_QWEN2A
|
||||||
|| proj_type == PROJECTOR_TYPE_VOXTRAL;
|
|| proj_type == PROJECTOR_TYPE_VOXTRAL;
|
||||||
|
|
@ -409,6 +446,15 @@ struct clip_model {
|
||||||
return proj_type == PROJECTOR_TYPE_ULTRAVOX
|
return proj_type == PROJECTOR_TYPE_ULTRAVOX
|
||||||
|| proj_type == PROJECTOR_TYPE_VOXTRAL;
|
|| proj_type == PROJECTOR_TYPE_VOXTRAL;
|
||||||
}
|
}
|
||||||
|
ggml_tensor * neck_conv_0;
|
||||||
|
ggml_tensor * neck_norm_0_w;
|
||||||
|
ggml_tensor * neck_norm_0_b;
|
||||||
|
ggml_tensor * neck_conv_1;
|
||||||
|
ggml_tensor * neck_norm_1_w;
|
||||||
|
ggml_tensor * neck_norm_1_b;
|
||||||
|
|
||||||
|
std::vector<clip_layer> enc_layers;
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct clip_ctx {
|
struct clip_ctx {
|
||||||
|
|
@ -521,9 +567,9 @@ struct clip_graph {
|
||||||
hparams(model.hparams),
|
hparams(model.hparams),
|
||||||
img(img),
|
img(img),
|
||||||
patch_size(hparams.patch_size),
|
patch_size(hparams.patch_size),
|
||||||
n_patches_x(img.nx / patch_size),
|
n_patches_x(img.nx / patch_size), // sam 1024 / 16 = 64
|
||||||
n_patches_y(img.ny / patch_size),
|
n_patches_y(img.ny / patch_size), // sam 1024 / 16 = 64
|
||||||
n_patches(n_patches_x * n_patches_y),
|
n_patches(n_patches_x * n_patches_y), // sam 64 * 64 = 4096
|
||||||
n_embd(hparams.n_embd),
|
n_embd(hparams.n_embd),
|
||||||
n_head(hparams.n_head),
|
n_head(hparams.n_head),
|
||||||
d_head(n_embd / n_head),
|
d_head(n_embd / n_head),
|
||||||
|
|
@ -619,6 +665,244 @@ struct clip_graph {
|
||||||
return gf;
|
return gf;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ggml_tensor * build_sam_enc(ggml_tensor * inp_raw,
|
||||||
|
const int enc_image_size = 1024
|
||||||
|
) {
|
||||||
|
constexpr int enc_n_embd = 768;
|
||||||
|
constexpr int _depth = 12;
|
||||||
|
constexpr int enc_n_heads = 12;
|
||||||
|
constexpr int enc_d_heads = enc_n_embd / enc_n_heads;
|
||||||
|
constexpr int _prompt_n_embd = 256;
|
||||||
|
constexpr int enc_patch_size = 16;
|
||||||
|
constexpr int _window_size = 14;
|
||||||
|
|
||||||
|
const int enc_n_patches = enc_image_size / enc_patch_size; // 64
|
||||||
|
|
||||||
|
ggml_tensor * inpL = build_enc_inp(inp_raw, enc_patch_size, enc_image_size, enc_n_embd);
|
||||||
|
ggml_tensor * cur = ggml_add(ctx0, inpL, model.position_embeddings);
|
||||||
|
|
||||||
|
// loop over layers
|
||||||
|
for (int il = 0; il < _depth; il++) {
|
||||||
|
auto & layer = model.enc_layers[il];
|
||||||
|
|
||||||
|
// layernorm1
|
||||||
|
cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, NORM_TYPE_NORMAL, eps, il);
|
||||||
|
cb(cur, "enc_layer_inp_normed", il);
|
||||||
|
|
||||||
|
const int64_t w0 = cur->ne[1];
|
||||||
|
const int64_t h0 = cur->ne[2];
|
||||||
|
|
||||||
|
if (hparams.is_global_attn(il) == false) {
|
||||||
|
// local attention layer - apply window partition
|
||||||
|
// ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L169-L172
|
||||||
|
cur = ggml_win_part(ctx0, cur, 14);
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64_t W = cur->ne[1];
|
||||||
|
const int64_t H = cur->ne[2];
|
||||||
|
|
||||||
|
// self-attention
|
||||||
|
{
|
||||||
|
cur = ggml_mul_mat(ctx0, layer.qkv_w, cur);
|
||||||
|
cur = ggml_add(ctx0, cur, layer.qkv_b);
|
||||||
|
const int B = cur->ne[3];
|
||||||
|
|
||||||
|
cur = ggml_reshape_4d(ctx0, cur, enc_n_embd, 3, W * H, B);
|
||||||
|
cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 0, 3, 1, 2));
|
||||||
|
|
||||||
|
ggml_tensor * Qcur =
|
||||||
|
ggml_view_3d(ctx0, cur, enc_n_embd, W * H, B, cur->nb[1], cur->nb[2], 0);
|
||||||
|
Qcur = ggml_reshape_4d(ctx0, Qcur, enc_d_heads, enc_n_heads, W * H, B);
|
||||||
|
Qcur = ggml_cont(ctx0, ggml_permute(ctx0, Qcur, 0, 2, 1, 3));
|
||||||
|
Qcur = ggml_reshape_3d(ctx0, Qcur, enc_d_heads, W * H, B * enc_n_heads);
|
||||||
|
|
||||||
|
ggml_tensor * Kcur =
|
||||||
|
ggml_view_3d(ctx0, cur, enc_n_embd, W * H, B, cur->nb[1], cur->nb[2], 1 * cur->nb[3]);
|
||||||
|
Kcur = ggml_reshape_4d(ctx0, Kcur, enc_d_heads, enc_n_heads, W * H, B);
|
||||||
|
Kcur = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
|
||||||
|
Kcur = ggml_reshape_3d(ctx0, Kcur, enc_d_heads, W * H, B * enc_n_heads);
|
||||||
|
|
||||||
|
ggml_tensor * Vcur =
|
||||||
|
ggml_view_3d(ctx0, cur, enc_n_embd, W * H, B, cur->nb[1], cur->nb[2], 2 * cur->nb[3]);
|
||||||
|
Vcur = ggml_reshape_4d(ctx0, Vcur, enc_d_heads, enc_n_heads, W * H, B);
|
||||||
|
Vcur = ggml_cont(ctx0, ggml_permute(ctx0, Vcur, 1, 2, 0, 3)); // transposed
|
||||||
|
Vcur = ggml_reshape_3d(ctx0, Vcur, W * H, enc_d_heads, B * enc_n_heads);
|
||||||
|
|
||||||
|
cb(Qcur, "Qcur", il);
|
||||||
|
cb(Kcur, "Kcur", il);
|
||||||
|
cb(Vcur, "Vcur", il);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, Kcur, Qcur);
|
||||||
|
|
||||||
|
struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, 1.0f / sqrtf(enc_n_heads));
|
||||||
|
|
||||||
|
struct ggml_tensor * rw = ggml_get_rel_pos(ctx0, layer.rel_pos_w, W, W);
|
||||||
|
struct ggml_tensor * rh = ggml_get_rel_pos(ctx0, layer.rel_pos_h, H, H);
|
||||||
|
|
||||||
|
struct ggml_tensor * q_r = ggml_reshape_4d(ctx0, Qcur, enc_n_heads, W, H, B * enc_n_embd);
|
||||||
|
|
||||||
|
struct ggml_tensor * rel_w = ggml_cont(
|
||||||
|
ctx0,
|
||||||
|
ggml_permute(ctx0, ggml_mul_mat(ctx0, rw, ggml_cont(ctx0, ggml_permute(ctx0, q_r, 0, 2, 1, 3))), 0,
|
||||||
|
2, 1, 3));
|
||||||
|
struct ggml_tensor * rel_h = ggml_mul_mat(ctx0, rh, q_r);
|
||||||
|
|
||||||
|
struct ggml_tensor * attn = ggml_add_rel_pos_inplace(ctx0, KQ_scaled, rel_w, rel_h);
|
||||||
|
|
||||||
|
struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, attn);
|
||||||
|
|
||||||
|
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, Vcur, KQ_soft_max);
|
||||||
|
|
||||||
|
cur = ggml_reshape_4d(
|
||||||
|
ctx0,
|
||||||
|
ggml_cont(ctx0, ggml_permute(ctx0, ggml_reshape_4d(ctx0, KQV, enc_d_heads, W * H, enc_n_heads, B),
|
||||||
|
0, 2, 1, 3)),
|
||||||
|
n_embd, W, H, B);
|
||||||
|
|
||||||
|
cur = ggml_mul_mat(ctx0, layer.o_w, cur);
|
||||||
|
cur = ggml_add_inplace(ctx0, cur, layer.o_b);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (hparams.is_global_attn(il) == false) {
|
||||||
|
// local attention layer - reverse window partition
|
||||||
|
cur = ggml_win_unpart(ctx0, cur, w0, h0, 14);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (layer.ls_1_w) {
|
||||||
|
cur = ggml_mul(ctx0, cur, layer.ls_1_w);
|
||||||
|
cb(cur, "attn_out_scaled", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
// re-add the layer input, e.g., residual
|
||||||
|
cur = ggml_add(ctx0, cur, inpL);
|
||||||
|
|
||||||
|
cb(cur, "ffn_inp", il);
|
||||||
|
|
||||||
|
// layernorm2
|
||||||
|
cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, NORM_TYPE_NORMAL, eps, il);
|
||||||
|
cb(cur, "ffn_inp_normed", il);
|
||||||
|
|
||||||
|
// ffn
|
||||||
|
cur = build_ffn(cur, layer.ff_up_w, layer.ff_up_b, layer.ff_gate_w, layer.ff_gate_b, layer.ff_down_w,
|
||||||
|
layer.ff_down_b, hparams.ffn_op, il);
|
||||||
|
|
||||||
|
cb(cur, "ffn_out", il);
|
||||||
|
|
||||||
|
if (layer.ls_2_w) {
|
||||||
|
cur = ggml_mul(ctx0, cur, layer.ls_2_w);
|
||||||
|
cb(cur, "ffn_out_scaled", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
// residual 2
|
||||||
|
cur = ggml_add(ctx0, inpL, cur);
|
||||||
|
cb(cur, "layer_out", il);
|
||||||
|
|
||||||
|
return cur; // B, 1024, 16, 16
|
||||||
|
}
|
||||||
|
|
||||||
|
cur = ggml_cont(ctx0, ggml_permute(ctx0, inpL, 2, 0, 1, 3));
|
||||||
|
|
||||||
|
cur = ggml_conv_2d_sk_p0(ctx0, model.neck_conv_0, cur);
|
||||||
|
|
||||||
|
cur = sam_layer_norm_2d(ctx0, cur, 256, model.neck_norm_0_w, model.neck_norm_0_b, hparams.eps);
|
||||||
|
|
||||||
|
cur = ggml_conv_2d_s1_ph(ctx0, model.neck_conv_1, cur);
|
||||||
|
|
||||||
|
cur = sam_layer_norm_2d(ctx0, cur, 256, model.neck_norm_1_w, model.neck_norm_1_b, hparams.eps);
|
||||||
|
|
||||||
|
//cur = ggml_cpy(ctx0, cur, state.embd_img);
|
||||||
|
|
||||||
|
ggml_build_forward_expand(gf, cur);
|
||||||
|
return cur;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor * sam_layer_norm_2d(ggml_context * ctx0,
|
||||||
|
ggml_tensor * layer,
|
||||||
|
int n_channels,
|
||||||
|
ggml_tensor * w,
|
||||||
|
ggml_tensor * b,
|
||||||
|
float eps) {
|
||||||
|
// LayerNorm2d
|
||||||
|
// normalize along channel dimmension
|
||||||
|
// TODO: better implementation
|
||||||
|
layer = ggml_permute(ctx0, ggml_norm(ctx0, ggml_cont(ctx0, ggml_permute(ctx0, layer, 1, 2, 0, 3)), eps), 2, 0,
|
||||||
|
1, 3);
|
||||||
|
|
||||||
|
layer =
|
||||||
|
ggml_add(ctx0, ggml_mul(ctx0, ggml_repeat(ctx0, ggml_reshape_3d(ctx0, w, 1, 1, n_channels), layer), layer),
|
||||||
|
ggml_repeat(ctx0, ggml_reshape_3d(ctx0, b, 1, 1, n_channels), layer));
|
||||||
|
|
||||||
|
return layer;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_cgraph * build_deepseek_ocr() {
|
||||||
|
//patch embedding
|
||||||
|
ggml_tensor * inp_raw = build_inp_raw();
|
||||||
|
|
||||||
|
|
||||||
|
ggml_tensor * global_features_1 = build_sam_enc(inp_raw);
|
||||||
|
|
||||||
|
ggml_tensor * global_features_2 = build_dp_ocr_clip(inp_raw, global_features_1);
|
||||||
|
|
||||||
|
// torch global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
|
||||||
|
ggml_tensor * global_features = ggml_concat(ctx0, global_features_1, global_features_2, 0);
|
||||||
|
global_features = build_global_local_features(
|
||||||
|
ctx0,
|
||||||
|
global_features,
|
||||||
|
n_patches_y,
|
||||||
|
n_patches_x,
|
||||||
|
n_embd
|
||||||
|
);
|
||||||
|
|
||||||
|
return gf;
|
||||||
|
}
|
||||||
|
|
||||||
|
// global_features: [n_dim, h*w]
|
||||||
|
// image_newline: [n_dim]
|
||||||
|
// view_separator: [n_dim]
|
||||||
|
|
||||||
|
ggml_tensor * build_global_local_features(ggml_context * ctx0,
|
||||||
|
ggml_tensor * global_features,
|
||||||
|
int h,
|
||||||
|
int w,
|
||||||
|
int n_dim) {
|
||||||
|
GGML_ASSERT(model.image_newline != nullptr);
|
||||||
|
GGML_ASSERT(model.view_seperator != nullptr);
|
||||||
|
GGML_ASSERT(global_features->ne[0] == (int64_t) n_dim);
|
||||||
|
GGML_ASSERT(global_features->ne[1] == (int64_t) (h * w));
|
||||||
|
|
||||||
|
// 1) global_features: [n_dim, h*w] -> [n_dim, w, h] -> [h, w, n_dim]
|
||||||
|
ggml_tensor * t = ggml_reshape_3d(ctx0, global_features, n_dim, w, h); // (n_dim, w, h)
|
||||||
|
t = ggml_permute(ctx0, t, 2, 1, 0, 3); // (h, w, n_dim)
|
||||||
|
|
||||||
|
// 2) image_newline: [n_dim] -> [1, 1, n_dim] -> repeat to [h, 1, n_dim]
|
||||||
|
ggml_tensor * nl = ggml_reshape_3d(ctx0, model.image_newline, 1, 1, n_dim); // (1, 1, n_dim)
|
||||||
|
|
||||||
|
ggml_tensor * nl_target_shape = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, h, n_dim); // (1, h, n_dim)
|
||||||
|
nl = ggml_repeat(ctx0, nl, nl_target_shape); // (1, h, n_dim)
|
||||||
|
nl = ggml_permute(ctx0, nl, 1, 0, 2, 3); // (h, 1, n_dim)
|
||||||
|
|
||||||
|
// 3) concat along width dimension (dim=1): (h, w, n_dim) + (h, 1, n_dim) -> (h, w+1, n_dim)
|
||||||
|
t = ggml_concat(ctx0, t, nl, 1); // (h, w+1, n_dim)
|
||||||
|
|
||||||
|
// 4) flatten back to token axis: (h, w+1, n_dim) -> (n_dim, h*(w+1))
|
||||||
|
t = ggml_permute(ctx0, t, 2, 1, 0, 3); // (n_dim, w+1, h)
|
||||||
|
t = ggml_cont_2d(ctx0, t, n_dim, (w + 1) * h); // (n_dim, h*(w+1))
|
||||||
|
|
||||||
|
// 5) append view_separator as an extra "token":
|
||||||
|
// view_separator: [n_dim] -> [n_dim, 1]
|
||||||
|
ggml_tensor * vs = ggml_reshape_2d(ctx0, model.view_seperator, n_dim, 1); // (n_dim, 1)
|
||||||
|
|
||||||
|
// concat along token dimension (dim=1):
|
||||||
|
ggml_tensor * global_local_features = ggml_concat(ctx0, t, vs, 1); // (n_dim, h*(w+1) + 1)
|
||||||
|
|
||||||
|
return global_local_features;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
ggml_cgraph * build_pixtral() {
|
ggml_cgraph * build_pixtral() {
|
||||||
const int n_merge = hparams.n_merge;
|
const int n_merge = hparams.n_merge;
|
||||||
|
|
||||||
|
|
@ -1215,7 +1499,7 @@ struct clip_graph {
|
||||||
norm_t,
|
norm_t,
|
||||||
hparams.ffn_op,
|
hparams.ffn_op,
|
||||||
model.position_embeddings,
|
model.position_embeddings,
|
||||||
nullptr);
|
nullptr); // shape [1024, 16, 16]
|
||||||
|
|
||||||
// remove CLS token
|
// remove CLS token
|
||||||
cur = ggml_view_2d(ctx0, cur,
|
cur = ggml_view_2d(ctx0, cur,
|
||||||
|
|
@ -1261,6 +1545,65 @@ struct clip_graph {
|
||||||
return gf;
|
return gf;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ggml_tensor * build_dp_ocr_clip(ggml_tensor * inpL, ggml_tensor * patch_embeds) {
|
||||||
|
GGML_ASSERT(model.class_embedding != nullptr);
|
||||||
|
GGML_ASSERT(model.position_embeddings != nullptr);
|
||||||
|
auto n_embd_vit_clip = 1024;
|
||||||
|
|
||||||
|
const int n_pos = n_patches + 1;
|
||||||
|
ggml_tensor * inp =
|
||||||
|
ggml_cont_3d(ctx0, ggml_dup_tensor(ctx0, patch_embeds), patch_embeds->ne[0], n_patches_x, n_patches_y);
|
||||||
|
//ggml_tensor * inp = ggml_cpy(ctx0, inpL, ggml_dup_tensor(ctx0, inpL));
|
||||||
|
|
||||||
|
// add CLS token
|
||||||
|
inp = ggml_concat(ctx0, inp, model.class_embedding, 1);
|
||||||
|
|
||||||
|
// The larger models use a different ViT, which uses RMS norm instead of layer norm
|
||||||
|
// ref: https://github.com/ggml-org/llama.cpp/pull/13443#issuecomment-2869786188
|
||||||
|
norm_type norm_t = (hparams.n_embd == 3200 && hparams.n_layer == 45) ?
|
||||||
|
NORM_TYPE_RMS // 6B ViT (Used by InternVL 2.5/3 - 26B, 38B, 78B)
|
||||||
|
:
|
||||||
|
NORM_TYPE_NORMAL; // 300M ViT (Used by all smaller InternVL models)
|
||||||
|
|
||||||
|
ggml_tensor * cur = build_vit(inp, n_pos, norm_t, hparams.ffn_op, model.position_embeddings,
|
||||||
|
nullptr); // shape [1024, 16, 16]
|
||||||
|
|
||||||
|
// remove CLS token
|
||||||
|
cur = ggml_view_2d(ctx0, cur, n_embd, n_patches, ggml_row_size(cur->type, n_embd), 0);
|
||||||
|
|
||||||
|
// pixel shuffle
|
||||||
|
{
|
||||||
|
const int scale_factor = model.hparams.n_merge;
|
||||||
|
const int bsz = 1; // batch size, always 1 for now since we don't support batching
|
||||||
|
const int height = n_patches_y;
|
||||||
|
const int width = n_patches_x;
|
||||||
|
GGML_ASSERT(scale_factor > 0);
|
||||||
|
cur = ggml_reshape_4d(ctx0, cur, n_embd * scale_factor, height / scale_factor, width, bsz);
|
||||||
|
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
|
||||||
|
cur = ggml_cont_4d(ctx0, cur, n_embd * scale_factor * scale_factor, height / scale_factor,
|
||||||
|
width / scale_factor, bsz);
|
||||||
|
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
|
||||||
|
// flatten to 2D
|
||||||
|
cur = ggml_cont_2d(ctx0, cur, n_embd * scale_factor * scale_factor, cur->ne[1] * cur->ne[2]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// projector (always using GELU activation)
|
||||||
|
{
|
||||||
|
// projector LayerNorm uses pytorch's default eps = 1e-5
|
||||||
|
// ref: https://huggingface.co/OpenGVLab/InternVL3-8B-Instruct/blob/a34d3e4e129a5856abfd6aa6de79776484caa14e/modeling_internvl_chat.py#L79
|
||||||
|
cur = build_norm(cur, model.mm_0_w, model.mm_0_b, NORM_TYPE_NORMAL, 1e-5, -1);
|
||||||
|
cur = ggml_mul_mat(ctx0, model.mm_1_w, cur);
|
||||||
|
cur = ggml_add(ctx0, cur, model.mm_1_b);
|
||||||
|
cur = ggml_gelu(ctx0, cur);
|
||||||
|
cur = ggml_mul_mat(ctx0, model.mm_3_w, cur);
|
||||||
|
cur = ggml_add(ctx0, cur, model.mm_3_b);
|
||||||
|
}
|
||||||
|
|
||||||
|
// build the graph
|
||||||
|
|
||||||
|
return cur;
|
||||||
|
}
|
||||||
|
|
||||||
ggml_cgraph * build_llama4() {
|
ggml_cgraph * build_llama4() {
|
||||||
GGML_ASSERT(model.class_embedding != nullptr);
|
GGML_ASSERT(model.class_embedding != nullptr);
|
||||||
GGML_ASSERT(model.position_embeddings != nullptr);
|
GGML_ASSERT(model.position_embeddings != nullptr);
|
||||||
|
|
@ -2164,18 +2507,41 @@ private:
|
||||||
return inpL;
|
return inpL;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// build the input after conv2d (inp_raw --> patches)
|
||||||
|
// returns tensor with shape [n_embd, n_patches]
|
||||||
|
ggml_tensor * build_enc_inp(ggml_tensor * inp_raw,
|
||||||
|
const int enc_patch_size,
|
||||||
|
const int enc_n_patches,
|
||||||
|
const int enc_n_embd) {
|
||||||
|
GGML_ASSERT(model.patch_embed_proj_w != nullptr);
|
||||||
|
GGML_ASSERT(model.patch_embed_proj_b != nullptr);
|
||||||
|
// Image to Patch Embedding.
|
||||||
|
// ggml_tensor * inp_raw = build_inp_raw(); // sam shape = [1024, 1024, 3]
|
||||||
|
// patch_embed_proj_w shape = [768, 3, 16, 16]
|
||||||
|
ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embed_proj_w, inp_raw, enc_patch_size, enc_patch_size, 0, 0,
|
||||||
|
1, 1); // [64, 64, 768]
|
||||||
|
inp = ggml_reshape_2d(ctx0, inp, enc_n_patches, enc_n_embd); // [4096, 768]
|
||||||
|
inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp)); // [768, 4096]
|
||||||
|
inp = ggml_add(ctx0, inp, model.patch_embed_proj_b);
|
||||||
|
cb(inp, "enc_patch_bias", -1);
|
||||||
|
return inp;
|
||||||
|
}
|
||||||
|
|
||||||
// build the input after conv2d (inp_raw --> patches)
|
// build the input after conv2d (inp_raw --> patches)
|
||||||
// returns tensor with shape [n_embd, n_patches]
|
// returns tensor with shape [n_embd, n_patches]
|
||||||
ggml_tensor * build_inp() {
|
ggml_tensor * build_inp() {
|
||||||
ggml_tensor * inp_raw = build_inp_raw();
|
// Image to Patch Embedding.
|
||||||
ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
|
ggml_tensor * inp_raw = build_inp_raw(); // sam shape = [1024, 1024, 3]
|
||||||
inp = ggml_reshape_2d(ctx0, inp, n_patches, n_embd);
|
// sam patch_embeddings_0 shape = [768, 3, 16, 16]
|
||||||
inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp));
|
ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1); // sam shape = [64, 64, 768]
|
||||||
|
inp = ggml_reshape_2d(ctx0, inp, n_patches, n_embd); // sam shape = [4096, 768]
|
||||||
|
inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp)); // sam shape = [768, 4096]
|
||||||
if (model.patch_bias) {
|
if (model.patch_bias) {
|
||||||
|
// sam patch_bias shape = [768]
|
||||||
inp = ggml_add(ctx0, inp, model.patch_bias);
|
inp = ggml_add(ctx0, inp, model.patch_bias);
|
||||||
cb(inp, "patch_bias", -1);
|
cb(inp, "patch_bias", -1);
|
||||||
}
|
}
|
||||||
return inp;
|
return inp; // shape = [n_embd, n_patches] same as [768, 4096]
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * build_inp_raw(int channels = 3) {
|
ggml_tensor * build_inp_raw(int channels = 3) {
|
||||||
|
|
@ -3236,6 +3602,10 @@ struct clip_model_loader {
|
||||||
model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"));
|
model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"));
|
||||||
model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"));
|
model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"));
|
||||||
} break;
|
} break;
|
||||||
|
case PROJECTOR_TYPE_DEEPSEEK_OCR:
|
||||||
|
{
|
||||||
|
}
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
GGML_ASSERT(false && "unknown projector type");
|
GGML_ASSERT(false && "unknown projector type");
|
||||||
}
|
}
|
||||||
|
|
@ -4192,6 +4562,59 @@ private:
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
static std::vector<std::pair<int, int>> ds_build_target_ratios(const int min_num, const int max_num) {
|
||||||
|
std::vector<std::pair<int, int>> ratios;
|
||||||
|
for (int n = min_num; n <= max_num; ++n) {
|
||||||
|
for (int i = 1; i <= n; ++i) {
|
||||||
|
for (int j = 1; j <= n; ++j) {
|
||||||
|
if (const int blocks = i * j; blocks >= min_num && blocks <= max_num) {
|
||||||
|
ratios.emplace_back(i, j); // (cols, rows)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// sort by total blocks like in Python (key=lambda x: x[0] * x[1])
|
||||||
|
std::sort(ratios.begin(), ratios.end(),
|
||||||
|
[](const auto &a, const auto &b) {
|
||||||
|
return (a.first * a.second) < (b.first * b.second);
|
||||||
|
});
|
||||||
|
|
||||||
|
// optional: dedup
|
||||||
|
ratios.erase(std::unique(ratios.begin(), ratios.end()), ratios.end());
|
||||||
|
return ratios;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::pair<int, int> ds_find_closest_aspect_ratio(
|
||||||
|
const float aspect_ratio,
|
||||||
|
const std::vector<std::pair<int, int>> &target_ratios,
|
||||||
|
const int width,
|
||||||
|
const int height,
|
||||||
|
const int image_size
|
||||||
|
) {
|
||||||
|
float best_diff = std::numeric_limits<float>::infinity();
|
||||||
|
std::pair<int, int> best_ratio = {1, 1};
|
||||||
|
const float area = static_cast<float>(width) * static_cast<float>(height);
|
||||||
|
|
||||||
|
for (const auto &r : target_ratios) {
|
||||||
|
const float target_ar = static_cast<float>(r.first) / static_cast<float>(r.second);
|
||||||
|
|
||||||
|
if (const float diff = std::fabs(aspect_ratio - target_ar); diff < best_diff) {
|
||||||
|
best_diff = diff;
|
||||||
|
best_ratio = r;
|
||||||
|
} else if (diff == best_diff) {
|
||||||
|
// same as python: prefer this ratio if the image area is “large enough”
|
||||||
|
if (const float needed_area = 0.5f * image_size * image_size * r.first * r.second; area > needed_area) {
|
||||||
|
best_ratio = r;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return best_ratio; // (cols, rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
// returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector
|
// returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector
|
||||||
// res_imgs memory is being allocated here, previous allocations will be freed if found
|
// res_imgs memory is being allocated here, previous allocations will be freed if found
|
||||||
bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, struct clip_image_f32_batch * res_imgs) {
|
bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, struct clip_image_f32_batch * res_imgs) {
|
||||||
|
|
@ -4406,6 +4829,69 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case PROJECTOR_TYPE_DEEPSEEK_OCR:
|
||||||
|
{
|
||||||
|
// configurable, or read from params
|
||||||
|
const int min_num = 2;
|
||||||
|
const int max_num = 9;
|
||||||
|
const int image_size = params.image_size; // typically 640
|
||||||
|
const bool use_thumbnail = true; // mimic python's use_thumbnail
|
||||||
|
|
||||||
|
// original image size
|
||||||
|
const int orig_w = original_size.width;
|
||||||
|
const int orig_h = original_size.height;
|
||||||
|
|
||||||
|
// 1) build candidate grids (cols, rows)
|
||||||
|
auto target_ratios = ds_build_target_ratios(min_num, max_num);
|
||||||
|
|
||||||
|
// 2) pick the grid that best matches the original aspect ratio
|
||||||
|
const float aspect_ratio = static_cast<float>(orig_w) / static_cast<float>(orig_h);
|
||||||
|
auto best = ds_find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_w, orig_h, image_size);
|
||||||
|
const int grid_cols = best.first; // how many tiles horizontally
|
||||||
|
const int grid_rows = best.second; // how many tiles vertically
|
||||||
|
|
||||||
|
// 3) compute the target (forced) size — python did:
|
||||||
|
// target_width = image_size * cols
|
||||||
|
// target_height = image_size * rows
|
||||||
|
const clip_image_size refined_size{ image_size * grid_cols, image_size * grid_rows };
|
||||||
|
|
||||||
|
// 4) prepare slice instructions, same style as the idefics3 branch
|
||||||
|
llava_uhd::slice_instructions instructions;
|
||||||
|
instructions.overview_size = clip_image_size{ image_size, image_size }; // for thumbnail/global
|
||||||
|
instructions.refined_size = refined_size;
|
||||||
|
instructions.grid_size = clip_image_size{ grid_cols, grid_rows };
|
||||||
|
|
||||||
|
// in deepseek python they always produce *full* 640x640 blocks,
|
||||||
|
// so we can do a simple double loop over rows/cols:
|
||||||
|
for (int r = 0; r < grid_rows; ++r) {
|
||||||
|
for (int c = 0; c < grid_cols; ++c) {
|
||||||
|
const int x = c * image_size;
|
||||||
|
const int y = r * image_size;
|
||||||
|
|
||||||
|
instructions.slices.push_back(llava_uhd::slice_coordinates{
|
||||||
|
/* x */ x,
|
||||||
|
/* y */ y,
|
||||||
|
/* size */ clip_image_size{ image_size, image_size }
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5) run the actual slicing (this should: resize to refined_size, then crop every slice)
|
||||||
|
auto imgs = llava_uhd::slice_image(img, instructions);
|
||||||
|
|
||||||
|
// 7) cast & normalize like the idefics3 branch
|
||||||
|
for (size_t i = 0; i < imgs.size(); ++i) {
|
||||||
|
clip_image_f32_ptr res(clip_image_f32_init());
|
||||||
|
normalize_image_u8_to_f32(*imgs[i], *res, params.image_mean, params.image_std);
|
||||||
|
res_imgs->entries.push_back(std::move(res));
|
||||||
|
}
|
||||||
|
|
||||||
|
// keep the grid info — the model may need to know how to reassemble / attend
|
||||||
|
res_imgs->grid_x = grid_cols;
|
||||||
|
res_imgs->grid_y = grid_rows;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
|
||||||
|
|
||||||
default:
|
default:
|
||||||
LOG_ERR("%s: unsupported projector type %d\n", __func__, ctx->proj_type());
|
LOG_ERR("%s: unsupported projector type %d\n", __func__, ctx->proj_type());
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue