diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index eb43520f98..0defed9aa8 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -7562,6 +7562,237 @@ class DeepseekV2Model(TextModel): raise ValueError(f"Unprocessed experts: {experts}") +@ModelBase.register("VaetkiForCausalLM", "VaetkiVLForCausalLM") +class VaetkiModel(TextModel): + """VAETKI MoE model with MLA attention and 4-norm layer structure""" + model_arch = gguf.MODEL_ARCH.VAETKI + + _experts: list[dict[str, Tensor]] | None = None + + def set_vocab(self): + # VAETKI: hybrid tokenizer with SPM-style ▁ space markers + BPE rank-based merges + <0xXX> byte fallback + # manual token loading because VAETKI doesn't fit standard BPE or SPM vocab loading + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(self.dir_model) + vocab_size = self.hparams.get("vocab_size", len(tokenizer.vocab)) + + tokens: list[str] = [] + toktypes: list[int] = [] + + reverse_vocab = {id_: tok for tok, id_ in tokenizer.vocab.items()} + added_vocab = tokenizer.get_added_vocab() + added_tokens_decoder = tokenizer.added_tokens_decoder + + for i in range(vocab_size): + if i not in reverse_vocab: + tokens.append(f"[PAD{i}]") + toktypes.append(gguf.TokenType.UNUSED) + else: + token: str = reverse_vocab[i] + if token in added_vocab: + if not added_tokens_decoder[i].normalized: + token = tokenizer.decode(tokenizer.encode(token, add_special_tokens=False)) + if added_tokens_decoder[i].special or self.does_token_look_special(token): + toktypes.append(gguf.TokenType.CONTROL) + else: + toktypes.append(gguf.TokenType.USER_DEFINED) + else: + toktypes.append(gguf.TokenType.NORMAL) + tokens.append(token) + + self.gguf_writer.add_tokenizer_model("vaetki") + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_types(toktypes) + + special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True) + special_vocab.add_to_gguf(self.gguf_writer) + + self.gguf_writer.add_add_space_prefix(False) + + def set_gguf_parameters(self): + super().set_gguf_parameters() + + hparams = self.hparams + + # For MLA without absorption, n_head_kv = n_head (full MHA after decompression) + self.gguf_writer.add_head_count_kv(hparams["num_attention_heads"]) + + # MLA parameters (like DeepSeek2) + self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"]) + self.gguf_writer.add_kv_lora_rank(hparams["kv_lora_rank"]) + + # For MLA without absorption, key_length/value_length are the full (MHA) dimensions + # key = qk_nope + qk_rope, value = v_head_dim + self.gguf_writer.add_key_length(hparams["qk_head_dim"]) + self.gguf_writer.add_value_length(hparams["v_head_dim"]) + + # key_length_mla/value_length_mla are the MLA head dimensions (same as key/value for non-absorption) + self.gguf_writer.add_key_length_mla(hparams["qk_head_dim"]) + self.gguf_writer.add_value_length_mla(hparams["v_head_dim"]) + self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) + + # MoE parameters + self.gguf_writer.add_leading_dense_block_count(hparams.get("first_k_dense_replace", 1)) + self.gguf_writer.add_expert_count(hparams["n_routed_experts"]) + self.gguf_writer.add_expert_shared_count(hparams.get("n_shared_experts", 1)) + self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"]) + if (routed_scale := hparams.get("routed_scaling_factor")) is not None: + self.gguf_writer.add_expert_weights_scale(routed_scale) + if hparams.get("norm_topk_prob", False): + self.gguf_writer.add_expert_weights_norm(True) + + self.gguf_writer.add_sliding_window(hparams["sliding_window"]) + sliding_window_pattern = [] + for t in self.hparams["layer_types"]: + sliding_window_pattern.append(t == "sliding_attention") + self.gguf_writer.add_sliding_window_pattern(sliding_window_pattern) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # Skip vision encoder tensors + if "vision_tower" in name or "vision_model" in name or "visual" in name: + return [] + if name.startswith("model.vision_model.") or name.startswith("vision_model."): + return [] + + # Remove language_model prefix + if name.startswith("model.language_model."): + name = name.replace("model.language_model.", "model.") + elif name.startswith("language_model."): + name = name.replace("language_model.", "model.") + + if name.endswith("q_b_proj.weight"): + n_head = self.hparams["num_attention_heads"] + qk_nope_head_dim = self.hparams["qk_nope_head_dim"] + qk_rope_head_dim = self.hparams["qk_rope_head_dim"] + qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + data_torch = data_torch.view(n_head, qk_head_dim, -1) + data_torch = torch.cat([data_torch[:, qk_nope_head_dim:, :], data_torch[:, :qk_nope_head_dim, :]], dim=1) + data_torch = data_torch.reshape(n_head * qk_head_dim, -1) + + # VAETKI WBLRMSNorm: add 1 to weights for standard RMSNorm compatibility + norm_weight_patterns = [ + "input_layernorm.weight", + "post_attention_layernorm.weight", + "pre_mlp_layernorm.weight", + "post_mlp_layernorm.weight", + "q_a_layernorm.weight", + "kv_a_layernorm.weight", + "model.norm.weight", + ] + if any(pattern in name for pattern in norm_weight_patterns): + data_torch = data_torch + 1.0 + + # Handle MoE expert tensors + if ".mlp.experts." in name and ".shared_experts." not in name: + n_experts = self.hparams["n_routed_experts"] + assert bid is not None + + if self._experts is None: + self._experts = [{} for _ in range(self.block_count)] + + self._experts[bid][name] = data_torch + + # Check if all experts for this layer are collected (n_experts * 3 tensors: down/gate/up) + if len(self._experts[bid]) >= n_experts * 3: + tensors: list[tuple[str, Tensor]] = [] + + # Merge experts into 3D tensors + for w_name in ["down_proj", "gate_proj", "up_proj"]: + datas: list[Tensor] = [] + + for xid in range(n_experts): + ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight" + datas.append(self._experts[bid][ename]) + del self._experts[bid][ename] + + data_torch = torch.stack(datas, dim=0) + merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" + new_name = self.map_tensor_name(merged_name) + tensors.append((new_name, data_torch)) + + return tensors + else: + return [] + + return super().modify_tensors(data_torch, name, bid) + + def prepare_tensors(self): + super().prepare_tensors() + + if self._experts is not None: + # Check for unprocessed experts + experts = [k for d in self._experts for k in d.keys()] + if len(experts) > 0: + raise ValueError(f"Unprocessed experts: {experts}") + + +@ModelBase.register("VaetkiVisionModel", "VaetkiVLForCausalLM") +class VaetkiVisionModel(MmprojModel): + """VAETKI Vision Model (mmproj) - Rice ViT with CLS token and 2D RoPE""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + assert self.hparams_vision is not None + # Remap vision config parameters to standard names + self.hparams_vision["num_attention_heads"] = self.hparams_vision.get("num_heads") + self.hparams_vision["num_hidden_layers"] = self.hparams_vision.get("depth") + if "embed_dim" in self.hparams_vision: + self.hparams_vision["hidden_size"] = self.hparams_vision.get("embed_dim") + if "image_size" not in self.hparams_vision: + self.hparams_vision["image_size"] = 560 # unused, set for compatibility + + def set_gguf_parameters(self): + super().set_gguf_parameters() + assert self.hparams_vision is not None + hparams = self.hparams_vision + + # VAETKI projector type - routes to vaetki.cpp graph builder + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.VAETKI) + self.gguf_writer.add_vision_attention_layernorm_eps(hparams.get("layer_norm_eps", 1e-5)) + self.gguf_writer.add_vision_spatial_merge_size(hparams.get("spatial_merge_size", 2)) + + # support dynamic size + self.gguf_writer.add_vision_image_min_pixels(self.preprocessor_config["min_pixels"]) + self.gguf_writer.add_vision_image_max_pixels(self.preprocessor_config["max_pixels"]) + + def tensor_force_quant(self, name, new_name, bid, n_dims): + if "class_pos_embd" in new_name: + return gguf.GGMLQuantizationType.F32 + return super().tensor_force_quant(name, new_name, bid, n_dims) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + # Only process vision tensors + if not (name.startswith("model.visual.") or name.startswith("visual.")): + return [] + + # Handle merger tensors with special index mapping + # clip.cpp PROJECTOR_TYPE_VAETKI expects: + # mm.input_norm.* -> ln_q (pre-norm) + # mm.up.* -> mlp.0 (up projection) + # mm.down.* -> mlp.2 (down projection) + if "merger.ln_q" in name: + suffix = ".weight" if name.endswith(".weight") else ".bias" + return [(self.format_tensor_name(gguf.MODEL_TENSOR.V_MM_INP_NORM, suffix=suffix), data_torch)] + elif "merger.mlp.0" in name: + suffix = ".weight" if name.endswith(".weight") else ".bias" + return [(self.format_tensor_name(gguf.MODEL_TENSOR.V_MM_UP, suffix=suffix), data_torch)] + elif "merger.mlp.2" in name: + suffix = ".weight" if name.endswith(".weight") else ".bias" + return [(self.format_tensor_name(gguf.MODEL_TENSOR.V_MM_DOWN, suffix=suffix), data_torch)] + + # Handle class_embedding and class_pos_emb (keep model.visual. prefix for mapping) + if "class_embedding" in name or "class_pos_emb" in name: + return [(self.map_tensor_name(name), data_torch)] + + # Strip model.visual. -> visual. for other tensors + if name.startswith("model.visual."): + name = name.replace("model.visual.", "visual.") + + return [(self.map_tensor_name(name), data_torch)] + + @ModelBase.register("MiniMaxM2ForCausalLM") class MiniMaxM2Model(TextModel): model_arch = gguf.MODEL_ARCH.MINIMAXM2 diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 31273b2b5a..d8db79db19 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -284,6 +284,8 @@ class Keys: class ClipVision: PROJECTOR_TYPE = "clip.vision.projector_type" # for mixed modality models IMAGE_SIZE = "clip.vision.image_size" + IMAGE_MIN_PIXELS = "clip.vision.image_min_pixels" + IMAGE_MAX_PIXELS = "clip.vision.image_max_pixels" PREPROC_IMAGE_SIZE = "clip.vision.preproc_image_size" PATCH_SIZE = "clip.vision.patch_size" EMBEDDING_LENGTH = "clip.vision.embedding_length" @@ -459,6 +461,7 @@ class MODEL_ARCH(IntEnum): MIMO2 = auto() LLAMA_EMBED = auto() MAINCODER = auto() + VAETKI = auto() class VISION_PROJECTOR_TYPE(IntEnum): @@ -655,6 +658,7 @@ class MODEL_TENSOR(IntEnum): V_MMPROJ_MLP = auto() V_MMPROJ_PEG = auto() V_ENC_EMBD_CLS = auto() + V_ENC_EMBD_CLS_POS = auto() V_ENC_EMBD_PATCH = auto() V_ENC_EMBD_NORM = auto() V_ENC_EMBD_POS = auto() @@ -880,6 +884,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.MIMO2: "mimo2", MODEL_ARCH.LLAMA_EMBED: "llama-embed", MODEL_ARCH.MAINCODER: "maincoder", + MODEL_ARCH.VAETKI: "vaetki", } VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = { @@ -1073,6 +1078,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.V_MMPROJ_MLP: "mm.model.mlp.{bid}", MODEL_TENSOR.V_MMPROJ_PEG: "mm.model.peg.{bid}", MODEL_TENSOR.V_ENC_EMBD_CLS: "v.class_embd", + MODEL_TENSOR.V_ENC_EMBD_CLS_POS: "v.class_pos_embd", MODEL_TENSOR.V_ENC_EMBD_PATCH: "v.patch_embd", MODEL_TENSOR.V_ENC_EMBD_NORM: "v.norm_embd", MODEL_TENSOR.V_ENC_EMBD_POS: "v.position_embd", @@ -1191,6 +1197,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.V_MMPROJ_MLP, MODEL_TENSOR.V_MMPROJ_PEG, MODEL_TENSOR.V_ENC_EMBD_CLS, + MODEL_TENSOR.V_ENC_EMBD_CLS_POS, MODEL_TENSOR.V_ENC_EMBD_PATCH, MODEL_TENSOR.V_ENC_EMBD_NORM, MODEL_TENSOR.V_ENC_EMBD_POS, @@ -3377,6 +3384,34 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.VAETKI: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q_A_NORM, + MODEL_TENSOR.ATTN_KV_A_NORM, + MODEL_TENSOR.ATTN_Q_A, + MODEL_TENSOR.ATTN_Q_B, + MODEL_TENSOR.ATTN_KV_A_MQA, + MODEL_TENSOR.ATTN_KV_B, + MODEL_TENSOR.ATTN_K_B, + MODEL_TENSOR.ATTN_V_B, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_POST_NORM, + MODEL_TENSOR.FFN_PRE_NORM, + MODEL_TENSOR.FFN_POST_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_GATE_SHEXP, + MODEL_TENSOR.FFN_DOWN_SHEXP, + MODEL_TENSOR.FFN_UP_SHEXP, + ], # TODO } @@ -3617,6 +3652,7 @@ class VisionProjectorType: MUSIC_FLAMINGO = "musicflamingo" # audio GLM4V = "glm4v" YOUTUVL = "youtuvl" + VAETKI = "vaetki" # Items here are (block size, type size) diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 7fbb78866b..0fe91786aa 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -1113,6 +1113,12 @@ class GGUFWriter: def add_vision_image_size(self, value: int) -> None: self.add_uint32(Keys.ClipVision.IMAGE_SIZE, value) + def add_vision_image_max_pixels(self, value: int) -> None: + self.add_uint32(Keys.ClipVision.IMAGE_MAX_PIXELS, value) + + def add_vision_image_min_pixels(self, value: int) -> None: + self.add_uint32(Keys.ClipVision.IMAGE_MIN_PIXELS, value) + def add_vision_preproc_image_size(self, value: int) -> None: self.add_uint32(Keys.ClipVision.PREPROC_IMAGE_SIZE, value) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 84aa868809..3f121fc494 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -1281,6 +1281,11 @@ class TensorNameMap: "model.vision_tower.embeddings.cls_token", # Intern-S1 "vision_model.class_embedding", # llama 4 "model.vision.patch_embedding.cls_embedding", # cogvlm + "model.visual.class_embedding", # vaetki + ), + + MODEL_TENSOR.V_ENC_EMBD_CLS_POS: ( + "model.visual.class_pos_emb", # vaetki ), MODEL_TENSOR.V_ENC_EMBD_PATCH: ( @@ -1466,6 +1471,7 @@ class TensorNameMap: "vision_tower.ln_pre", # pixtral-hf "vision_encoder.ln_pre", # pixtral "vision_model.layernorm_pre", # llama4 + "visual.pre_layernorm", # vaetki ), MODEL_TENSOR.V_POST_NORM: ( diff --git a/include/llama.h b/include/llama.h index bf4e28a8be..b86790f3e0 100644 --- a/include/llama.h +++ b/include/llama.h @@ -76,6 +76,7 @@ extern "C" { LLAMA_VOCAB_TYPE_UGM = 4, // T5 tokenizer based on Unigram LLAMA_VOCAB_TYPE_RWKV = 5, // RWKV tokenizer based on greedy tokenization LLAMA_VOCAB_TYPE_PLAMO2 = 6, // PLaMo-2 tokenizer based on Aho-Corasick with dynamic programming + LLAMA_VOCAB_TYPE_VAETKI = 7, // VAETKI tokenizer based on rank-based BPE with SPM-style space markers }; enum llama_rope_type { diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index f337afd6b3..50f50fa758 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -137,6 +137,7 @@ add_library(llama models/t5-dec.cpp models/t5-enc.cpp models/wavtokenizer-dec.cpp + models/vaetki.cpp models/xverse.cpp models/mistral3.cpp models/graph-context-mamba.cpp diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index a54bc1956a..040dd8f529 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -120,6 +120,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_MIMO2, "mimo2" }, { LLM_ARCH_LLAMA_EMBED, "llama-embed" }, { LLM_ARCH_MAINCODER, "maincoder" }, + { LLM_ARCH_VAETKI, "vaetki" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -339,6 +340,7 @@ static const std::map LLM_TENSOR_NAMES = { { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, { LLM_TENSOR_ATTN_GATE, "blk.%d.attn_gate" }, + { LLM_TENSOR_FFN_PRE_NORM, "blk.%d.ffn_norm" }, { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, @@ -2289,6 +2291,35 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_FFN_DOWN, LLM_TENSOR_FFN_UP, }; + case LLM_ARCH_VAETKI: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q_A_NORM, + LLM_TENSOR_ATTN_KV_A_NORM, + LLM_TENSOR_ATTN_Q_A, + LLM_TENSOR_ATTN_Q_B, + LLM_TENSOR_ATTN_KV_A_MQA, + LLM_TENSOR_ATTN_KV_B, + LLM_TENSOR_ATTN_K_B, + LLM_TENSOR_ATTN_V_B, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_ATTN_POST_NORM, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_POST_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_GATE_SHEXP, + LLM_TENSOR_FFN_DOWN_SHEXP, + LLM_TENSOR_FFN_UP_SHEXP, + }; default: GGML_ABORT("unknown architecture for tensor mapping"); } diff --git a/src/llama-arch.h b/src/llama-arch.h index 270d28b16a..f784558dfc 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -124,6 +124,7 @@ enum llm_arch { LLM_ARCH_MIMO2, LLM_ARCH_LLAMA_EMBED, LLM_ARCH_MAINCODER, + LLM_ARCH_VAETKI, LLM_ARCH_UNKNOWN, }; @@ -345,6 +346,7 @@ enum llm_tensor { LLM_TENSOR_FFN_GATE_INP, LLM_TENSOR_FFN_GATE_INP_SHEXP, LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_PRE_NORM, LLM_TENSOR_FFN_POST_NORM, LLM_TENSOR_FFN_GATE, LLM_TENSOR_FFN_DOWN, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 72490a89b5..f87b6b1b08 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -127,6 +127,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_31B_A3_5B: return "31B.A3.5B"; case LLM_TYPE_80B_A3B: return "80B.A3B"; case LLM_TYPE_100B_A6B: return "100B.A6B"; + case LLM_TYPE_100B_A10B: return "100B.A10B"; case LLM_TYPE_102B_A12B: return "102B.A12B"; case LLM_TYPE_106B_A12B: return "106B.A12B"; case LLM_TYPE_230B_A10B: return "230B.A10B"; @@ -1129,6 +1130,39 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_VAETKI: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla); + hparams.n_embd_head_k = hparams.n_embd_head_k_mla; + hparams.n_embd_head_v = hparams.n_embd_head_v_mla; + hparams.n_head_kv_arr = hparams.n_head_arr; + + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; + } + + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer); + + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_7B_A1B; break; + case 48: type = LLM_TYPE_100B_A10B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_QWEN3VL: { ml.get_key(LLM_KV_NUM_DEEPSTACK_LAYERS, hparams.n_deepstack_layers, false); @@ -6971,6 +7005,64 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); } } break; + case LLM_ARCH_VAETKI: + { + const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla; + const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla; + const int64_t n_embd_head_qk_rope = hparams.n_rot; + const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope; + + const int64_t q_lora_rank = hparams.n_lora_q; + const int64_t kv_lora_rank = hparams.n_lora_kv; + + const int64_t n_ff_exp = hparams.n_ff_exp; + const int64_t n_expert_shared = hparams.n_expert_shared; + const int64_t n_layer_dense = hparams.n_layer_dense_lead; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0); + layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, 0); + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_qk_rope)}, 0); + + layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + n_embd_head_qk_rope}, 0); + layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0); + layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v_mla)}, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v_mla, n_embd}, 0); + + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + + if (i < n_layer_dense) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_exp * n_expert_shared, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + } + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -7324,6 +7416,18 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); } + if (arch == LLM_ARCH_VAETKI) { + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv); + LLAMA_LOG_INFO("%s: n_embd_head_k_mla = %d\n", __func__, hparams.n_embd_head_k_mla); + LLAMA_LOG_INFO("%s: n_embd_head_v_mla = %d\n", __func__, hparams.n_embd_head_v_mla); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); + } + if (arch == LLM_ARCH_QWEN2MOE) { LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); @@ -7645,6 +7749,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_VAETKI: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_DECI: { llm = std::make_unique(*this, params); @@ -8268,6 +8376,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_MISTRAL3: case LLM_ARCH_LLAMA_EMBED: case LLM_ARCH_MAINCODER: + case LLM_ARCH_VAETKI: return LLAMA_ROPE_TYPE_NORM; // the pairs of head values are offset by n_rot/2 diff --git a/src/llama-model.h b/src/llama-model.h index d1de16e3f2..604cb22d9a 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -120,6 +120,7 @@ enum llm_type { LLM_TYPE_31B_A3_5B, LLM_TYPE_80B_A3B, // Qwen3 Next LLM_TYPE_100B_A6B, + LLM_TYPE_100B_A10B, // VAETKI LLM_TYPE_102B_A12B, // Solar-Open LLM_TYPE_106B_A12B, // GLM-4.5-Air LLM_TYPE_230B_A10B, // Minimax M2 @@ -261,6 +262,7 @@ struct llama_layer { struct ggml_tensor * ffn_norm = nullptr; struct ggml_tensor * ffn_norm_b = nullptr; struct ggml_tensor * ffn_post_norm = nullptr; + struct ggml_tensor * ffn_pre_norm = nullptr; struct ggml_tensor * layer_out_norm = nullptr; struct ggml_tensor * layer_out_norm_b = nullptr; struct ggml_tensor * ffn_norm_exps = nullptr; diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index a23950d007..010574ace8 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -1518,6 +1518,206 @@ private: const llm_tokenizer_plamo2 & tokenizer; }; +// +// VAETKI tokenizer +// Hybrid tokenizer: SPM-style ▁ space markers + BPE rank-based merges + <0xXX> byte fallback +// + +struct llm_tokenizer_vaetki : llm_tokenizer { + llm_tokenizer_vaetki(const llama_vocab & vocab) { + GGML_ASSERT(vocab.get_type() == LLAMA_VOCAB_TYPE_VAETKI); + } +}; + +struct llm_tokenizer_vaetki_session { + llm_tokenizer_vaetki_session(const llama_vocab & vocab) + : vocab(vocab) {} + + void tokenize(const std::string & text, std::vector & output) { + int final_prev_index = -1; + + // Normalize - replace all spaces with ▁ (U+2581) + std::string normalized; + normalized.reserve(text.size() * 3); + for (size_t i = 0; i < text.size(); ) { + if (text[i] == ' ') { + normalized += "\xE2\x96\x81"; + i++; + } else { + size_t char_len = unicode_len_utf8(text[i]); + normalized += text.substr(i, char_len); + i += char_len; + } + } + + // Split on ▁ boundaries, keeping ▁ with following text + // "Hello▁World" -> ["Hello", "▁World"] + // "Hello▁▁World" -> ["Hello", "▁▁World"] + std::vector word_collection; + std::string current_word; + const std::string escaped_space = "\xE2\x96\x81"; // ▁ (U+2581) + + for (size_t i = 0; i < normalized.size(); ) { + size_t char_len = unicode_len_utf8(normalized[i]); + + if (char_len == 3 && + i + 2 < normalized.size() && + (unsigned char)normalized[i] == 0xE2 && + (unsigned char)normalized[i+1] == 0x96 && + (unsigned char)normalized[i+2] == 0x81) { + if (!current_word.empty()) { + word_collection.push_back(current_word); + current_word.clear(); + } + current_word = escaped_space; + i += 3; + } else { + current_word += normalized.substr(i, char_len); + i += char_len; + } + } + if (!current_word.empty()) { + word_collection.push_back(current_word); + } + + symbols_final.clear(); + + for (const auto & word : word_collection) { + work_queue = llm_bigram_bpe::queue(); + symbols.clear(); + + int index = 0; + size_t offset = 0; + + // Check if word exists as a single token (ignore_merges behavior) + if (vocab.text_to_token(word) != LLAMA_TOKEN_NULL) { + symbols.emplace_back(llm_symbol{-1, -1, word.c_str(), word.size()}); + offset = word.size(); + } + + while (offset < word.size()) { + llm_symbol sym; + size_t char_len = std::min(word.size() - offset, (size_t) unicode_len_utf8(word[offset])); + sym.text = word.c_str() + offset; + sym.n = char_len; + offset += sym.n; + sym.prev = index - 1; + sym.next = offset == word.size() ? -1 : index + 1; + index++; + symbols.emplace_back(sym); + } + for (int i = 1; i < (int) symbols.size(); ++i) { + add_new_bigram(i - 1, i); + } + + // build token(s) + while (!work_queue.empty()) { + auto bigram = work_queue.pop_move(); + + auto & left_symbol = symbols[bigram.left]; + auto & right_symbol = symbols[bigram.right]; + + if (left_symbol.n == 0 || right_symbol.n == 0) { + continue; + } + std::string left_token = std::string(left_symbol.text, left_symbol.n); + std::string right_token = std::string(right_symbol.text, right_symbol.n); + if (left_token + right_token != bigram.text) { + continue; // Skip this bigram if it's outdated + } + + // merge the right sym into the left one + left_symbol.n += right_symbol.n; + right_symbol.n = 0; + + // remove the right sym from the chain + left_symbol.next = right_symbol.next; + if (right_symbol.next >= 0) { + symbols[right_symbol.next].prev = bigram.left; + } + + add_new_bigram(left_symbol.prev, bigram.left); // left side of current symbol + add_new_bigram(bigram.left, left_symbol.next); // right side of current symbol + } + + // add the finished tokens to the final list keeping correct order for next and prev + for (auto & sym : symbols) { + if (sym.n > 0) { + sym.prev = final_prev_index; + sym.next = -1; + if (final_prev_index != -1) { + symbols_final[final_prev_index].next = symbols_final.size(); + } + symbols_final.emplace_back(sym); + final_prev_index = symbols_final.size() - 1; + } + } + } + + symbols = symbols_final; + + if (!symbols.empty()) { + for (int i = 0; i != -1; i = symbols[i].next) { + auto & symbol = symbols[i]; + if (symbol.n == 0) { + continue; + } + + const std::string str = std::string(symbol.text, symbol.n); + const auto token = vocab.text_to_token(str); + + if (token == LLAMA_TOKEN_NULL) { + // Byte fallback: use <0xXX> format + for (auto j = str.begin(); j != str.end(); ++j) { + char buf[8]; + snprintf(buf, sizeof(buf), "<0x%02X>", static_cast(*j)); + std::string byte_str(buf); + auto token_byte = vocab.text_to_token(byte_str); + if (token_byte != LLAMA_TOKEN_NULL) { + output.push_back(token_byte); + } + } + } else { + output.push_back(token); + } + } + } + } + +private: + void add_new_bigram(int left, int right) { + if (left == -1 || right == -1) { + return; + } + std::string left_token = std::string(symbols[left].text, symbols[left].n); + std::string right_token = std::string(symbols[right].text, symbols[right].n); + + int rank_found = -1; + + rank_found = vocab.find_bpe_rank(left_token, right_token); + + if (rank_found < 0) { + return; + } + + llm_bigram_bpe bigram; + + bigram.left = left; + bigram.right = right; + bigram.text = left_token + right_token; + bigram.size = left_token.size() + right_token.size(); + bigram.rank = rank_found; + + work_queue.push(bigram); + } + + const llama_vocab & vocab; + + std::vector symbols; + std::vector symbols_final; + llm_bigram_bpe::queue work_queue; +}; + // // impl // @@ -1831,6 +2031,39 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { special_sep_id = LLAMA_TOKEN_NULL; special_pad_id = 3; // <|plamo:pad|> special_mask_id = LLAMA_TOKEN_NULL; + } else if (tokenizer_model == "vaetki") { + type = LLAMA_VOCAB_TYPE_VAETKI; + + // read bpe merges and populate bpe ranks (same as gpt2) + const int merges_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_MERGES).c_str()); + if (merges_keyidx == -1) { + throw std::runtime_error("cannot find tokenizer merges in model file\n"); + } + + const int n_merges = gguf_get_arr_n(ctx, merges_keyidx); + for (int i = 0; i < n_merges; i++) { + const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i); + + std::string first; + std::string second; + + const size_t pos = word.find(' ', 1); + + if (pos != std::string::npos) { + first = word.substr(0, pos); + second = word.substr(pos + 1); + } + + bpe_ranks.emplace(std::make_pair(first, second), i); + } + + // VAETKI default special tokens (will be overridden by model config) + special_bos_id = LLAMA_TOKEN_NULL; + special_eos_id = LLAMA_TOKEN_NULL; + special_unk_id = LLAMA_TOKEN_NULL; + special_sep_id = LLAMA_TOKEN_NULL; + special_pad_id = LLAMA_TOKEN_NULL; + special_mask_id = LLAMA_TOKEN_NULL; } else { throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str())); } @@ -2626,6 +2859,7 @@ std::string llama_vocab::impl::type_name() const{ case LLAMA_VOCAB_TYPE_UGM: return "UGM"; case LLAMA_VOCAB_TYPE_RWKV: return "RWKV"; case LLAMA_VOCAB_TYPE_PLAMO2: return "PLaMo2"; + case LLAMA_VOCAB_TYPE_VAETKI: return "VAETKI"; default: return "unknown"; } } @@ -2670,7 +2904,9 @@ uint8_t llama_vocab::impl::token_to_byte(llama_token id) const { const auto & token_data = id_to_token.at(id); switch (get_type()) { case LLAMA_VOCAB_TYPE_SPM: - case LLAMA_VOCAB_TYPE_UGM: { + case LLAMA_VOCAB_TYPE_UGM: + case LLAMA_VOCAB_TYPE_VAETKI: { + // <0xXX> format auto buf = token_data.text.substr(3, 2); return strtol(buf.c_str(), NULL, 16); } @@ -2712,6 +2948,9 @@ void llama_vocab::impl::init_tokenizer(enum llama_vocab_type type) { case LLAMA_VOCAB_TYPE_PLAMO2: tokenizer = std::make_unique(vocab); break; + case LLAMA_VOCAB_TYPE_VAETKI: + tokenizer = std::make_unique(vocab); + break; default: GGML_ABORT("unsupported vocab type"); } @@ -3071,6 +3310,41 @@ std::vector llama_vocab::impl::tokenize( } } } break; + case LLAMA_VOCAB_TYPE_VAETKI: + { + llm_tokenizer_vaetki_session session(vocab); + + if (add_special && add_bos) { + GGML_ASSERT(special_bos_id != LLAMA_TOKEN_NULL); + output.push_back(special_bos_id); + } + + for (const auto & fragment : fragment_buffer) { + if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { + std::string text = fragment.raw_text.substr(fragment.offset, fragment.length); + +#ifdef PRETOKENIZERDEBUG + LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str()); +#endif + + session.tokenize(text, output); + } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) + output.push_back(fragment.token); + } + } + + if (add_special && add_bos && output.size() >= 2 && output[1] == special_bos_id) { + LLAMA_LOG_WARN( + "%s: Added a BOS token to the prompt as specified by the model but the prompt " + "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. " + "Are you sure this is what you want?\n", __FUNCTION__); + } + + if (add_special && add_eos) { + GGML_ASSERT(special_eos_id != LLAMA_TOKEN_NULL); + output.push_back(special_eos_id); + } + } break; case LLAMA_VOCAB_TYPE_NONE: GGML_ABORT("fatal error"); } @@ -3119,7 +3393,8 @@ int32_t llama_vocab::impl::token_to_piece(llama_token token, char * buf, int32_t switch (get_type()) { case LLAMA_VOCAB_TYPE_WPM: case LLAMA_VOCAB_TYPE_SPM: - case LLAMA_VOCAB_TYPE_UGM: { + case LLAMA_VOCAB_TYPE_UGM: + case LLAMA_VOCAB_TYPE_VAETKI: { // NOTE: we accept all unsupported token types, // suppressing them like CONTROL tokens. if (attr & (attr_special | LLAMA_TOKEN_ATTR_USER_DEFINED)) { @@ -3420,8 +3695,9 @@ llama_token llama_vocab::byte_to_token(uint8_t ch) const { case LLAMA_VOCAB_TYPE_BPE: { return pimpl->token_to_id.at(unicode_byte_to_utf8(ch)); } - case LLAMA_VOCAB_TYPE_PLAMO2: { - // PLaMo-2 uses byte tokens in format <0xXX> + case LLAMA_VOCAB_TYPE_PLAMO2: + case LLAMA_VOCAB_TYPE_VAETKI: { + // PLaMo-2/VAETKI uses byte tokens in format <0xXX> char hex_str[8]; snprintf(hex_str, sizeof(hex_str), "<0x%02X>", ch); return pimpl->token_to_id.at(hex_str); diff --git a/src/models/models.h b/src/models/models.h index 3a44f7f140..e5ce8141e4 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -568,6 +568,10 @@ struct llm_build_wavtokenizer_dec : public llm_graph_context { llm_build_wavtokenizer_dec(const llama_model & model, const llm_graph_params & params); }; +struct llm_build_vaetki : public llm_graph_context { + llm_build_vaetki(const llama_model & model, const llm_graph_params & params); +}; + struct llm_build_xverse : public llm_graph_context { llm_build_xverse(const llama_model & model, const llm_graph_params & params); }; diff --git a/src/models/vaetki.cpp b/src/models/vaetki.cpp new file mode 100644 index 0000000000..8ddd39821e --- /dev/null +++ b/src/models/vaetki.cpp @@ -0,0 +1,182 @@ +#include "models.h" + +llm_build_vaetki::llm_build_vaetki(const llama_model & model, const llm_graph_params & params) : + llm_graph_context(params) { + + const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla; + const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla; + + const int64_t n_embd_head_qk_rope = hparams.n_rot; + const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope; + + const uint32_t kv_lora_rank = hparams.n_lora_kv; + + const float kq_scale = 1.0f / sqrtf(float(n_embd_head_qk_nope + n_embd_head_qk_rope)); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_iswa(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + const float freq_base_l = model.get_rope_freq_base(cparams, il); + const float freq_scale_l = model.get_rope_freq_scale(cparams, il); + + ggml_tensor * inpSA = inpL; + + cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self_attention + { + ggml_tensor * q = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur); + cb(q, "q_a", il); + + q = build_norm(q, model.layers[il].attn_q_a_norm, nullptr, LLM_NORM_RMS, il); + cb(q, "q_a_norm", il); + + q = ggml_mul_mat(ctx0, model.layers[il].wq_b, q); + cb(q, "q", il); + + // q is now [rope | nope] after weight reordering in conversion + // reshape to {n_embd_head_k_mla, n_head, n_tokens} + q = ggml_reshape_3d(ctx0, q, n_embd_head_k_mla, n_head, n_tokens); + + ggml_tensor * kv_cmpr_pe = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur); + cb(kv_cmpr_pe, "kv_cmpr_pe", il); + + // {kv_lora_rank, n_tokens} + ggml_tensor * kv_cmpr = ggml_view_2d(ctx0, kv_cmpr_pe, + kv_lora_rank, n_tokens, + ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), 0); + cb(kv_cmpr, "kv_cmpr", il); + + // {n_embd_head_qk_rope, 1, n_tokens} + ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_cmpr_pe, + n_embd_head_qk_rope, 1, n_tokens, + ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), + ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), + ggml_row_size(kv_cmpr_pe->type, kv_lora_rank)); + cb(k_pe, "k_pe", il); + + // apply rope - rotates first n_rot dims, copies rest unchanged + ggml_tensor * Qcur = ggml_rope_ext(ctx0, q, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, + freq_base_l, freq_scale_l, ext_factor, attn_factor, beta_fast, beta_slow); + cb(Qcur, "Qcur", il); + + k_pe = ggml_rope_ext(ctx0, k_pe, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, + freq_base_l, freq_scale_l, ext_factor, attn_factor, beta_fast, beta_slow); + cb(k_pe, "k_pe_rope", il); + + kv_cmpr = build_norm(kv_cmpr, model.layers[il].attn_kv_a_norm, nullptr, LLM_NORM_RMS, il); + cb(kv_cmpr, "kv_cmpr_norm", il); + + ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cmpr); + cb(kv, "kv", il); + + // {n_embd_head_qk_nope, n_head, n_tokens} + ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, + n_embd_head_qk_nope, n_head, n_tokens, + ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v_mla), + ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v_mla) * n_head, 0); + cb(k_nope, "k_nope", il); + + // {n_embd_head_v_mla, n_head, n_tokens} + ggml_tensor * Vcur = ggml_view_3d(ctx0, kv, + n_embd_head_v_mla, n_head, n_tokens, + ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v_mla), + ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v_mla) * n_head, + ggml_row_size(kv->type, n_embd_head_qk_nope)); + cb(Vcur, "Vcur", il); + + Vcur = ggml_cont(ctx0, Vcur); + cb(Vcur, "Vcur_cont", il); + + ggml_tensor * q_pe_ref = ggml_view_3d(ctx0, Qcur, + n_embd_head_qk_rope, n_head, n_tokens, + Qcur->nb[1], Qcur->nb[2], 0); + ggml_tensor * Kcur = ggml_concat(ctx0, ggml_repeat(ctx0, k_pe, q_pe_ref), k_nope, 0); + cb(Kcur, "Kcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + } + + cur = build_norm(cur, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_post_norm", il); + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + if ((uint32_t) il < hparams.n_layer_dense_lead) { + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + ggml_tensor * moe_out = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, hparams.expert_weights_norm, + hparams.expert_weights_scale, hparams.expert_weights_scale, + (llama_expert_gating_func_type) hparams.expert_gating_func, + il); + cb(moe_out, "ffn_moe_out", il); + + ggml_tensor * ffn_shexp = + build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "ffn_shexp", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); + } + + cur = build_norm(cur, model.layers[il].ffn_post_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "ffn_post_norm", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + inpL = cur; + } + cur = inpL; + + cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + cur = ggml_mul_mat(ctx0, model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/tools/mtmd/CMakeLists.txt b/tools/mtmd/CMakeLists.txt index 751440af32..4163058d4c 100644 --- a/tools/mtmd/CMakeLists.txt +++ b/tools/mtmd/CMakeLists.txt @@ -26,6 +26,7 @@ add_library(mtmd models/qwen2vl.cpp models/qwen3vl.cpp models/siglip.cpp + models/vaetki.cpp models/whisper-enc.cpp models/mobilenetv5.cpp models/youtuvl.cpp diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index dd693623a2..74dfa2f05a 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -36,6 +36,8 @@ // vision-specific #define KEY_VISION_PROJ_TYPE "clip.vision.projector_type" // for models with mixed modalities #define KEY_IMAGE_SIZE "clip.vision.image_size" +#define KEY_IMAGE_MIN_PIXELS "clip.vision.image_min_pixels" +#define KEY_IMAGE_MAX_PIXELS "clip.vision.image_max_pixels" #define KEY_PREPROC_IMAGE_SIZE "clip.vision.preproc_image_size" #define KEY_PATCH_SIZE "clip.vision.patch_size" #define KEY_IMAGE_MEAN "clip.vision.image_mean" @@ -66,6 +68,7 @@ #define TN_POS_EMBD "%s.position_embd.weight" #define TN_CLASS_EMBD "v.class_embd" +#define TN_CLASS_POS_EMBD "v.class_pos_embd" #define TN_PATCH_EMBD "v.patch_embd.weight" // not rename tensor with ".0" postfix for backwrad compat #define TN_PATCH_EMBD_1 "v.patch_embd.weight.1" #define TN_PATCH_BIAS "v.patch_embd.bias" @@ -233,6 +236,7 @@ enum projector_type { PROJECTOR_TYPE_LFM2A, PROJECTOR_TYPE_GLM4V, PROJECTOR_TYPE_YOUTUVL, + PROJECTOR_TYPE_VAETKI, PROJECTOR_TYPE_UNKNOWN, }; @@ -266,6 +270,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_LFM2A, "lfm2a"}, { PROJECTOR_TYPE_GLM4V, "glm4v"}, { PROJECTOR_TYPE_YOUTUVL, "youtuvl"}, + { PROJECTOR_TYPE_VAETKI, "vaetki"}, }; static projector_type clip_projector_type_from_string(const std::string & str) { diff --git a/tools/mtmd/clip-model.h b/tools/mtmd/clip-model.h index d4ff9151bb..7d7af1b8e5 100644 --- a/tools/mtmd/clip-model.h +++ b/tools/mtmd/clip-model.h @@ -219,6 +219,7 @@ struct clip_model { // embeddings ggml_tensor * class_embedding = nullptr; + ggml_tensor * class_pos_emb = nullptr; ggml_tensor * patch_embeddings_0 = nullptr; ggml_tensor * patch_embeddings_1 = nullptr; // second Conv2D kernel when we decouple Conv3D along temproal dimension (Qwen2VL) ggml_tensor * patch_bias = nullptr; diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 9fa5afc390..4b17fcadd5 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -849,6 +849,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 { builder = std::make_unique(ctx, img); } break; + case PROJECTOR_TYPE_VAETKI: + { + builder = std::make_unique(ctx, img); + } break; default: GGML_ABORT("missing cgraph builder"); } @@ -1194,6 +1198,15 @@ struct clip_model_loader { hparams.set_limit_image_tokens(8, 4096); hparams.set_warmup_n_tokens(46*46); // avoid OOM on warmup } break; + case PROJECTOR_TYPE_VAETKI: + { + hparams.rope_theta = 10000.0f; + hparams.n_merge = 2; + get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.n_merge, false); + get_u32(KEY_IMAGE_MIN_PIXELS, hparams.image_min_pixels); + get_u32(KEY_IMAGE_MAX_PIXELS, hparams.image_max_pixels); + hparams.set_warmup_n_tokens(40*40); + } break; case PROJECTOR_TYPE_LLAMA4: { hparams.rope_theta = 10000.0f; @@ -1542,6 +1555,16 @@ struct clip_model_loader { model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); // merger.mlp.2 model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias")); } break; + case PROJECTOR_TYPE_VAETKI: + { + model.class_pos_emb = get_tensor(TN_CLASS_POS_EMBD); + model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM); + model.mm_input_norm_b = get_tensor(TN_MM_INP_NORM_B); + model.mm_ffn_up_w = get_tensor(string_format(TN_MM_UP, "weight")); + model.mm_ffn_up_b = get_tensor(string_format(TN_MM_UP, "bias")); + model.mm_ffn_down_w = get_tensor(string_format(TN_MM_DOWN, "weight")); + model.mm_ffn_down_b = get_tensor(string_format(TN_MM_DOWN, "bias")); + } break; case PROJECTOR_TYPE_GLM4V: { model.projection = get_tensor(TN_MM_PROJECTOR); @@ -2834,6 +2857,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str case PROJECTOR_TYPE_QWEN25VL: case PROJECTOR_TYPE_QWEN3VL: case PROJECTOR_TYPE_GLM4V: + case PROJECTOR_TYPE_VAETKI: { GGML_ASSERT(params.image_min_pixels > 0 && params.image_max_pixels > 0); clip_image_u8 resized; @@ -3234,6 +3258,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im case PROJECTOR_TYPE_IDEFICS3: case PROJECTOR_TYPE_INTERNVL: case PROJECTOR_TYPE_LLAMA4: + case PROJECTOR_TYPE_VAETKI: { // both X and Y are downscaled by the scale factor int scale_factor = ctx->model.hparams.n_merge; @@ -3481,11 +3506,14 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima case PROJECTOR_TYPE_QWEN2VL: case PROJECTOR_TYPE_QWEN3VL: case PROJECTOR_TYPE_GLM4V: + case PROJECTOR_TYPE_VAETKI: { const int merge_ratio = hparams.n_merge; const int pw = image_size_width / patch_size; const int ph = image_size_height / patch_size; - std::vector positions(n_pos * 4); + + const int pos_size = num_patches; + std::vector positions(pos_size * 4); int ptr = 0; for (int y = 0; y < ph; y += merge_ratio) { for (int x = 0; x < pw; x += merge_ratio) { @@ -3775,6 +3803,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { return ctx->model.mm_4h_to_h_w->ne[1]; case PROJECTOR_TYPE_LFM2A: return ctx->model.position_embeddings->ne[0]; + case PROJECTOR_TYPE_VAETKI: case PROJECTOR_TYPE_GLM4V: return ctx->model.mm_ffn_down_w->ne[1]; default: diff --git a/tools/mtmd/models/models.h b/tools/mtmd/models/models.h index 9970980c7b..3147c6495b 100644 --- a/tools/mtmd/models/models.h +++ b/tools/mtmd/models/models.h @@ -109,3 +109,8 @@ struct clip_graph_mobilenetv5 : clip_graph { ggml_tensor * inp, const mobilenetv5_block & block); }; + +struct clip_graph_vaetki : clip_graph { + clip_graph_vaetki(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} + ggml_cgraph * build() override; +}; diff --git a/tools/mtmd/models/vaetki.cpp b/tools/mtmd/models/vaetki.cpp new file mode 100644 index 0000000000..ee23cb01b0 --- /dev/null +++ b/tools/mtmd/models/vaetki.cpp @@ -0,0 +1,104 @@ +#include "models.h" + +ggml_cgraph * clip_graph_vaetki::build() { + GGML_ASSERT(model.class_embedding != nullptr); + + const int batch_size = 1; + const int n_pos = n_patches + 1; + const int n_pos_patches = n_patches; + const int num_position_ids = n_pos_patches * 4; + + norm_type norm_t = NORM_TYPE_NORMAL; + int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4}; + + ggml_tensor * inp = build_inp(); + + // add CLS token + inp = ggml_concat(ctx0, model.class_embedding, inp, 1); + cb(inp, "inp_with_cls", -1); + + // position IDs for 2D RoPE (patch tokens only) + ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_position_ids); + ggml_set_name(positions, "positions"); + ggml_set_input(positions); + + // precompute CLS position embedding cos/sin + ggml_tensor * cls_cos = nullptr; + ggml_tensor * cls_sin = nullptr; + if (model.class_pos_emb) { + ggml_tensor * cls_pos = ggml_concat(ctx0, model.class_pos_emb, model.class_pos_emb, 0); + cls_cos = ggml_cos(ctx0, cls_pos); + cls_sin = ggml_sin(ctx0, cls_pos); + } + + auto add_pos = [&](ggml_tensor * cur, const clip_layer &) -> ggml_tensor * { + // split CLS and patch tokens + // use cur->nb[2] to support both fused QKV (nb[2]=3*n_embd) and separate Q/K/V (nb[2]=n_embd) + ggml_tensor * cur_cls = ggml_view_3d(ctx0, cur, d_head, n_head, 1, + ggml_row_size(cur->type, d_head), + cur->nb[2], 0); + ggml_tensor * cur_patch = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos_patches, + ggml_row_size(cur->type, d_head), + cur->nb[2], + cur->nb[2]); + + // apply RoPE to CLS token using class_pos_emb + if (cls_cos && cls_sin) { + ggml_tensor * cls_1 = ggml_view_3d(ctx0, cur_cls, d_head/2, n_head, 1, + ggml_row_size(cur_cls->type, d_head), + ggml_row_size(cur_cls->type, d_head * n_head), 0); + ggml_tensor * cls_2 = ggml_view_3d(ctx0, cur_cls, d_head/2, n_head, 1, + ggml_row_size(cur_cls->type, d_head), + ggml_row_size(cur_cls->type, d_head * n_head), + ggml_row_size(cur_cls->type, d_head/2)); + ggml_tensor * cls_rot = ggml_concat(ctx0, ggml_neg(ctx0, cls_2), cls_1, 0); + + cur_cls = ggml_add(ctx0, + ggml_mul(ctx0, cur_cls, cls_cos), + ggml_mul(ctx0, cls_rot, cls_sin)); + } + + // apply 2D RoPE to patch tokens + cur_patch = ggml_rope_multi(ctx0, cur_patch, positions, nullptr, + d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1); + + return ggml_concat(ctx0, cur_cls, cur_patch, 2); + }; + + ggml_tensor * cur = build_vit( + inp, n_pos, + norm_t, + hparams.ffn_op, + nullptr, + add_pos); + + cb(cur, "vit_out", -1); + + // remove CLS token + ggml_tensor * embeddings = ggml_view_2d(ctx0, cur, + n_embd, n_pos_patches, + ggml_row_size(cur->type, n_embd), + ggml_row_size(cur->type, n_embd)); + cb(embeddings, "patches_only", -1); + + // merger + embeddings = build_norm(embeddings, model.mm_input_norm_w, model.mm_input_norm_b, NORM_TYPE_NORMAL, 1e-5, -1); + cb(embeddings, "merger_normed", -1); + + // pixel shuffle + const int scale_factor = hparams.n_merge; + embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd * scale_factor * scale_factor, n_pos_patches / (scale_factor * scale_factor), batch_size); + cb(embeddings, "merger_reshaped", -1); + + embeddings = build_ffn(embeddings, + model.mm_ffn_up_w, model.mm_ffn_up_b, + nullptr, nullptr, + model.mm_ffn_down_w, model.mm_ffn_down_b, + FFN_GELU, + -1); + cb(embeddings, "merger_out", -1); + + ggml_build_forward_expand(gf, embeddings); + + return gf; +}