model: support GLM4V vision encoder (#18042)
* convert ok * no deepstack * less new tensors * cgraph ok * add mrope for text model * faster patch merger * add GGML_ROPE_TYPE_MRNORM * add support for metal * move glm4v do dedicated graph * convert: add norm_embd * clip: add debugging fn * working correctly * fix style * use bicubic * fix mrope metal * improve cpu * convert to neox ordering on conversion * revert backend changes * force stop if using old weight * support moe variant * fix conversion * fix convert (2) * Update tools/mtmd/clip-graph.h Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * process mrope_section on TextModel base class * resolve conflict merge --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
9963b81f63
commit
3d86c6c2b5
|
|
@ -862,6 +862,14 @@ class TextModel(ModelBase):
|
||||||
logger.warning(f"Unknown RoPE type: {rope_type}")
|
logger.warning(f"Unknown RoPE type: {rope_type}")
|
||||||
logger.info(f"gguf: rope scaling type = {rope_gguf_type.name}")
|
logger.info(f"gguf: rope scaling type = {rope_gguf_type.name}")
|
||||||
|
|
||||||
|
if "mrope_section" in self.rope_parameters:
|
||||||
|
mrope_section = self.rope_parameters["mrope_section"]
|
||||||
|
# Pad to 4 dimensions [time, height, width, extra]
|
||||||
|
while len(mrope_section) < 4:
|
||||||
|
mrope_section.append(0)
|
||||||
|
self.gguf_writer.add_rope_dimension_sections(mrope_section[:4])
|
||||||
|
logger.info(f"gguf: mrope sections: {mrope_section[:4]}")
|
||||||
|
|
||||||
if (rope_theta := rope_params.get("rope_theta")) is not None:
|
if (rope_theta := rope_params.get("rope_theta")) is not None:
|
||||||
self.gguf_writer.add_rope_freq_base(rope_theta)
|
self.gguf_writer.add_rope_freq_base(rope_theta)
|
||||||
logger.info(f"gguf: rope theta = {rope_theta}")
|
logger.info(f"gguf: rope theta = {rope_theta}")
|
||||||
|
|
@ -3739,9 +3747,6 @@ class Qwen2VLModel(TextModel):
|
||||||
|
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
super().set_gguf_parameters()
|
super().set_gguf_parameters()
|
||||||
mrope_section = self.hparams["rope_scaling"]["mrope_section"]
|
|
||||||
mrope_section += [0] * max(0, 4 - len(mrope_section))
|
|
||||||
self.gguf_writer.add_rope_dimension_sections(mrope_section)
|
|
||||||
|
|
||||||
def set_vocab(self):
|
def set_vocab(self):
|
||||||
try:
|
try:
|
||||||
|
|
@ -4377,6 +4382,30 @@ class Qwen3VLVisionModel(MmprojModel):
|
||||||
return super().modify_tensors(data_torch, name, bid)
|
return super().modify_tensors(data_torch, name, bid)
|
||||||
|
|
||||||
|
|
||||||
|
@ModelBase.register("Glm4vForConditionalGeneration", "Glm4vMoeForConditionalGeneration")
|
||||||
|
class Glm4VVisionModel(Qwen3VLVisionModel):
|
||||||
|
def set_gguf_parameters(self):
|
||||||
|
MmprojModel.set_gguf_parameters(self) # skip Qwen3VLVisionModel parameters
|
||||||
|
assert self.hparams_vision is not None
|
||||||
|
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.GLM4V)
|
||||||
|
|
||||||
|
hidden_act = str(self.hparams_vision.get("hidden_act", "")).lower()
|
||||||
|
if hidden_act == "gelu":
|
||||||
|
self.gguf_writer.add_vision_use_gelu(True)
|
||||||
|
elif hidden_act == "silu":
|
||||||
|
self.gguf_writer.add_vision_use_silu(True)
|
||||||
|
|
||||||
|
rms_norm_eps = self.hparams_vision.get("rms_norm_eps", 1e-5)
|
||||||
|
self.gguf_writer.add_vision_attention_layernorm_eps(rms_norm_eps)
|
||||||
|
|
||||||
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||||
|
if name.startswith("model.visual."):
|
||||||
|
name = name.replace("model.visual.", "visual.")
|
||||||
|
if name.startswith("visual.merger."):
|
||||||
|
return [(self.map_tensor_name(name), data_torch)]
|
||||||
|
return super().modify_tensors(data_torch, name, bid)
|
||||||
|
|
||||||
|
|
||||||
@ModelBase.register("Qwen3VLForConditionalGeneration")
|
@ModelBase.register("Qwen3VLForConditionalGeneration")
|
||||||
class Qwen3VLTextModel(Qwen3Model):
|
class Qwen3VLTextModel(Qwen3Model):
|
||||||
model_arch = gguf.MODEL_ARCH.QWEN3VL
|
model_arch = gguf.MODEL_ARCH.QWEN3VL
|
||||||
|
|
@ -4385,20 +4414,6 @@ class Qwen3VLTextModel(Qwen3Model):
|
||||||
super().set_gguf_parameters()
|
super().set_gguf_parameters()
|
||||||
|
|
||||||
# Handle MRoPE (Multi-axis Rotary Position Embedding) for Qwen3-VL
|
# Handle MRoPE (Multi-axis Rotary Position Embedding) for Qwen3-VL
|
||||||
text_config = self.hparams.get("text_config", {})
|
|
||||||
# rope_scaling is deprecated in V5, use rope_parameters instead
|
|
||||||
rope_scaling = text_config.get("rope_scaling") or text_config.get("rope_parameters") or {}
|
|
||||||
|
|
||||||
if rope_scaling.get("mrope_section"):
|
|
||||||
# mrope_section contains [time, height, width] dimensions
|
|
||||||
mrope_section = rope_scaling["mrope_section"]
|
|
||||||
# Pad to 4 dimensions [time, height, width, extra]
|
|
||||||
while len(mrope_section) < 4:
|
|
||||||
mrope_section.append(0)
|
|
||||||
self.gguf_writer.add_rope_dimension_sections(mrope_section[:4])
|
|
||||||
|
|
||||||
logger.info(f"MRoPE sections: {mrope_section[:4]}")
|
|
||||||
|
|
||||||
vision_config = self.hparams.get("vision_config", {})
|
vision_config = self.hparams.get("vision_config", {})
|
||||||
deepstack_layer_num = len(vision_config.get("deepstack_visual_indexes", []))
|
deepstack_layer_num = len(vision_config.get("deepstack_visual_indexes", []))
|
||||||
self.gguf_writer.add_num_deepstack_layers(deepstack_layer_num)
|
self.gguf_writer.add_num_deepstack_layers(deepstack_layer_num)
|
||||||
|
|
@ -4417,22 +4432,6 @@ class Qwen3VLMoeTextModel(Qwen3MoeModel):
|
||||||
|
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
super().set_gguf_parameters()
|
super().set_gguf_parameters()
|
||||||
|
|
||||||
# Handle MRoPE (Multi-axis Rotary Position Embedding) for Qwen3-VL
|
|
||||||
text_config = self.hparams.get("text_config", {})
|
|
||||||
# rope_scaling is deprecated in V5, use rope_parameters instead
|
|
||||||
rope_scaling = text_config.get("rope_scaling") or text_config.get("rope_parameters") or {}
|
|
||||||
|
|
||||||
if rope_scaling.get("mrope_section"):
|
|
||||||
# mrope_section contains [time, height, width] dimensions
|
|
||||||
mrope_section = rope_scaling["mrope_section"]
|
|
||||||
# Pad to 4 dimensions [time, height, width, extra]
|
|
||||||
while len(mrope_section) < 4:
|
|
||||||
mrope_section.append(0)
|
|
||||||
self.gguf_writer.add_rope_dimension_sections(mrope_section[:4])
|
|
||||||
|
|
||||||
logger.info(f"MRoPE sections: {mrope_section[:4]}")
|
|
||||||
|
|
||||||
vision_config = self.hparams.get("vision_config", {})
|
vision_config = self.hparams.get("vision_config", {})
|
||||||
deepstack_layer_num = len(vision_config.get("deepstack_visual_indexes", []))
|
deepstack_layer_num = len(vision_config.get("deepstack_visual_indexes", []))
|
||||||
self.gguf_writer.add_num_deepstack_layers(deepstack_layer_num)
|
self.gguf_writer.add_num_deepstack_layers(deepstack_layer_num)
|
||||||
|
|
@ -7795,6 +7794,15 @@ class JaisModel(TextModel):
|
||||||
@ModelBase.register("Glm4ForCausalLM", "Glm4vForConditionalGeneration")
|
@ModelBase.register("Glm4ForCausalLM", "Glm4vForConditionalGeneration")
|
||||||
class Glm4Model(TextModel):
|
class Glm4Model(TextModel):
|
||||||
model_arch = gguf.MODEL_ARCH.GLM4
|
model_arch = gguf.MODEL_ARCH.GLM4
|
||||||
|
use_mrope = False
|
||||||
|
partial_rotary_factor = 0.5
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.partial_rotary_factor = self.rope_parameters.get("partial_rotary_factor", 0.5)
|
||||||
|
if "mrope_section" in self.rope_parameters:
|
||||||
|
self.use_mrope = True
|
||||||
|
logger.info("Q/K weight will need to be permuted for M-RoPE")
|
||||||
|
|
||||||
def set_vocab(self):
|
def set_vocab(self):
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
@ -7816,17 +7824,49 @@ class Glm4Model(TextModel):
|
||||||
super().set_gguf_parameters()
|
super().set_gguf_parameters()
|
||||||
if (rope_dim := self.hparams.get("head_dim")) is None:
|
if (rope_dim := self.hparams.get("head_dim")) is None:
|
||||||
rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
|
rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
|
||||||
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5)))
|
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.partial_rotary_factor))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def normal_to_neox(weights: Tensor, n_head: int, n_head_kv: int, head_dim: int, partial_rotary_factor: float) -> Tensor:
|
||||||
|
orig_shape = weights.shape
|
||||||
|
if len(orig_shape) == 1:
|
||||||
|
weights = weights.unsqueeze(1) # [out_dim, 1]
|
||||||
|
if len(weights.shape) != 2:
|
||||||
|
raise ValueError("Only 1D and 2D tensors are supported.")
|
||||||
|
n_effective_heads = weights.shape[0] // head_dim
|
||||||
|
if n_head_kv is not None and n_effective_heads != n_head:
|
||||||
|
if n_effective_heads != n_head_kv:
|
||||||
|
raise AssertionError(f"Mismatch in effective heads: computed {n_effective_heads}, expected {n_head} or {n_head_kv}")
|
||||||
|
rotary_dim = int(head_dim * partial_rotary_factor)
|
||||||
|
if rotary_dim % 2 != 0:
|
||||||
|
raise ValueError("rotary_dim must be even.")
|
||||||
|
reshaped = weights.reshape(n_effective_heads, head_dim, -1)
|
||||||
|
rot_part = reshaped[:, :rotary_dim, :]
|
||||||
|
non_rot_part = reshaped[:, rotary_dim:, :]
|
||||||
|
permuted_rot = torch.cat((rot_part[:, ::2, :], rot_part[:, 1::2, :]), dim=1)
|
||||||
|
combined = torch.cat((permuted_rot, non_rot_part), dim=1)
|
||||||
|
result = combined.reshape(weights.shape)
|
||||||
|
return result if len(orig_shape) != 1 else result.squeeze(1)
|
||||||
|
|
||||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||||
if name.startswith("model.visual."): # ignore visual part of Glm4v
|
if name.startswith("model.visual."): # ignore visual part of Glm4v
|
||||||
return []
|
return []
|
||||||
elif name.startswith("model.language_model."):
|
elif name.startswith("model.language_model."):
|
||||||
name = name.replace("language_model.", "") # for Glm4v
|
name = name.replace("language_model.", "") # for Glm4v
|
||||||
|
if self.use_mrope:
|
||||||
|
n_head = self.hparams["num_attention_heads"]
|
||||||
|
n_kv_head = self.hparams["num_key_value_heads"]
|
||||||
|
n_embd = self.hparams["hidden_size"]
|
||||||
|
head_dim = n_embd // n_head
|
||||||
|
# because llama.cpp M-RoPE kernel only supports Neox ordering, we have to permute the weights here
|
||||||
|
if name.endswith(("q_proj.weight", "q_proj.bias")):
|
||||||
|
data_torch = Glm4Model.normal_to_neox(data_torch, n_head, n_head, head_dim, self.partial_rotary_factor)
|
||||||
|
if name.endswith(("k_proj.weight", "k_proj.bias")):
|
||||||
|
data_torch = Glm4Model.normal_to_neox(data_torch, n_head, n_kv_head, head_dim, self.partial_rotary_factor)
|
||||||
return super().modify_tensors(data_torch, name, bid)
|
return super().modify_tensors(data_torch, name, bid)
|
||||||
|
|
||||||
|
|
||||||
@ModelBase.register("Glm4MoeForCausalLM")
|
@ModelBase.register("Glm4MoeForCausalLM", "Glm4vMoeForConditionalGeneration")
|
||||||
class Glm4MoeModel(TextModel):
|
class Glm4MoeModel(TextModel):
|
||||||
model_arch = gguf.MODEL_ARCH.GLM4_MOE
|
model_arch = gguf.MODEL_ARCH.GLM4_MOE
|
||||||
|
|
||||||
|
|
@ -7893,6 +7933,7 @@ class Glm4MoeModel(TextModel):
|
||||||
|
|
||||||
_experts: list[dict[str, Tensor]] | None = None
|
_experts: list[dict[str, Tensor]] | None = None
|
||||||
|
|
||||||
|
# note: unlike GLM4V non-MoE, we don't need to permute Q/K here since GLM4V_MOE uses Neox ordering already
|
||||||
def modify_tensors(
|
def modify_tensors(
|
||||||
self, data_torch: Tensor, name: str, bid: int | None
|
self, data_torch: Tensor, name: str, bid: int | None
|
||||||
) -> Iterable[tuple[str, Tensor]]:
|
) -> Iterable[tuple[str, Tensor]]:
|
||||||
|
|
|
||||||
|
|
@ -643,6 +643,7 @@ class MODEL_TENSOR(IntEnum):
|
||||||
V_MMPROJ_PEG = auto()
|
V_MMPROJ_PEG = auto()
|
||||||
V_ENC_EMBD_CLS = auto()
|
V_ENC_EMBD_CLS = auto()
|
||||||
V_ENC_EMBD_PATCH = auto()
|
V_ENC_EMBD_PATCH = auto()
|
||||||
|
V_ENC_EMBD_NORM = auto()
|
||||||
V_ENC_EMBD_POS = auto()
|
V_ENC_EMBD_POS = auto()
|
||||||
V_ENC_INPUT_NORM = auto()
|
V_ENC_INPUT_NORM = auto()
|
||||||
V_ENC_ATTN_QKV = auto()
|
V_ENC_ATTN_QKV = auto()
|
||||||
|
|
@ -661,6 +662,7 @@ class MODEL_TENSOR(IntEnum):
|
||||||
V_LAYER_SCALE_2 = auto()
|
V_LAYER_SCALE_2 = auto()
|
||||||
V_PRE_NORM = auto()
|
V_PRE_NORM = auto()
|
||||||
V_POST_NORM = auto()
|
V_POST_NORM = auto()
|
||||||
|
V_MM_POST_NORM = auto()
|
||||||
V_MM_INP_NORM = auto()
|
V_MM_INP_NORM = auto()
|
||||||
V_MM_INP_PROJ = auto() # gemma3
|
V_MM_INP_PROJ = auto() # gemma3
|
||||||
V_MM_SOFT_EMB_NORM = auto() # gemma3
|
V_MM_SOFT_EMB_NORM = auto() # gemma3
|
||||||
|
|
@ -1016,6 +1018,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||||
MODEL_TENSOR.V_MMPROJ_PEG: "mm.model.peg.{bid}",
|
MODEL_TENSOR.V_MMPROJ_PEG: "mm.model.peg.{bid}",
|
||||||
MODEL_TENSOR.V_ENC_EMBD_CLS: "v.class_embd",
|
MODEL_TENSOR.V_ENC_EMBD_CLS: "v.class_embd",
|
||||||
MODEL_TENSOR.V_ENC_EMBD_PATCH: "v.patch_embd",
|
MODEL_TENSOR.V_ENC_EMBD_PATCH: "v.patch_embd",
|
||||||
|
MODEL_TENSOR.V_ENC_EMBD_NORM: "v.norm_embd",
|
||||||
MODEL_TENSOR.V_ENC_EMBD_POS: "v.position_embd",
|
MODEL_TENSOR.V_ENC_EMBD_POS: "v.position_embd",
|
||||||
MODEL_TENSOR.V_ENC_ATTN_QKV: "v.blk.{bid}.attn_qkv",
|
MODEL_TENSOR.V_ENC_ATTN_QKV: "v.blk.{bid}.attn_qkv",
|
||||||
MODEL_TENSOR.V_ENC_ATTN_Q: "v.blk.{bid}.attn_q",
|
MODEL_TENSOR.V_ENC_ATTN_Q: "v.blk.{bid}.attn_q",
|
||||||
|
|
@ -1034,6 +1037,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||||
MODEL_TENSOR.V_LAYER_SCALE_2: "v.blk.{bid}.ls2",
|
MODEL_TENSOR.V_LAYER_SCALE_2: "v.blk.{bid}.ls2",
|
||||||
MODEL_TENSOR.V_PRE_NORM: "v.pre_ln",
|
MODEL_TENSOR.V_PRE_NORM: "v.pre_ln",
|
||||||
MODEL_TENSOR.V_POST_NORM: "v.post_ln",
|
MODEL_TENSOR.V_POST_NORM: "v.post_ln",
|
||||||
|
MODEL_TENSOR.V_MM_POST_NORM: "mm.post_norm",
|
||||||
MODEL_TENSOR.V_MM_INP_PROJ: "mm.input_projection",
|
MODEL_TENSOR.V_MM_INP_PROJ: "mm.input_projection",
|
||||||
MODEL_TENSOR.V_MM_INP_NORM: "mm.input_norm",
|
MODEL_TENSOR.V_MM_INP_NORM: "mm.input_norm",
|
||||||
MODEL_TENSOR.V_MM_SOFT_EMB_NORM: "mm.soft_emb_norm",
|
MODEL_TENSOR.V_MM_SOFT_EMB_NORM: "mm.soft_emb_norm",
|
||||||
|
|
@ -1094,6 +1098,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||||
MODEL_TENSOR.V_MMPROJ_PEG,
|
MODEL_TENSOR.V_MMPROJ_PEG,
|
||||||
MODEL_TENSOR.V_ENC_EMBD_CLS,
|
MODEL_TENSOR.V_ENC_EMBD_CLS,
|
||||||
MODEL_TENSOR.V_ENC_EMBD_PATCH,
|
MODEL_TENSOR.V_ENC_EMBD_PATCH,
|
||||||
|
MODEL_TENSOR.V_ENC_EMBD_NORM,
|
||||||
MODEL_TENSOR.V_ENC_EMBD_POS,
|
MODEL_TENSOR.V_ENC_EMBD_POS,
|
||||||
MODEL_TENSOR.V_ENC_INPUT_NORM,
|
MODEL_TENSOR.V_ENC_INPUT_NORM,
|
||||||
MODEL_TENSOR.V_ENC_ATTN_QKV,
|
MODEL_TENSOR.V_ENC_ATTN_QKV,
|
||||||
|
|
@ -1112,6 +1117,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||||
MODEL_TENSOR.V_LAYER_SCALE_2,
|
MODEL_TENSOR.V_LAYER_SCALE_2,
|
||||||
MODEL_TENSOR.V_PRE_NORM,
|
MODEL_TENSOR.V_PRE_NORM,
|
||||||
MODEL_TENSOR.V_POST_NORM,
|
MODEL_TENSOR.V_POST_NORM,
|
||||||
|
MODEL_TENSOR.V_MM_POST_NORM,
|
||||||
MODEL_TENSOR.V_MM_INP_PROJ,
|
MODEL_TENSOR.V_MM_INP_PROJ,
|
||||||
MODEL_TENSOR.V_MM_INP_NORM,
|
MODEL_TENSOR.V_MM_INP_NORM,
|
||||||
MODEL_TENSOR.V_MM_SOFT_EMB_NORM,
|
MODEL_TENSOR.V_MM_SOFT_EMB_NORM,
|
||||||
|
|
@ -3357,6 +3363,7 @@ class VisionProjectorType:
|
||||||
LIGHTONOCR = "lightonocr"
|
LIGHTONOCR = "lightonocr"
|
||||||
COGVLM = "cogvlm"
|
COGVLM = "cogvlm"
|
||||||
JANUS_PRO = "janus_pro"
|
JANUS_PRO = "janus_pro"
|
||||||
|
GLM4V = "glm4v"
|
||||||
|
|
||||||
|
|
||||||
# Items here are (block size, type size)
|
# Items here are (block size, type size)
|
||||||
|
|
|
||||||
|
|
@ -1212,6 +1212,7 @@ class TensorNameMap:
|
||||||
MODEL_TENSOR.V_MMPROJ_FC: (
|
MODEL_TENSOR.V_MMPROJ_FC: (
|
||||||
"model.connector.modality_projection.proj", # SmolVLM
|
"model.connector.modality_projection.proj", # SmolVLM
|
||||||
"model.vision.linear_proj.linear_proj", # cogvlm
|
"model.vision.linear_proj.linear_proj", # cogvlm
|
||||||
|
"visual.merger.proj", # glm4v
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.V_MMPROJ_MLP: (
|
MODEL_TENSOR.V_MMPROJ_MLP: (
|
||||||
|
|
@ -1245,6 +1246,10 @@ class TensorNameMap:
|
||||||
"model.vision.patch_embedding.proj", # cogvlm
|
"model.vision.patch_embedding.proj", # cogvlm
|
||||||
),
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.V_ENC_EMBD_NORM: (
|
||||||
|
"visual.post_conv_layernorm", # glm4v
|
||||||
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.V_ENC_EMBD_POS: (
|
MODEL_TENSOR.V_ENC_EMBD_POS: (
|
||||||
"vision_tower.vision_model.embeddings.position_embedding",
|
"vision_tower.vision_model.embeddings.position_embedding",
|
||||||
"model.vision_tower.embeddings.position_embeddings", # Intern-S1
|
"model.vision_tower.embeddings.position_embeddings", # Intern-S1
|
||||||
|
|
@ -1254,6 +1259,7 @@ class TensorNameMap:
|
||||||
"vision_tower.patch_embed.pos_emb", # kimi-vl
|
"vision_tower.patch_embed.pos_emb", # kimi-vl
|
||||||
"visual.pos_embed", # qwen3vl
|
"visual.pos_embed", # qwen3vl
|
||||||
"model.vision.patch_embedding.position_embedding", # cogvlm
|
"model.vision.patch_embedding.position_embedding", # cogvlm
|
||||||
|
"visual.embeddings.position_embedding", # glm4v
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.V_ENC_ATTN_QKV: (
|
MODEL_TENSOR.V_ENC_ATTN_QKV: (
|
||||||
|
|
@ -1409,6 +1415,11 @@ class TensorNameMap:
|
||||||
"vision_model.layernorm_post", # llama4
|
"vision_model.layernorm_post", # llama4
|
||||||
"visual.merger.ln_q", # qwen2vl
|
"visual.merger.ln_q", # qwen2vl
|
||||||
"vision_tower.encoder.final_layernorm", # kimi-vl
|
"vision_tower.encoder.final_layernorm", # kimi-vl
|
||||||
|
"visual.post_layernorm", # glm4v
|
||||||
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.V_MM_POST_NORM: (
|
||||||
|
"visual.merger.post_projection_norm", # glm4v
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.V_MM_INP_PROJ: (
|
MODEL_TENSOR.V_MM_INP_PROJ: (
|
||||||
|
|
@ -1478,6 +1489,7 @@ class TensorNameMap:
|
||||||
MODEL_TENSOR.V_MM_PATCH_MERGER: (
|
MODEL_TENSOR.V_MM_PATCH_MERGER: (
|
||||||
"multi_modal_projector.patch_merger.merging_layer", # mistral small 3.1 - hf
|
"multi_modal_projector.patch_merger.merging_layer", # mistral small 3.1 - hf
|
||||||
"patch_merger.merging_layer", # mistral
|
"patch_merger.merging_layer", # mistral
|
||||||
|
"visual.downsample", # glm4v
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.V_DS_NORM: (
|
MODEL_TENSOR.V_DS_NORM: (
|
||||||
|
|
@ -1498,14 +1510,17 @@ class TensorNameMap:
|
||||||
|
|
||||||
MODEL_TENSOR.V_MM_UP: (
|
MODEL_TENSOR.V_MM_UP: (
|
||||||
"model.vision.linear_proj.dense_h_to_4h", # cogvlm
|
"model.vision.linear_proj.dense_h_to_4h", # cogvlm
|
||||||
|
"visual.merger.up_proj", # glm4v
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.V_MM_DOWN: (
|
MODEL_TENSOR.V_MM_DOWN: (
|
||||||
"model.vision.linear_proj.dense_4h_to_h", # cogvlm
|
"model.vision.linear_proj.dense_4h_to_h", # cogvlm
|
||||||
|
"visual.merger.down_proj", # glm4v
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.V_MM_GATE: (
|
MODEL_TENSOR.V_MM_GATE: (
|
||||||
"model.vision.linear_proj.gate_proj", # cogvlm
|
"model.vision.linear_proj.gate_proj", # cogvlm
|
||||||
|
"visual.merger.gate_proj", # glm4v
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.V_TOK_BOI: (
|
MODEL_TENSOR.V_TOK_BOI: (
|
||||||
|
|
|
||||||
|
|
@ -231,3 +231,7 @@ bool llama_hparams::is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama
|
||||||
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool llama_hparams::use_mrope() const {
|
||||||
|
return rope_sections[0] > 0 && rope_sections[1] > 0;
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -270,6 +270,8 @@ struct llama_hparams {
|
||||||
// TODO: think of a better place for this function
|
// TODO: think of a better place for this function
|
||||||
// TODO: pack the SWA params in a struct?
|
// TODO: pack the SWA params in a struct?
|
||||||
static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1);
|
static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1);
|
||||||
|
|
||||||
|
bool use_mrope() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");
|
static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");
|
||||||
|
|
|
||||||
|
|
@ -1689,7 +1689,8 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||||
} break;
|
} break;
|
||||||
case LLM_ARCH_GLM4:
|
case LLM_ARCH_GLM4:
|
||||||
{
|
{
|
||||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||||
|
ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false);
|
||||||
switch (hparams.n_layer) {
|
switch (hparams.n_layer) {
|
||||||
case 40: type = LLM_TYPE_9B; break;
|
case 40: type = LLM_TYPE_9B; break;
|
||||||
case 61: type = LLM_TYPE_32B; break;
|
case 61: type = LLM_TYPE_32B; break;
|
||||||
|
|
@ -1698,8 +1699,9 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||||
} break;
|
} break;
|
||||||
case LLM_ARCH_GLM4_MOE:
|
case LLM_ARCH_GLM4_MOE:
|
||||||
{
|
{
|
||||||
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
|
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
|
||||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||||
|
ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false);
|
||||||
|
|
||||||
// MoE parameters
|
// MoE parameters
|
||||||
ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert);
|
ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert);
|
||||||
|
|
@ -7792,7 +7794,6 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
||||||
case LLM_ARCH_DEEPSEEK2:
|
case LLM_ARCH_DEEPSEEK2:
|
||||||
case LLM_ARCH_PLM:
|
case LLM_ARCH_PLM:
|
||||||
case LLM_ARCH_CHATGLM:
|
case LLM_ARCH_CHATGLM:
|
||||||
case LLM_ARCH_GLM4:
|
|
||||||
case LLM_ARCH_GRANITE:
|
case LLM_ARCH_GRANITE:
|
||||||
case LLM_ARCH_GRANITE_MOE:
|
case LLM_ARCH_GRANITE_MOE:
|
||||||
case LLM_ARCH_GRANITE_HYBRID:
|
case LLM_ARCH_GRANITE_HYBRID:
|
||||||
|
|
@ -7854,7 +7855,6 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
||||||
case LLM_ARCH_LFM2:
|
case LLM_ARCH_LFM2:
|
||||||
case LLM_ARCH_LFM2MOE:
|
case LLM_ARCH_LFM2MOE:
|
||||||
case LLM_ARCH_SMALLTHINKER:
|
case LLM_ARCH_SMALLTHINKER:
|
||||||
case LLM_ARCH_GLM4_MOE:
|
|
||||||
case LLM_ARCH_SEED_OSS:
|
case LLM_ARCH_SEED_OSS:
|
||||||
case LLM_ARCH_GROVEMOE:
|
case LLM_ARCH_GROVEMOE:
|
||||||
case LLM_ARCH_APERTUS:
|
case LLM_ARCH_APERTUS:
|
||||||
|
|
@ -7871,6 +7871,11 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
||||||
case LLM_ARCH_QWEN3VLMOE:
|
case LLM_ARCH_QWEN3VLMOE:
|
||||||
return LLAMA_ROPE_TYPE_IMROPE;
|
return LLAMA_ROPE_TYPE_IMROPE;
|
||||||
|
|
||||||
|
case LLM_ARCH_GLM4:
|
||||||
|
return model->hparams.use_mrope() ? LLAMA_ROPE_TYPE_MROPE : LLAMA_ROPE_TYPE_NORM;
|
||||||
|
case LLM_ARCH_GLM4_MOE:
|
||||||
|
return model->hparams.use_mrope() ? LLAMA_ROPE_TYPE_MROPE : LLAMA_ROPE_TYPE_NEOX;
|
||||||
|
|
||||||
// all model arches should be listed explicitly here
|
// all model arches should be listed explicitly here
|
||||||
case LLM_ARCH_UNKNOWN:
|
case LLM_ARCH_UNKNOWN:
|
||||||
GGML_ABORT("unknown architecture");
|
GGML_ABORT("unknown architecture");
|
||||||
|
|
|
||||||
|
|
@ -5,11 +5,20 @@ llm_build_glm4_moe::llm_build_glm4_moe(const llama_model & model, const llm_grap
|
||||||
|
|
||||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||||
|
|
||||||
|
int sections[4];
|
||||||
|
std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
|
||||||
|
|
||||||
ggml_tensor * cur;
|
ggml_tensor * cur;
|
||||||
ggml_tensor * inpL;
|
ggml_tensor * inpL;
|
||||||
|
|
||||||
inpL = build_inp_embd(model.tok_embd);
|
inpL = build_inp_embd(model.tok_embd);
|
||||||
|
|
||||||
|
bool use_mrope = hparams.use_mrope();
|
||||||
|
if (ubatch.embd && !use_mrope) {
|
||||||
|
// unfortunately, we need to forcefully stop here, to avoid users complaining about wrong results
|
||||||
|
GGML_ABORT("This GGUF does not support multimodal. Please reconvert it.");
|
||||||
|
}
|
||||||
|
|
||||||
// inp_pos - contains the positions
|
// inp_pos - contains the positions
|
||||||
ggml_tensor * inp_pos = build_inp_pos();
|
ggml_tensor * inp_pos = build_inp_pos();
|
||||||
|
|
||||||
|
|
@ -60,17 +69,25 @@ llm_build_glm4_moe::llm_build_glm4_moe(const llama_model & model, const llm_grap
|
||||||
Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
|
Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
|
||||||
cb(Kcur, "Kcur_normed", il);
|
cb(Kcur, "Kcur_normed", il);
|
||||||
}
|
}
|
||||||
Qcur = ggml_rope_ext(
|
|
||||||
ctx0, Qcur, inp_pos, nullptr,
|
|
||||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
|
||||||
);
|
|
||||||
|
|
||||||
Kcur = ggml_rope_ext(
|
if (use_mrope) {
|
||||||
ctx0, Kcur, inp_pos, nullptr,
|
Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr,
|
||||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||||
);
|
|
||||||
|
Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr,
|
||||||
|
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
|
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||||
|
} else {
|
||||||
|
// Normal RoPE
|
||||||
|
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot,
|
||||||
|
rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
|
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||||
|
|
||||||
|
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot,
|
||||||
|
rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
|
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||||
|
}
|
||||||
|
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
|
|
||||||
|
|
@ -8,11 +8,20 @@ llm_build_glm4::llm_build_glm4(const llama_model & model, const llm_graph_params
|
||||||
|
|
||||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||||
|
|
||||||
|
int sections[4];
|
||||||
|
std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
|
||||||
|
|
||||||
ggml_tensor * cur;
|
ggml_tensor * cur;
|
||||||
ggml_tensor * inpL;
|
ggml_tensor * inpL;
|
||||||
|
|
||||||
inpL = build_inp_embd(model.tok_embd);
|
inpL = build_inp_embd(model.tok_embd);
|
||||||
|
|
||||||
|
bool use_mrope = hparams.use_mrope();
|
||||||
|
if (ubatch.embd && !use_mrope) {
|
||||||
|
// unfortunately, we need to forcefully stop here, to avoid users complaining about wrong results
|
||||||
|
GGML_ABORT("This GGUF does not support multimodal. Please reconvert it.");
|
||||||
|
}
|
||||||
|
|
||||||
// inp_pos - contains the positions
|
// inp_pos - contains the positions
|
||||||
ggml_tensor * inp_pos = build_inp_pos();
|
ggml_tensor * inp_pos = build_inp_pos();
|
||||||
|
|
||||||
|
|
@ -63,11 +72,25 @@ llm_build_glm4::llm_build_glm4(const llama_model & model, const llm_graph_params
|
||||||
Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float),
|
Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float),
|
||||||
cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa));
|
cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa));
|
||||||
}
|
}
|
||||||
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
|
||||||
|
|
||||||
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
if (use_mrope) {
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr,
|
||||||
|
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
|
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||||
|
|
||||||
|
Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr,
|
||||||
|
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
|
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||||
|
} else {
|
||||||
|
// Normal RoPE
|
||||||
|
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot,
|
||||||
|
rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
|
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||||
|
|
||||||
|
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot,
|
||||||
|
rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
|
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||||
|
}
|
||||||
|
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@ add_library(mtmd
|
||||||
clip-graph.h
|
clip-graph.h
|
||||||
models/models.h
|
models/models.h
|
||||||
models/cogvlm.cpp
|
models/cogvlm.cpp
|
||||||
|
models/glm4v.cpp
|
||||||
models/internvl.cpp
|
models/internvl.cpp
|
||||||
models/kimivl.cpp
|
models/kimivl.cpp
|
||||||
models/llama4.cpp
|
models/llama4.cpp
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,8 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
|
|
||||||
|
#define DEFAULT_INTERPOLATION_MODE (GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ANTIALIAS)
|
||||||
|
|
||||||
struct clip_graph {
|
struct clip_graph {
|
||||||
const clip_model & model;
|
const clip_model & model;
|
||||||
const clip_hparams & hparams;
|
const clip_hparams & hparams;
|
||||||
|
|
@ -49,7 +51,7 @@ struct clip_graph {
|
||||||
void cb(ggml_tensor * cur0, const char * name, int il) const;
|
void cb(ggml_tensor * cur0, const char * name, int il) const;
|
||||||
|
|
||||||
// siglip2 naflex
|
// siglip2 naflex
|
||||||
ggml_tensor * resize_position_embeddings();
|
ggml_tensor * resize_position_embeddings(uint32_t interpolation_mode = DEFAULT_INTERPOLATION_MODE);
|
||||||
|
|
||||||
// build vision transformer (ViT) cgraph
|
// build vision transformer (ViT) cgraph
|
||||||
// this function should cover most of the models
|
// this function should cover most of the models
|
||||||
|
|
|
||||||
|
|
@ -68,6 +68,7 @@
|
||||||
#define TN_PATCH_EMBD "v.patch_embd.weight" // not rename tensor with ".0" postfix for backwrad compat
|
#define TN_PATCH_EMBD "v.patch_embd.weight" // not rename tensor with ".0" postfix for backwrad compat
|
||||||
#define TN_PATCH_EMBD_1 "v.patch_embd.weight.1"
|
#define TN_PATCH_EMBD_1 "v.patch_embd.weight.1"
|
||||||
#define TN_PATCH_BIAS "v.patch_embd.bias"
|
#define TN_PATCH_BIAS "v.patch_embd.bias"
|
||||||
|
#define TN_NORM_EMBD "v.norm_embd.%s"
|
||||||
#define TN_ATTN_QKV "%s.blk.%d.attn_qkv.%s"
|
#define TN_ATTN_QKV "%s.blk.%d.attn_qkv.%s"
|
||||||
#define TN_ATTN_K "%s.blk.%d.attn_k.%s"
|
#define TN_ATTN_K "%s.blk.%d.attn_k.%s"
|
||||||
#define TN_ATTN_Q "%s.blk.%d.attn_q.%s"
|
#define TN_ATTN_Q "%s.blk.%d.attn_q.%s"
|
||||||
|
|
@ -86,6 +87,10 @@
|
||||||
#define TN_LN_PRE "%s.pre_ln.%s"
|
#define TN_LN_PRE "%s.pre_ln.%s"
|
||||||
#define TN_LN_POST "%s.post_ln.%s"
|
#define TN_LN_POST "%s.post_ln.%s"
|
||||||
#define TN_LLAVA_PROJ "mm.%d.%s"
|
#define TN_LLAVA_PROJ "mm.%d.%s"
|
||||||
|
#define TN_MM_UP "mm.up.%s"
|
||||||
|
#define TN_MM_GATE "mm.gate.%s"
|
||||||
|
#define TN_MM_DOWN "mm.down.%s"
|
||||||
|
#define TN_MM_POST_NORM "mm.post_norm.%s"
|
||||||
#define TN_MVLM_PROJ_MLP "mm.model.mlp.%d.%s"
|
#define TN_MVLM_PROJ_MLP "mm.model.mlp.%d.%s"
|
||||||
#define TN_MVLM_PROJ_BLOCK "mm.model.mb_block.%d.block.%d.%s"
|
#define TN_MVLM_PROJ_BLOCK "mm.model.mb_block.%d.block.%d.%s"
|
||||||
#define TN_MVLM_PROJ_PEG "mm.model.peg.%d.%s"
|
#define TN_MVLM_PROJ_PEG "mm.model.peg.%d.%s"
|
||||||
|
|
@ -95,7 +100,7 @@
|
||||||
#define TN_MM_INP_PROJ "mm.input_projection.weight" // gemma3
|
#define TN_MM_INP_PROJ "mm.input_projection.weight" // gemma3
|
||||||
#define TN_MM_SOFT_EMB_N "mm.soft_emb_norm.weight" // gemma3
|
#define TN_MM_SOFT_EMB_N "mm.soft_emb_norm.weight" // gemma3
|
||||||
#define TN_MM_PROJECTOR "mm.model.fc.weight" // idefics3
|
#define TN_MM_PROJECTOR "mm.model.fc.weight" // idefics3
|
||||||
#define TN_MM_PATCH_MERGER "mm.patch_merger.weight" // mistral small 3.1
|
#define TN_MM_PATCH_MERGER "mm.patch_merger.%s" // mistral small 3.1, glm4v
|
||||||
#define TN_TOK_IMG_BREAK "v.token_embd.img_break" // pixtral
|
#define TN_TOK_IMG_BREAK "v.token_embd.img_break" // pixtral
|
||||||
#define TN_TOK_GLM_BOI "adapter.boi" // glm-edge (these embeddings are not in text model)
|
#define TN_TOK_GLM_BOI "adapter.boi" // glm-edge (these embeddings are not in text model)
|
||||||
#define TN_TOK_GLM_EOI "adapter.eoi" // glm-edge (these embeddings are not in text model)
|
#define TN_TOK_GLM_EOI "adapter.eoi" // glm-edge (these embeddings are not in text model)
|
||||||
|
|
@ -165,6 +170,7 @@ enum projector_type {
|
||||||
PROJECTOR_TYPE_LIGHTONOCR,
|
PROJECTOR_TYPE_LIGHTONOCR,
|
||||||
PROJECTOR_TYPE_COGVLM,
|
PROJECTOR_TYPE_COGVLM,
|
||||||
PROJECTOR_TYPE_JANUS_PRO,
|
PROJECTOR_TYPE_JANUS_PRO,
|
||||||
|
PROJECTOR_TYPE_GLM4V,
|
||||||
PROJECTOR_TYPE_UNKNOWN,
|
PROJECTOR_TYPE_UNKNOWN,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -192,6 +198,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
|
||||||
{ PROJECTOR_TYPE_LIGHTONOCR,"lightonocr"},
|
{ PROJECTOR_TYPE_LIGHTONOCR,"lightonocr"},
|
||||||
{ PROJECTOR_TYPE_COGVLM, "cogvlm"},
|
{ PROJECTOR_TYPE_COGVLM, "cogvlm"},
|
||||||
{ PROJECTOR_TYPE_JANUS_PRO, "janus_pro"},
|
{ PROJECTOR_TYPE_JANUS_PRO, "janus_pro"},
|
||||||
|
{ PROJECTOR_TYPE_GLM4V, "glm4v"},
|
||||||
};
|
};
|
||||||
|
|
||||||
static projector_type clip_projector_type_from_string(const std::string & str) {
|
static projector_type clip_projector_type_from_string(const std::string & str) {
|
||||||
|
|
@ -495,6 +502,8 @@ static void print_tensor_data(ggml_tensor * t, uint8_t * data, int64_t n) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void clip_debug_encode(clip_ctx * ctx, int h, int w, float fill_value);
|
||||||
|
|
||||||
//
|
//
|
||||||
// API used internally with mtmd
|
// API used internally with mtmd
|
||||||
//
|
//
|
||||||
|
|
|
||||||
|
|
@ -158,6 +158,8 @@ struct clip_model {
|
||||||
ggml_tensor * patch_embeddings_1 = nullptr; // second Conv2D kernel when we decouple Conv3D along temproal dimension (Qwen2VL)
|
ggml_tensor * patch_embeddings_1 = nullptr; // second Conv2D kernel when we decouple Conv3D along temproal dimension (Qwen2VL)
|
||||||
ggml_tensor * patch_bias = nullptr;
|
ggml_tensor * patch_bias = nullptr;
|
||||||
ggml_tensor * position_embeddings = nullptr;
|
ggml_tensor * position_embeddings = nullptr;
|
||||||
|
ggml_tensor * norm_embd_w = nullptr;
|
||||||
|
ggml_tensor * norm_embd_b = nullptr;
|
||||||
|
|
||||||
ggml_tensor * pre_ln_w = nullptr;
|
ggml_tensor * pre_ln_w = nullptr;
|
||||||
ggml_tensor * pre_ln_b = nullptr;
|
ggml_tensor * pre_ln_b = nullptr;
|
||||||
|
|
@ -172,6 +174,14 @@ struct clip_model {
|
||||||
ggml_tensor * projection; // TODO: rename it to fc (fully connected layer)
|
ggml_tensor * projection; // TODO: rename it to fc (fully connected layer)
|
||||||
ggml_tensor * mm_fc_w;
|
ggml_tensor * mm_fc_w;
|
||||||
ggml_tensor * mm_fc_b;
|
ggml_tensor * mm_fc_b;
|
||||||
|
ggml_tensor * mm_ffn_up_w = nullptr;
|
||||||
|
ggml_tensor * mm_ffn_up_b = nullptr;
|
||||||
|
ggml_tensor * mm_ffn_gate_w = nullptr;
|
||||||
|
ggml_tensor * mm_ffn_gate_b = nullptr;
|
||||||
|
ggml_tensor * mm_ffn_down_w = nullptr;
|
||||||
|
ggml_tensor * mm_ffn_down_b = nullptr;
|
||||||
|
ggml_tensor * mm_post_norm_w = nullptr;
|
||||||
|
ggml_tensor * mm_post_norm_b = nullptr;
|
||||||
|
|
||||||
// LLaVA projection
|
// LLaVA projection
|
||||||
ggml_tensor * mm_input_norm_w = nullptr;
|
ggml_tensor * mm_input_norm_w = nullptr;
|
||||||
|
|
@ -253,9 +263,10 @@ struct clip_model {
|
||||||
ggml_tensor * mm_input_proj_w = nullptr;
|
ggml_tensor * mm_input_proj_w = nullptr;
|
||||||
ggml_tensor * mm_soft_emb_norm_w = nullptr;
|
ggml_tensor * mm_soft_emb_norm_w = nullptr;
|
||||||
|
|
||||||
// pixtral
|
// pixtral, glm4v
|
||||||
ggml_tensor * token_embd_img_break = nullptr;
|
ggml_tensor * token_embd_img_break = nullptr;
|
||||||
ggml_tensor * mm_patch_merger_w = nullptr;
|
ggml_tensor * mm_patch_merger_w = nullptr;
|
||||||
|
ggml_tensor * mm_patch_merger_b = nullptr;
|
||||||
|
|
||||||
// ultravox / whisper encoder
|
// ultravox / whisper encoder
|
||||||
ggml_tensor * conv1d_1_w = nullptr;
|
ggml_tensor * conv1d_1_w = nullptr;
|
||||||
|
|
|
||||||
|
|
@ -264,11 +264,11 @@ void clip_graph::cb(ggml_tensor * cur0, const char * name, int il) const {
|
||||||
}
|
}
|
||||||
|
|
||||||
// siglip2 naflex
|
// siglip2 naflex
|
||||||
ggml_tensor * clip_graph::resize_position_embeddings() {
|
ggml_tensor * clip_graph::resize_position_embeddings(uint32_t interpolation_mode) {
|
||||||
ggml_tensor * pos_embd = model.position_embeddings;
|
ggml_tensor * pos_embd = model.position_embeddings;
|
||||||
const int height = img.ny / patch_size;
|
const int height = img.ny / patch_size;
|
||||||
const int width = img.nx / patch_size;
|
const int width = img.nx / patch_size;
|
||||||
const uint32_t mode = GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ANTIALIAS;
|
const uint32_t mode = interpolation_mode;
|
||||||
const int n_per_side = (int)std::sqrt(pos_embd->ne[1]);
|
const int n_per_side = (int)std::sqrt(pos_embd->ne[1]);
|
||||||
|
|
||||||
GGML_ASSERT(pos_embd);
|
GGML_ASSERT(pos_embd);
|
||||||
|
|
@ -485,19 +485,14 @@ ggml_tensor * clip_graph::build_norm(
|
||||||
? ggml_rms_norm(ctx0, cur, norm_eps)
|
? ggml_rms_norm(ctx0, cur, norm_eps)
|
||||||
: ggml_norm(ctx0, cur, norm_eps);
|
: ggml_norm(ctx0, cur, norm_eps);
|
||||||
|
|
||||||
if (mw || mb) {
|
|
||||||
cb(cur, "norm", il);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (mw) {
|
if (mw) {
|
||||||
cur = ggml_mul(ctx0, cur, mw);
|
cur = ggml_mul(ctx0, cur, mw);
|
||||||
if (mb) {
|
cb(cur, "norm_w", il);
|
||||||
cb(cur, "norm_w", il);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (mb) {
|
if (mb) {
|
||||||
cur = ggml_add(ctx0, cur, mb);
|
cur = ggml_add(ctx0, cur, mb);
|
||||||
|
cb(cur, "norm_b", il);
|
||||||
}
|
}
|
||||||
|
|
||||||
return cur;
|
return cur;
|
||||||
|
|
@ -842,6 +837,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||||
{
|
{
|
||||||
builder = std::make_unique<clip_graph_llava>(ctx, img);
|
builder = std::make_unique<clip_graph_llava>(ctx, img);
|
||||||
} break;
|
} break;
|
||||||
|
case PROJECTOR_TYPE_GLM4V:
|
||||||
|
{
|
||||||
|
builder = std::make_unique<clip_graph_glm4v>(ctx, img);
|
||||||
|
} break;
|
||||||
default:
|
default:
|
||||||
GGML_ABORT("missing cgraph builder");
|
GGML_ABORT("missing cgraph builder");
|
||||||
}
|
}
|
||||||
|
|
@ -1155,6 +1154,14 @@ struct clip_model_loader {
|
||||||
LOG_WRN("%s: more info: https://github.com/ggml-org/llama.cpp/issues/16842\n\n", __func__);
|
LOG_WRN("%s: more info: https://github.com/ggml-org/llama.cpp/issues/16842\n\n", __func__);
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case PROJECTOR_TYPE_GLM4V:
|
||||||
|
{
|
||||||
|
hparams.rope_theta = 10000.0f;
|
||||||
|
hparams.n_merge = 2; // default value for GLM4-V
|
||||||
|
get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.n_merge, false);
|
||||||
|
hparams.set_limit_image_tokens(8, 4096);
|
||||||
|
hparams.set_warmup_n_tokens(46*46); // avoid OOM on warmup
|
||||||
|
} break;
|
||||||
case PROJECTOR_TYPE_LLAMA4:
|
case PROJECTOR_TYPE_LLAMA4:
|
||||||
{
|
{
|
||||||
hparams.rope_theta = 10000.0f;
|
hparams.rope_theta = 10000.0f;
|
||||||
|
|
@ -1282,6 +1289,9 @@ struct clip_model_loader {
|
||||||
model.patch_embeddings_0 = get_tensor(TN_PATCH_EMBD, false);
|
model.patch_embeddings_0 = get_tensor(TN_PATCH_EMBD, false);
|
||||||
model.patch_embeddings_1 = get_tensor(TN_PATCH_EMBD_1, false);
|
model.patch_embeddings_1 = get_tensor(TN_PATCH_EMBD_1, false);
|
||||||
|
|
||||||
|
model.norm_embd_w = get_tensor(string_format(TN_NORM_EMBD, "weight"), false);
|
||||||
|
model.norm_embd_b = get_tensor(string_format(TN_NORM_EMBD, "bias"), false);
|
||||||
|
|
||||||
model.position_embeddings = get_tensor(string_format(TN_POS_EMBD, prefix), false);
|
model.position_embeddings = get_tensor(string_format(TN_POS_EMBD, prefix), false);
|
||||||
|
|
||||||
// layers
|
// layers
|
||||||
|
|
@ -1470,6 +1480,20 @@ struct clip_model_loader {
|
||||||
model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));
|
model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));
|
||||||
model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"));
|
model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"));
|
||||||
} break;
|
} break;
|
||||||
|
case PROJECTOR_TYPE_GLM4V:
|
||||||
|
{
|
||||||
|
model.projection = get_tensor(TN_MM_PROJECTOR);
|
||||||
|
model.mm_ffn_up_w = get_tensor(string_format(TN_MM_UP, "weight"));
|
||||||
|
model.mm_ffn_up_b = get_tensor(string_format(TN_MM_UP, "bias"), false);
|
||||||
|
model.mm_ffn_gate_w = get_tensor(string_format(TN_MM_GATE, "weight"));
|
||||||
|
model.mm_ffn_gate_b = get_tensor(string_format(TN_MM_GATE, "bias"), false);
|
||||||
|
model.mm_ffn_down_w = get_tensor(string_format(TN_MM_DOWN, "weight"));
|
||||||
|
model.mm_ffn_down_b = get_tensor(string_format(TN_MM_DOWN, "bias"), false);
|
||||||
|
model.mm_post_norm_w = get_tensor(string_format(TN_MM_POST_NORM, "weight"));
|
||||||
|
model.mm_post_norm_b = get_tensor(string_format(TN_MM_POST_NORM, "bias"), false);
|
||||||
|
model.mm_patch_merger_w = get_tensor(string_format(TN_MM_PATCH_MERGER, "weight"));
|
||||||
|
model.mm_patch_merger_b = get_tensor(string_format(TN_MM_PATCH_MERGER, "bias"));
|
||||||
|
} break;
|
||||||
case PROJECTOR_TYPE_GEMMA3:
|
case PROJECTOR_TYPE_GEMMA3:
|
||||||
{
|
{
|
||||||
model.mm_input_proj_w = get_tensor(TN_MM_INP_PROJ);
|
model.mm_input_proj_w = get_tensor(TN_MM_INP_PROJ);
|
||||||
|
|
@ -1498,8 +1522,8 @@ struct clip_model_loader {
|
||||||
// [IMG_BREAK] token embedding
|
// [IMG_BREAK] token embedding
|
||||||
model.token_embd_img_break = get_tensor(TN_TOK_IMG_BREAK);
|
model.token_embd_img_break = get_tensor(TN_TOK_IMG_BREAK);
|
||||||
// for mistral small 3.1
|
// for mistral small 3.1
|
||||||
model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false);
|
model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false);
|
||||||
model.mm_patch_merger_w = get_tensor(TN_MM_PATCH_MERGER, false);
|
model.mm_patch_merger_w = get_tensor(string_format(TN_MM_PATCH_MERGER, "weight"), false);
|
||||||
} break;
|
} break;
|
||||||
case PROJECTOR_TYPE_LIGHTONOCR:
|
case PROJECTOR_TYPE_LIGHTONOCR:
|
||||||
{
|
{
|
||||||
|
|
@ -1507,8 +1531,8 @@ struct clip_model_loader {
|
||||||
model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"), false);
|
model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"), false);
|
||||||
model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));
|
model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));
|
||||||
model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"), false);
|
model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"), false);
|
||||||
model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false);
|
model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false);
|
||||||
model.mm_patch_merger_w = get_tensor(TN_MM_PATCH_MERGER, false);
|
model.mm_patch_merger_w = get_tensor(string_format(TN_MM_PATCH_MERGER, "weight"), false);
|
||||||
} break;
|
} break;
|
||||||
case PROJECTOR_TYPE_ULTRAVOX:
|
case PROJECTOR_TYPE_ULTRAVOX:
|
||||||
{
|
{
|
||||||
|
|
@ -1873,6 +1897,8 @@ struct clip_init_result clip_init(const char * fname, struct clip_context_params
|
||||||
if (ctx_params.warmup) {
|
if (ctx_params.warmup) {
|
||||||
loader.warmup(*ctx_vision);
|
loader.warmup(*ctx_vision);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// clip_debug_encode(ctx_vision, 24*14, 24*14, 0.5f);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (loader.has_audio) {
|
if (loader.has_audio) {
|
||||||
|
|
@ -2582,6 +2608,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
|
||||||
case PROJECTOR_TYPE_QWEN2VL:
|
case PROJECTOR_TYPE_QWEN2VL:
|
||||||
case PROJECTOR_TYPE_QWEN25VL:
|
case PROJECTOR_TYPE_QWEN25VL:
|
||||||
case PROJECTOR_TYPE_QWEN3VL:
|
case PROJECTOR_TYPE_QWEN3VL:
|
||||||
|
case PROJECTOR_TYPE_GLM4V:
|
||||||
{
|
{
|
||||||
GGML_ASSERT(params.image_min_pixels > 0 && params.image_max_pixels > 0);
|
GGML_ASSERT(params.image_min_pixels > 0 && params.image_max_pixels > 0);
|
||||||
clip_image_u8 resized;
|
clip_image_u8 resized;
|
||||||
|
|
@ -2824,16 +2851,30 @@ const char * clip_patch_merge_type(const struct clip_ctx * ctx) {
|
||||||
int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 * img) {
|
int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 * img) {
|
||||||
const auto & params = ctx->model.hparams;
|
const auto & params = ctx->model.hparams;
|
||||||
const int n_total = clip_n_output_tokens(ctx, img);
|
const int n_total = clip_n_output_tokens(ctx, img);
|
||||||
if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN3VL) {
|
const auto & proj = ctx->proj_type();
|
||||||
return img->nx / (params.patch_size * 2);
|
switch (proj) {
|
||||||
|
case PROJECTOR_TYPE_QWEN2VL:
|
||||||
|
case PROJECTOR_TYPE_QWEN25VL:
|
||||||
|
case PROJECTOR_TYPE_QWEN3VL:
|
||||||
|
case PROJECTOR_TYPE_GLM4V:
|
||||||
|
return (img->nx / params.patch_size) / 2;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
return n_total;
|
return n_total;
|
||||||
}
|
}
|
||||||
|
|
||||||
int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 * img) {
|
int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 * img) {
|
||||||
const auto & params = ctx->model.hparams;
|
const auto & params = ctx->model.hparams;
|
||||||
if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN3VL) {
|
const auto & proj = ctx->proj_type();
|
||||||
return img->ny / (params.patch_size * 2);
|
switch (proj) {
|
||||||
|
case PROJECTOR_TYPE_QWEN2VL:
|
||||||
|
case PROJECTOR_TYPE_QWEN25VL:
|
||||||
|
case PROJECTOR_TYPE_QWEN3VL:
|
||||||
|
case PROJECTOR_TYPE_GLM4V:
|
||||||
|
return (img->ny / params.patch_size) / 2;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
@ -2890,6 +2931,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
|
||||||
case PROJECTOR_TYPE_QWEN2VL:
|
case PROJECTOR_TYPE_QWEN2VL:
|
||||||
case PROJECTOR_TYPE_QWEN25VL:
|
case PROJECTOR_TYPE_QWEN25VL:
|
||||||
case PROJECTOR_TYPE_QWEN3VL:
|
case PROJECTOR_TYPE_QWEN3VL:
|
||||||
|
case PROJECTOR_TYPE_GLM4V:
|
||||||
{
|
{
|
||||||
// dynamic size (2 conv, so double patch size)
|
// dynamic size (2 conv, so double patch size)
|
||||||
int x_patch = img->nx / (params.patch_size * 2);
|
int x_patch = img->nx / (params.patch_size * 2);
|
||||||
|
|
@ -3137,6 +3179,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||||
} break;
|
} break;
|
||||||
case PROJECTOR_TYPE_QWEN2VL:
|
case PROJECTOR_TYPE_QWEN2VL:
|
||||||
case PROJECTOR_TYPE_QWEN3VL:
|
case PROJECTOR_TYPE_QWEN3VL:
|
||||||
|
case PROJECTOR_TYPE_GLM4V:
|
||||||
{
|
{
|
||||||
const int merge_ratio = hparams.n_merge;
|
const int merge_ratio = hparams.n_merge;
|
||||||
const int pw = image_size_width / patch_size;
|
const int pw = image_size_width / patch_size;
|
||||||
|
|
@ -3363,7 +3406,9 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||||
}
|
}
|
||||||
|
|
||||||
// copy the embeddings to the location passed by the user
|
// copy the embeddings to the location passed by the user
|
||||||
ggml_backend_tensor_get(embeddings, vec, 0, ggml_nbytes(embeddings));
|
if (vec != nullptr) {
|
||||||
|
ggml_backend_tensor_get(embeddings, vec, 0, ggml_nbytes(embeddings));
|
||||||
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
@ -3411,6 +3456,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
|
||||||
return ctx->model.mm_2_w->ne[1];
|
return ctx->model.mm_2_w->ne[1];
|
||||||
case PROJECTOR_TYPE_COGVLM:
|
case PROJECTOR_TYPE_COGVLM:
|
||||||
return ctx->model.mm_4h_to_h_w->ne[1];
|
return ctx->model.mm_4h_to_h_w->ne[1];
|
||||||
|
case PROJECTOR_TYPE_GLM4V:
|
||||||
|
return ctx->model.mm_ffn_down_w->ne[1];
|
||||||
default:
|
default:
|
||||||
GGML_ABORT("Unknown projector type");
|
GGML_ABORT("Unknown projector type");
|
||||||
}
|
}
|
||||||
|
|
@ -3427,10 +3474,11 @@ bool clip_is_glm(const struct clip_ctx * ctx) {
|
||||||
return ctx->proj_type() == PROJECTOR_TYPE_GLM_EDGE;
|
return ctx->proj_type() == PROJECTOR_TYPE_GLM_EDGE;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool clip_is_qwen2vl(const struct clip_ctx * ctx) {
|
bool clip_is_mrope(const struct clip_ctx * ctx) {
|
||||||
return ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL
|
return ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL
|
||||||
|| ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL
|
|| ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL
|
||||||
|| ctx->proj_type() == PROJECTOR_TYPE_QWEN3VL;
|
|| ctx->proj_type() == PROJECTOR_TYPE_QWEN3VL
|
||||||
|
|| ctx->proj_type() == PROJECTOR_TYPE_GLM4V;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool clip_is_llava(const struct clip_ctx * ctx) {
|
bool clip_is_llava(const struct clip_ctx * ctx) {
|
||||||
|
|
@ -3491,3 +3539,22 @@ void clip_image_f32_batch_add_mel(struct clip_image_f32_batch * batch, int n_mel
|
||||||
const clip_hparams * clip_get_hparams(const struct clip_ctx * ctx) {
|
const clip_hparams * clip_get_hparams(const struct clip_ctx * ctx) {
|
||||||
return &ctx->model.hparams;
|
return &ctx->model.hparams;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// API for debugging
|
||||||
|
//
|
||||||
|
|
||||||
|
void clip_debug_encode(clip_ctx * ctx, int h, int w, float fill_value) {
|
||||||
|
clip_image_f32 img;
|
||||||
|
img.nx = w;
|
||||||
|
img.ny = h;
|
||||||
|
img.buf.resize(h * w * 3);
|
||||||
|
for (int i = 0; i < h * w * 3; i++) {
|
||||||
|
img.buf[i] = static_cast<float>(fill_value);
|
||||||
|
}
|
||||||
|
bool cur_debug_graph = ctx->debug_graph;
|
||||||
|
ctx->debug_graph = true;
|
||||||
|
clip_image_encode(ctx, 1, &img, nullptr);
|
||||||
|
ctx->debug_graph = cur_debug_graph;
|
||||||
|
GGML_ASSERT(img.buf.empty() && "expected, always stop here");
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -104,7 +104,7 @@ bool clip_image_batch_encode(struct clip_ctx * ctx, int n_threads, const struct
|
||||||
|
|
||||||
int clip_is_minicpmv(const struct clip_ctx * ctx);
|
int clip_is_minicpmv(const struct clip_ctx * ctx);
|
||||||
bool clip_is_glm(const struct clip_ctx * ctx);
|
bool clip_is_glm(const struct clip_ctx * ctx);
|
||||||
bool clip_is_qwen2vl(const struct clip_ctx * ctx);
|
bool clip_is_mrope(const struct clip_ctx * ctx);
|
||||||
bool clip_is_llava(const struct clip_ctx * ctx);
|
bool clip_is_llava(const struct clip_ctx * ctx);
|
||||||
bool clip_is_gemma3(const struct clip_ctx * ctx);
|
bool clip_is_gemma3(const struct clip_ctx * ctx);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,120 @@
|
||||||
|
#include "models.h"
|
||||||
|
|
||||||
|
ggml_cgraph * clip_graph_glm4v::build() {
|
||||||
|
GGML_ASSERT(model.patch_bias != nullptr);
|
||||||
|
GGML_ASSERT(model.position_embeddings != nullptr);
|
||||||
|
GGML_ASSERT(model.class_embedding == nullptr);
|
||||||
|
|
||||||
|
const int batch_size = 1;
|
||||||
|
|
||||||
|
norm_type norm_t = NORM_TYPE_RMS;
|
||||||
|
|
||||||
|
ggml_tensor * inp_raw = build_inp_raw();
|
||||||
|
ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
|
||||||
|
|
||||||
|
int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
|
||||||
|
ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches * 4);
|
||||||
|
ggml_set_name(positions, "positions");
|
||||||
|
ggml_set_input(positions);
|
||||||
|
|
||||||
|
GGML_ASSERT(img.nx % (patch_size * 2) == 0);
|
||||||
|
GGML_ASSERT(img.ny % (patch_size * 2) == 0);
|
||||||
|
|
||||||
|
// second conv dimension
|
||||||
|
{
|
||||||
|
auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
|
||||||
|
inp = ggml_add(ctx0, inp, inp_1);
|
||||||
|
|
||||||
|
inp = ggml_permute(ctx0, inp, 1, 2, 0, 3); // [w, h, c, b] -> [c, w, h, b]
|
||||||
|
inp = ggml_cont_4d(
|
||||||
|
ctx0, inp,
|
||||||
|
n_embd * 2, n_patches_x / 2, n_patches_y, batch_size);
|
||||||
|
inp = ggml_reshape_4d(
|
||||||
|
ctx0, inp,
|
||||||
|
n_embd * 2, n_patches_x / 2, 2, batch_size * (n_patches_y / 2));
|
||||||
|
inp = ggml_permute(ctx0, inp, 0, 2, 1, 3);
|
||||||
|
inp = ggml_cont_3d(
|
||||||
|
ctx0, inp,
|
||||||
|
n_embd, n_patches_x * n_patches_y, batch_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
// add patch bias
|
||||||
|
inp = ggml_add(ctx0, inp, model.patch_bias);
|
||||||
|
cb(inp, "patch_bias", -1);
|
||||||
|
|
||||||
|
// pos-conv norm
|
||||||
|
inp = build_norm(inp, model.norm_embd_w, model.norm_embd_b, norm_t, eps, -1);
|
||||||
|
|
||||||
|
// calculate absolute position embedding and apply
|
||||||
|
ggml_tensor * learned_pos_embd = resize_position_embeddings(GGML_SCALE_MODE_BICUBIC);
|
||||||
|
learned_pos_embd = ggml_cont_4d(
|
||||||
|
ctx0, learned_pos_embd,
|
||||||
|
n_embd * 2, n_patches_x / 2, n_patches_y, batch_size);
|
||||||
|
learned_pos_embd = ggml_reshape_4d(
|
||||||
|
ctx0, learned_pos_embd,
|
||||||
|
n_embd * 2, n_patches_x / 2, 2, batch_size * (n_patches_y / 2));
|
||||||
|
learned_pos_embd = ggml_permute(ctx0, learned_pos_embd, 0, 2, 1, 3);
|
||||||
|
learned_pos_embd = ggml_cont_3d(
|
||||||
|
ctx0, learned_pos_embd,
|
||||||
|
n_embd, n_patches_x * n_patches_y, batch_size);
|
||||||
|
cb(learned_pos_embd, "learned_pos_embd", -1);
|
||||||
|
|
||||||
|
auto add_pos = [&](ggml_tensor * cur, const clip_layer &) {
|
||||||
|
return ggml_rope_multi(
|
||||||
|
ctx0, cur, positions, nullptr,
|
||||||
|
d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION,
|
||||||
|
32768, hparams.rope_theta, 1, 0, 1, 32, 1);
|
||||||
|
};
|
||||||
|
|
||||||
|
ggml_tensor * cur = build_vit(
|
||||||
|
inp, n_patches,
|
||||||
|
norm_t,
|
||||||
|
hparams.ffn_op,
|
||||||
|
learned_pos_embd,
|
||||||
|
add_pos);
|
||||||
|
|
||||||
|
cb(cur, "vit_out", -1);
|
||||||
|
// cb(ggml_sum(ctx0, cur), "vit_out_sum", -1);
|
||||||
|
|
||||||
|
// GLM4V projector
|
||||||
|
// ref: https://github.com/huggingface/transformers/blob/40dc11cd3eb4126652aa41ef8272525affd4a636/src/transformers/models/glm4v/modeling_glm4v.py#L116-L130
|
||||||
|
|
||||||
|
// patch merger (downsample)
|
||||||
|
{
|
||||||
|
int n_merge = hparams.n_merge;
|
||||||
|
GGML_ASSERT(n_merge > 0);
|
||||||
|
|
||||||
|
int n_token_out = n_patches / n_merge / n_merge;
|
||||||
|
cur = ggml_reshape_4d(ctx0, cur, n_embd, n_merge, n_merge, n_token_out);
|
||||||
|
cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 2, 0, 1, 3)); // [n_merge, n_merge, n_embd, n_token_out]
|
||||||
|
cur = ggml_conv_2d(ctx0, model.mm_patch_merger_w, cur, n_merge, n_merge, 0, 0, 1, 1);
|
||||||
|
cur = ggml_reshape_2d(ctx0, cur, cur->ne[2], n_token_out); // [n_embd_out, n_token_out]
|
||||||
|
|
||||||
|
cur = ggml_add(ctx0, cur, model.mm_patch_merger_b);
|
||||||
|
}
|
||||||
|
|
||||||
|
// FC projector
|
||||||
|
{
|
||||||
|
cur = ggml_mul_mat(ctx0, model.projection, cur);
|
||||||
|
// default LayerNorm (post_projection_norm)
|
||||||
|
cur = build_norm(cur, model.mm_post_norm_w, model.mm_post_norm_b, NORM_TYPE_NORMAL, 1e-5, -1);
|
||||||
|
cur = ggml_gelu_erf(ctx0, cur);
|
||||||
|
cb(cur, "after_fc_proj", -1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// FFN projector
|
||||||
|
{
|
||||||
|
cur = build_ffn(cur,
|
||||||
|
model.mm_ffn_up_w, model.mm_ffn_up_b,
|
||||||
|
model.mm_ffn_gate_w, model.mm_ffn_gate_b,
|
||||||
|
model.mm_ffn_down_w, model.mm_ffn_down_b,
|
||||||
|
hparams.ffn_op, -1);
|
||||||
|
cb(cur, "after_ffn_proj", -1);
|
||||||
|
// cb(ggml_sum(ctx0, cur), "merged_sum", -1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// build the graph
|
||||||
|
ggml_build_forward_expand(gf, cur);
|
||||||
|
|
||||||
|
return gf;
|
||||||
|
}
|
||||||
|
|
@ -56,3 +56,8 @@ struct clip_graph_whisper_enc : clip_graph {
|
||||||
clip_graph_whisper_enc(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
|
clip_graph_whisper_enc(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
|
||||||
ggml_cgraph * build() override;
|
ggml_cgraph * build() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct clip_graph_glm4v : clip_graph {
|
||||||
|
clip_graph_glm4v(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
|
||||||
|
ggml_cgraph * build() override;
|
||||||
|
};
|
||||||
|
|
|
||||||
|
|
@ -217,7 +217,7 @@ struct mtmd_context {
|
||||||
|
|
||||||
void init_vision() {
|
void init_vision() {
|
||||||
GGML_ASSERT(ctx_v != nullptr);
|
GGML_ASSERT(ctx_v != nullptr);
|
||||||
use_mrope = clip_is_qwen2vl(ctx_v);
|
use_mrope = clip_is_mrope(ctx_v);
|
||||||
|
|
||||||
projector_type proj = clip_get_projector_type(ctx_v);
|
projector_type proj = clip_get_projector_type(ctx_v);
|
||||||
int minicpmv_version = clip_is_minicpmv(ctx_v);
|
int minicpmv_version = clip_is_minicpmv(ctx_v);
|
||||||
|
|
@ -309,6 +309,10 @@ struct mtmd_context {
|
||||||
img_beg = "<|image_start|>";
|
img_beg = "<|image_start|>";
|
||||||
img_end = "<|image_end|>";
|
img_end = "<|image_end|>";
|
||||||
|
|
||||||
|
} else if (proj == PROJECTOR_TYPE_GLM4V) {
|
||||||
|
img_beg = "<|begin_of_image|>";
|
||||||
|
img_end = "<|end_of_image|>";
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue