Add VaetkiVisionModel mmproj converter with Rice ViT support

This commit is contained in:
suhyun-hwang 2026-01-11 01:06:22 +09:00
parent 96294c6ad9
commit 025ce711b6
3 changed files with 82 additions and 0 deletions

View File

@ -7909,6 +7909,86 @@ class VaetkiModel(TextModel):
raise ValueError(f"Unprocessed experts: {experts}")
@ModelBase.register("VaetkiVisionModel", "VaetkiVLForCausalLM")
class VaetkiVisionModel(MmprojModel):
"""VAETKI Vision Model (mmproj) - Rice ViT with CLS token and 2D RoPE"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert self.hparams_vision is not None
# Remap vision config parameters to standard names
self.hparams_vision["num_attention_heads"] = self.hparams_vision.get("num_heads")
self.hparams_vision["num_hidden_layers"] = self.hparams_vision.get("depth")
if "embed_dim" in self.hparams_vision:
self.hparams_vision["hidden_size"] = self.hparams_vision.get("embed_dim")
if "image_size" not in self.hparams_vision:
self.hparams_vision["image_size"] = self.preprocessor_config.get("size", {}).get("shortest_edge", 560)
def set_gguf_parameters(self):
super().set_gguf_parameters()
assert self.hparams_vision is not None
hparams = self.hparams_vision
# VAETKI projector type - routes to vaetki.cpp graph builder
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.VAETKI)
self.gguf_writer.add_vision_attention_layernorm_eps(hparams.get("layer_norm_eps", 1e-5))
self.gguf_writer.add_vision_spatial_merge_size(hparams.get("spatial_merge_size", 2))
def tensor_force_quant(self, name, new_name, bid, n_dims):
if "class_pos_embd" in new_name:
return gguf.GGMLQuantizationType.F32
return super().tensor_force_quant(name, new_name, bid, n_dims)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
# Only process vision tensors
if not (name.startswith("model.visual.") or name.startswith("visual.")):
return []
# Handle merger tensors with special index mapping
# clip.cpp PROJECTOR_TYPE_VAETKI expects:
# mm.model.mlp.0.* -> ln_q (pre-norm)
# mm.model.mlp.1.* -> mlp.0 (up projection)
# mm.model.mlp.3.* -> mlp.2 (down projection)
if "merger.ln_q" in name:
# ln_q -> mm.model.mlp.0 (used as norm in vaetki.cpp)
suffix = "weight" if name.endswith(".weight") else "bias"
return [(f"mm.model.mlp.0.{suffix}", data_torch)]
elif "merger.mlp.0" in name:
# mlp.0 -> mm.model.mlp.1 (up projection)
suffix = "weight" if name.endswith(".weight") else "bias"
return [(f"mm.model.mlp.1.{suffix}", data_torch)]
elif "merger.mlp.2" in name:
# mlp.2 -> mm.model.mlp.3 (down projection)
suffix = "weight" if name.endswith(".weight") else "bias"
return [(f"mm.model.mlp.3.{suffix}", data_torch)]
# Handle class_embedding and class_pos_emb (keep model.visual. prefix for mapping)
if "class_embedding" in name or "class_pos_emb" in name:
return [(self.map_tensor_name(name), data_torch)]
# Strip model.visual. -> visual. for other tensors
if name.startswith("model.visual."):
name = name.replace("model.visual.", "visual.")
# Split fused QKV tensors
if ".qkv." in name:
if data_torch.ndim == 2:
c3, _ = data_torch.shape
else:
c3 = data_torch.shape[0]
assert c3 % 3 == 0
c = c3 // 3
return [
(self.map_tensor_name(name.replace("qkv", "q")), data_torch[:c]),
(self.map_tensor_name(name.replace("qkv", "k")), data_torch[c:c*2]),
(self.map_tensor_name(name.replace("qkv", "v")), data_torch[c*2:]),
]
return [(self.map_tensor_name(name), data_torch)]
@ModelBase.register("MiniMaxM2ForCausalLM")
class MiniMaxM2Model(TextModel):
model_arch = gguf.MODEL_ARCH.MINIMAXM2

View File

@ -3650,6 +3650,7 @@ class VisionProjectorType:
MUSIC_FLAMINGO = "musicflamingo" # audio
GLM4V = "glm4v"
YOUTUVL = "youtuvl"
VAETKI = "vaetki"
# Items here are (block size, type size)

View File

@ -1471,6 +1471,7 @@ class TensorNameMap:
"vision_tower.ln_pre", # pixtral-hf
"vision_encoder.ln_pre", # pixtral
"vision_model.layernorm_pre", # llama4
"visual.pre_layernorm", # vaetki
),
MODEL_TENSOR.V_POST_NORM: (