diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index cfa60c4ede..7ad20c0869 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1233,6 +1233,9 @@ class TextModel(ModelBase): if chkhsh == "4a2e2abae11ca2b86d570fc5b44be4d5eb5e72cc8f22dd136a94b37da83ab665": # ref: https://huggingface.co/KORMo-Team/KORMo-tokenizer res = "kormo" + if chkhsh == "9d70134b369a70e5735009b6de918f7581b5211f7c074d1f89f753aea8248af1": + # ref: https://huggingface.co/tencent/Youtu-LLM-2B + res = "youtu" if chkhsh == "16389f0a1f51ee53e562ffd51c371dc508639ab0e4261502071836e50e223e91": # ref: https://huggingface.co/upstage/Solar-Open-100B res = "solar-open" @@ -7189,6 +7192,7 @@ class DeepseekModel(TextModel): "DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM", "KimiVLForConditionalGeneration", + "YoutuForCausalLM", ) class DeepseekV2Model(TextModel): model_arch = gguf.MODEL_ARCH.DEEPSEEK2 @@ -7255,7 +7259,15 @@ class DeepseekV2Model(TextModel): super().set_gguf_parameters() hparams = self.hparams - self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"]) + # first_k_dense_replace: number of leading layers using dense FFN instead of MoE + # For non-MoE models (like Youtu), set to n_layer to use dense FFN for all layers + # For MoE models (like DeepSeek-V2), this is the number of leading non-MoE layers + has_moe = hparams.get("n_routed_experts") is not None + first_k_dense_replace = hparams.get("first_k_dense_replace") + if first_k_dense_replace is None: + # Default: if no MoE, all layers are dense; if MoE, none are dense + first_k_dense_replace = hparams["num_hidden_layers"] if not has_moe else 0 + self.gguf_writer.add_leading_dense_block_count(first_k_dense_replace) self.gguf_writer.add_vocab_size(hparams["vocab_size"]) if "q_lora_rank" in hparams and hparams["q_lora_rank"] is not None: self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"]) @@ -7267,11 +7279,24 @@ class DeepseekV2Model(TextModel): self.gguf_writer.add_key_length_mla(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"]) self.gguf_writer.add_value_length_mla(hparams["v_head_dim"]) - self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"]) - self.gguf_writer.add_expert_count(hparams["n_routed_experts"]) - self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"]) - self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"]) - self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"]) + # MoE parameters (required by C++ code for DEEPSEEK2 arch) + # For non-MoE models like Youtu, use intermediate_size as expert_feed_forward_length + moe_intermediate_size = self.find_hparam(["moe_intermediate_size", "intermediate_size"], optional=False) + self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size) + + if (n_routed_experts := hparams.get("n_routed_experts")) is not None: + self.gguf_writer.add_expert_count(n_routed_experts) + + # expert_shared_count is required by C++ code, default to 0 for non-MoE models + n_shared_experts = hparams.get("n_shared_experts", 0) + self.gguf_writer.add_expert_shared_count(n_shared_experts) + + # When not set, C++ code will use scale_w = false to skip the no-op scaling + if (routed_scaling_factor := hparams.get("routed_scaling_factor")) is not None: + self.gguf_writer.add_expert_weights_scale(routed_scaling_factor) + + if (norm_topk_prob := hparams.get("norm_topk_prob")) is not None and norm_topk_prob: + self.gguf_writer.add_expert_weights_norm(norm_topk_prob) self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) @@ -7287,10 +7312,17 @@ class DeepseekV2Model(TextModel): # skip vision tensors and remove "language_model." for Kimi-VL if "vision_tower" in name or "multi_modal_projector" in name: return [] - + if name.startswith("siglip2.") or name.startswith("merger."): + return [] if name.startswith("language_model."): name = name.replace("language_model.", "") + # skip lm_head.weight if tie_word_embeddings is True + if self.hparams.get("tie_word_embeddings", False): + if name == "lm_head.weight" or name == "model.lm_head.weight": + logger.info("Skipping tied output layer 'lm_head.weight' (will use token_embd.weight)") + return [] + # rename e_score_correction_bias tensors if name.endswith("e_score_correction_bias"): name = name.replace("e_score_correction_bias", "e_score_correction.bias") @@ -10625,6 +10657,59 @@ class JanusProVisionModel(MmprojModel): return [] +@ModelBase.register("YOUTUVLForConditionalGeneration", "YOUTUVLForCausalLM") +class YOUTUVLVisionModel(MmprojModel): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + assert self.hparams_vision is not None + self.hparams_vision["image_size"] = self.hparams_vision.get("image_size", 560) + + def set_gguf_parameters(self): + super().set_gguf_parameters() + + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.YOUTUVL) + self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams.get("layer_norm_eps", 1e-6)) + + # Handle activation function + hidden_act = str(self.hparams.get("hidden_act", "gelu_pytorch_tanh")).lower() + if hidden_act in ("gelu", "gelu_pytorch_tanh", "gelu_fast", "gelu_new", "gelu_accurate"): + self.gguf_writer.add_vision_use_gelu(True) + elif hidden_act == "silu": + self.gguf_writer.add_vision_use_silu(True) + else: + raise ValueError(f"Unsupported activation function for YOUTUVL: {hidden_act}") + + self.gguf_writer.add_vision_spatial_merge_size(self.hparams.get("spatial_merge_size", 2)) + + window_size = self.hparams.get("window_size") + if window_size is not None: + self.gguf_writer.add_vision_window_size(window_size) + # fullatt_block_indexes contains explicit layer indices that use full attention + # e.g., [2, 5, 8, 11] means layers 2, 5, 8, 11 use full attention + # All other layers use window attention + fullatt_block_indexes = self.hparams.get("fullatt_block_indexes") + assert fullatt_block_indexes is not None, "fullatt_block_indexes is required for youtuvl" + # Store the explicit layer indices for YoutuVL (irregular pattern approach) + self.gguf_writer.add_vision_wa_layer_indexes(layers=fullatt_block_indexes) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + # Skip language model tensors + skip_prefixes = ('lm_head.', 'model.layers.', 'model.embed_tokens.', 'model.norm.') + if name.startswith(skip_prefixes): + return [] + + # Try to map the tensor using TensorNameMap (handles vision encoder and projector) + try: + new_name = self.map_tensor_name(name) + return [(new_name, data_torch)] + except ValueError: + # If mapping fails, log warning and skip + logger.warning(f"Cannot map tensor: {name}") + return [] + + @ModelBase.register("SolarOpenForCausalLM") class SolarOpenModel(Glm4MoeModel): model_arch = gguf.MODEL_ARCH.GLM4_MOE diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index 243cf8a29b..74c67e6a9c 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -145,6 +145,7 @@ models = [ {"name": "granite-docling", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-docling-258M", }, {"name": "minimax-m2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/MiniMaxAI/MiniMax-M2", }, {"name": "kormo", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/KORMo-Team/KORMo-tokenizer", }, + {"name": "youtu", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Youtu-LLM-2B", }, {"name": "solar-open", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/upstage/Solar-Open-100B", }, ] diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index c2a0f41c1b..0ac512ff36 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -294,7 +294,9 @@ class Keys: USE_GELU = "clip.use_gelu" USE_SILU = "clip.use_silu" N_WA_PATTERN = "clip.vision.n_wa_pattern" # used by qwen2.5vl + WA_LAYER_INDEXES = "clip.vision.wa_layer_indexes" # used by youtuvl IS_DEEPSTACK_LAYERS = "clip.vision.is_deepstack_layers" + WINDOW_SIZE = "clip.vision.window_size" class Attention: HEAD_COUNT = "clip.vision.attention.head_count" @@ -3494,6 +3496,7 @@ class VisionProjectorType: LFM2A = "lfm2a" # audio MUSIC_FLAMINGO = "musicflamingo" # audio GLM4V = "glm4v" + YOUTUVL = "youtuvl" # Items here are (block size, type size) diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 6a4a504f8d..612a978e4c 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -1129,11 +1129,40 @@ class GGUFWriter: self.add_uint32(Keys.ClipVision.Projector.SCALE_FACTOR, value) def add_vision_n_wa_pattern(self, value: int) -> None: + """Add window attention pattern interval for vision models. + + This defines the pattern interval for window attention vs full attention layers. + For example, if n_wa_pattern=4, then layers 3, 7, 11, ... use full attention, + while other layers use window attention. + + Used by models like Qwen2.5-VL where full attention layers follow a regular pattern. + """ self.add_uint32(Keys.ClipVision.N_WA_PATTERN, value) + def add_vision_wa_layer_indexes(self, layers: Sequence[int]) -> None: + """Add explicit layer indexes that use full attention in vision models. + + This specifies the exact layer indices (0-based) that should use full attention + instead of window attention. All other layers will use window attention. + + Args: + layers: List of layer indices that use full attention (e.g., [3, 7, 11, 15]) + + Used by models like YoutuVL where full attention layers are explicitly specified + rather than following a regular pattern. + + Difference from add_vision_n_wa_pattern: + - n_wa_pattern: Defines a regular interval pattern (every Nth layer uses full attention) + - wa_layer_indexes: Explicitly lists which layers use full attention (irregular pattern) + """ + self.add_array(Keys.ClipVision.WA_LAYER_INDEXES, layers) + def add_vision_is_deepstack_layers(self, layers: Sequence[bool]) -> None: self.add_array(Keys.ClipVision.IS_DEEPSTACK_LAYERS, layers) + def add_vision_window_size(self, value: int) -> None: + self.add_uint32(Keys.ClipVision.WINDOW_SIZE, value) + # audio models def add_audio_projection_dim(self, value: int) -> None: diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 115df6c7c3..64dd4ddca5 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -1221,6 +1221,7 @@ class TensorNameMap: MODEL_TENSOR.V_MMPROJ: ( "multi_modal_projector.linear_{bid}", "visual.merger.mlp.{bid}", # qwen2vl + "merger.mlp.{bid}", ), MODEL_TENSOR.V_MMPROJ_FC: ( @@ -1258,6 +1259,7 @@ class TensorNameMap: "visual.patch_embed.proj", # qwen2vl "vision_tower.patch_embed.proj", # kimi-vl "model.vision.patch_embedding.proj", # cogvlm + "siglip2.vision_model.embeddings.patch_embedding", ), MODEL_TENSOR.V_ENC_EMBD_NORM: ( @@ -1291,6 +1293,7 @@ class TensorNameMap: "vision_encoder.transformer.layers.{bid}.attention.wq", # pixtral "visual.blocks.{bid}.attn.q", # qwen2vl, generated "vision_tower.encoder.blocks.{bid}.wq", # kimi-vl, generated + "siglip2.vision_model.encoder.layers.{bid}.self_attn.q_proj", # youtuvl ), MODEL_TENSOR.V_ENC_ATTN_Q_NORM: ( @@ -1308,6 +1311,7 @@ class TensorNameMap: "vision_encoder.transformer.layers.{bid}.attention.wk", # pixtral "visual.blocks.{bid}.attn.k", # qwen2vl, generated "vision_tower.encoder.blocks.{bid}.wk", # kimi-vl, generated + "siglip2.vision_model.encoder.layers.{bid}.self_attn.k_proj", ), MODEL_TENSOR.V_ENC_ATTN_K_NORM: ( @@ -1325,6 +1329,7 @@ class TensorNameMap: "vision_encoder.transformer.layers.{bid}.attention.wv", # pixtral "visual.blocks.{bid}.attn.v", # qwen2vl, generated "vision_tower.encoder.blocks.{bid}.wv", # kimi-vl, generated + "siglip2.vision_model.encoder.layers.{bid}.self_attn.v_proj", ), MODEL_TENSOR.V_ENC_INPUT_NORM: ( @@ -1339,6 +1344,7 @@ class TensorNameMap: "visual.blocks.{bid}.norm1", # qwen2vl "vision_tower.encoder.blocks.{bid}.norm0", # kimi-vl (norm0/norm1) "model.vision.transformer.layers.{bid}.input_layernorm", # cogvlm + "siglip2.vision_model.encoder.layers.{bid}.layer_norm1", ), MODEL_TENSOR.V_ENC_ATTN_O: ( @@ -1354,6 +1360,7 @@ class TensorNameMap: "visual.blocks.{bid}.attn.proj", # qwen2vl "vision_tower.encoder.blocks.{bid}.wo", # kimi-vl "model.vision.transformer.layers.{bid}.attention.dense", # cogvlm + "siglip2.vision_model.encoder.layers.{bid}.self_attn.out_proj", # youtuvl ), MODEL_TENSOR.V_ENC_POST_ATTN_NORM: ( @@ -1368,6 +1375,7 @@ class TensorNameMap: "visual.blocks.{bid}.norm2", # qwen2vl "vision_tower.encoder.blocks.{bid}.norm1", # kimi-vl (norm0/norm1) "model.vision.transformer.layers.{bid}.post_attention_layernorm", # cogvlm + "siglip2.vision_model.encoder.layers.{bid}.layer_norm2", ), MODEL_TENSOR.V_ENC_FFN_UP: ( @@ -1383,6 +1391,7 @@ class TensorNameMap: "visual.blocks.{bid}.mlp.linear_fc1", # qwen3vl "vision_tower.encoder.blocks.{bid}.mlp.fc0", # kimi-vl (fc0/fc1) "model.vision.transformer.layers.{bid}.mlp.fc1", # cogvlm + "siglip2.vision_model.encoder.layers.{bid}.mlp.fc1", ), MODEL_TENSOR.V_ENC_FFN_GATE: ( @@ -1404,6 +1413,7 @@ class TensorNameMap: "visual.blocks.{bid}.mlp.linear_fc2", # qwen3vl "vision_tower.encoder.blocks.{bid}.mlp.fc1", # kimi-vl (fc0/fc1) "model.vision.transformer.layers.{bid}.mlp.fc2", # cogvlm + "siglip2.vision_model.encoder.layers.{bid}.mlp.fc2", ), MODEL_TENSOR.V_LAYER_SCALE_1: ( @@ -1430,6 +1440,7 @@ class TensorNameMap: "visual.merger.ln_q", # qwen2vl "vision_tower.encoder.final_layernorm", # kimi-vl "visual.post_layernorm", # glm4v + "siglip2.vision_model.post_layernorm", ), MODEL_TENSOR.V_MM_POST_NORM: ( @@ -1446,6 +1457,7 @@ class TensorNameMap: "multi_modal_projector.pre_norm", "pre_mm_projector_norm", "model.vision.linear_proj.norm1", # cogvlm + "merger.ln_q", ), MODEL_TENSOR.V_MM_SOFT_EMB_NORM: ( diff --git a/src/llama-model.cpp b/src/llama-model.cpp index d8b1221df5..c2cd44de44 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1683,7 +1683,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla, false); ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { @@ -4785,7 +4785,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // output output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + // try to load output.weight, if not found, use token_embd (tied embeddings) + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; @@ -4848,7 +4852,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // output output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + // try to load output.weight, if not found, use token_embd (tied embeddings) + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index c57055082b..bd311bea45 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -314,6 +314,12 @@ struct llm_tokenizer_bpe : llm_tokenizer { "[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+", }; break; + case LLAMA_VOCAB_PRE_TYPE_YOUTU: + regex_exprs = { + "[가-힣ㄱ-ㆎ]+|[!…“”‘’—:;,、-〿︰-﹏]+|[ㄅ-ㄯ]+|[一-龥぀-ゟ゠-ヿ]+", + "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }; + break; case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER: regex_exprs = { "[\r\n]", @@ -1861,6 +1867,11 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "deepseek-v3") { pre_type = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM; clean_spaces = false; + } else if ( + tokenizer_pre == "youtu") { + pre_type = LLAMA_VOCAB_PRE_TYPE_YOUTU; + clean_spaces = false; + ignore_merges = true; } else if ( tokenizer_pre == "falcon") { pre_type = LLAMA_VOCAB_PRE_TYPE_FALCON; diff --git a/src/llama-vocab.h b/src/llama-vocab.h index f5bdd22311..2b240a5491 100644 --- a/src/llama-vocab.h +++ b/src/llama-vocab.h @@ -52,6 +52,7 @@ enum llama_vocab_pre_type { LLAMA_VOCAB_PRE_TYPE_MINIMAX_M2 = 41, LLAMA_VOCAB_PRE_TYPE_AFMOE = 42, LLAMA_VOCAB_PRE_TYPE_SOLAR_OPEN = 43, + LLAMA_VOCAB_PRE_TYPE_YOUTU = 44, }; struct LLM_KV; diff --git a/src/models/deepseek2.cpp b/src/models/deepseek2.cpp index 49382874ba..ca63a62ad1 100644 --- a/src/models/deepseek2.cpp +++ b/src/models/deepseek2.cpp @@ -215,7 +215,7 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr model.layers[il].ffn_exp_probs_b, n_expert, n_expert_used, LLM_FFN_SILU, hparams.expert_weights_norm, - true, hparams.expert_weights_scale, + hparams.expert_weights_scale, hparams.expert_weights_scale, (llama_expert_gating_func_type) hparams.expert_gating_func, il); cb(moe_out, "ffn_moe_out", il); diff --git a/src/unicode.cpp b/src/unicode.cpp index bb44edfadd..b47dcbe619 100644 --- a/src/unicode.cpp +++ b/src/unicode.cpp @@ -964,6 +964,11 @@ std::vector unicode_regex_split(const std::string & text, const std { "\\p{P}", unicode_cpt_flags::PUNCTUATION }, { "\\p{M}", unicode_cpt_flags::ACCENT_MARK }, { "\\p{S}", unicode_cpt_flags::SYMBOL }, + { "\\p{Lu}", unicode_cpt_flags::LETTER }, // Uppercase letter + { "\\p{Ll}", unicode_cpt_flags::LETTER }, // Lowercase letter + { "\\p{Lt}", unicode_cpt_flags::LETTER }, // Titlecase letter + { "\\p{Lm}", unicode_cpt_flags::LETTER }, // Modifier letter + { "\\p{Lo}", unicode_cpt_flags::LETTER }, // Other letter }; static const std::map k_ucat_cpt = { @@ -1074,22 +1079,26 @@ std::vector unicode_regex_split(const std::string & text, const std continue; } - if (regex_expr[i + 0] == '\\' && i + 4 < regex_expr.size() && + // Match \p{...} Unicode properties of varying lengths + if (regex_expr[i + 0] == '\\' && i + 3 < regex_expr.size() && regex_expr[i + 1] == 'p' && - regex_expr[i + 2] == '{' && - regex_expr[i + 4] == '}') { - const std::string pat = regex_expr.substr(i, 5); - if (k_ucat_enum.find(pat) != k_ucat_enum.end()) { - if (!inside) { - regex_expr_collapsed += '['; + regex_expr[i + 2] == '{') { + // Find the closing brace + size_t closing_brace = regex_expr.find('}', i + 3); + if (closing_brace != std::string::npos && closing_brace <= i + 10) { // reasonable limit + const std::string pat = regex_expr.substr(i, closing_brace - i + 1); + if (k_ucat_enum.find(pat) != k_ucat_enum.end()) { + if (!inside) { + regex_expr_collapsed += '['; + } + regex_expr_collapsed += k_ucat_cpt.at(k_ucat_enum.at(pat)); + regex_expr_collapsed += k_ucat_map.at(k_ucat_enum.at(pat)); + if (!inside) { + regex_expr_collapsed += ']'; + } + i = closing_brace; + continue; } - regex_expr_collapsed += k_ucat_cpt.at(k_ucat_enum.at(pat)); - regex_expr_collapsed += k_ucat_map.at(k_ucat_enum.at(pat)); - if (!inside) { - regex_expr_collapsed += ']'; - } - i += 4; - continue; } } diff --git a/tools/mtmd/CMakeLists.txt b/tools/mtmd/CMakeLists.txt index 317d5f19fd..4b9022cb58 100644 --- a/tools/mtmd/CMakeLists.txt +++ b/tools/mtmd/CMakeLists.txt @@ -27,6 +27,7 @@ add_library(mtmd models/qwen3vl.cpp models/siglip.cpp models/whisper-enc.cpp + models/youtuvl.cpp ) set_target_properties(mtmd PROPERTIES diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index 1ed0741883..df7e479765 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -45,13 +45,14 @@ #define KEY_SPATIAL_MERGE_SIZE "clip.vision.spatial_merge_size" #define KEY_IS_DEEPSTACK_LAYERS "clip.vision.is_deepstack_layers" -#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type" -#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints" -#define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution" -#define KEY_WIN_ATTN_PATTERN "clip.vision.n_wa_pattern" -#define KEY_ATTN_WINDOW_SIZE "clip.vision.window_size" -#define KEY_MINICPMV_VERSION "clip.minicpmv_version" -#define KEY_MINICPMV_QUERY_NUM "clip.minicpmv_query_num" +#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type" +#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints" +#define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution" +#define KEY_WIN_ATTN_PATTERN "clip.vision.n_wa_pattern" +#define KEY_WIN_ATTN_LAYER_INDEXES "clip.vision.wa_layer_indexes" +#define KEY_ATTN_WINDOW_SIZE "clip.vision.window_size" +#define KEY_MINICPMV_VERSION "clip.minicpmv_version" +#define KEY_MINICPMV_QUERY_NUM "clip.minicpmv_query_num" // audio-specific #define KEY_AUDIO_PROJ_TYPE "clip.audio.projector_type" // for models with mixed modalities @@ -188,6 +189,7 @@ enum projector_type { PROJECTOR_TYPE_JANUS_PRO, PROJECTOR_TYPE_LFM2A, PROJECTOR_TYPE_GLM4V, + PROJECTOR_TYPE_YOUTUVL, PROJECTOR_TYPE_UNKNOWN, }; @@ -218,6 +220,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_JANUS_PRO, "janus_pro"}, { PROJECTOR_TYPE_LFM2A, "lfm2a"}, { PROJECTOR_TYPE_GLM4V, "glm4v"}, + { PROJECTOR_TYPE_YOUTUVL, "youtuvl"}, }; static projector_type clip_projector_type_from_string(const std::string & str) { diff --git a/tools/mtmd/clip-model.h b/tools/mtmd/clip-model.h index 1e5aa87b98..702e10151a 100644 --- a/tools/mtmd/clip-model.h +++ b/tools/mtmd/clip-model.h @@ -61,6 +61,7 @@ struct clip_hparams { std::unordered_set vision_feature_layer; int32_t attn_window_size = 0; int32_t n_wa_pattern = 0; + std::unordered_set wa_layer_indexes; // explicit layer indexes that use full attention (for irregular patterns like YoutuVL) // audio int32_t n_mel_bins = 0; // whisper preprocessor diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index fb08dd258c..9f551e8f3c 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -846,6 +846,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 { builder = std::make_unique(ctx, img); } break; + case PROJECTOR_TYPE_YOUTUVL: + { + builder = std::make_unique(ctx, img); + } break; default: GGML_ABORT("missing cgraph builder"); } @@ -1159,6 +1163,20 @@ struct clip_model_loader { LOG_WRN("%s: more info: https://github.com/ggml-org/llama.cpp/issues/16842\n\n", __func__); } } break; + case PROJECTOR_TYPE_YOUTUVL: + { + hparams.n_merge = 2; + get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.n_merge, false); + get_u32(KEY_ATTN_WINDOW_SIZE, hparams.attn_window_size, true); + std::vector wa_layer_indexes_vec; + get_arr_int(KEY_WIN_ATTN_LAYER_INDEXES, wa_layer_indexes_vec, true); + for (auto & layer : wa_layer_indexes_vec) { + hparams.wa_layer_indexes.insert(layer); + } + // support max_height * max_width = 8000 * 8000. 8000/16/2 = 250 image tokens + hparams.set_limit_image_tokens(1, 62500); + hparams.set_warmup_n_tokens(16*16); // avoid OOM on warmup + } break; case PROJECTOR_TYPE_GLM4V: { hparams.rope_theta = 10000.0f; @@ -1227,7 +1245,14 @@ struct clip_model_loader { LOG_INF("%s: has_llava_proj: %d\n", __func__, hparams.has_llava_projector); LOG_INF("%s: minicpmv_version: %d\n", __func__, hparams.minicpmv_version); LOG_INF("%s: n_merge: %d\n", __func__, hparams.n_merge); - LOG_INF("%s: n_wa_pattern: %d\n", __func__, hparams.n_wa_pattern); + LOG_INF("%s: n_wa_pattern: %d\n", __func__, hparams.n_wa_pattern); + if (!hparams.wa_layer_indexes.empty()) { + LOG_INF("%s: wa_layer_indexes: ", __func__); + for (auto & layer : hparams.wa_layer_indexes) { + LOG_INF("%d ", layer); + } + LOG_INF("\n"); + } if (hparams.image_min_pixels > 0) { LOG_INF("%s: image_min_pixels: %d%s\n", __func__, hparams.image_min_pixels, hparams.custom_image_min_tokens > 0 ? " (custom value)" : ""); } @@ -1495,6 +1520,14 @@ struct clip_model_loader { model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias")); } break; + case PROJECTOR_TYPE_YOUTUVL: + { + model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM); // merger.ln_q (RMS norm) + model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight")); // merger.mlp.0 + model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias")); + model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); // merger.mlp.2 + model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias")); + } break; case PROJECTOR_TYPE_GLM4V: { model.projection = get_tensor(TN_MM_PROJECTOR); @@ -2697,6 +2730,57 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str // res_imgs->data[0] = *res; res_imgs->entries.push_back(std::move(img_f32)); } break; + case PROJECTOR_TYPE_YOUTUVL: + { + const int patch_size = params.patch_size; // typically 16 + const int merge_size = params.n_merge; // typically 2 + const int align_size = patch_size * merge_size; // 32 + + const int max_num_patches = params.image_max_pixels > 0 ? + params.image_max_pixels / (patch_size * patch_size) : 256; + + // Linear search for optimal scale to fit within max_num_patches + float scale = 1.0f; + int target_height = original_size.height; + int target_width = original_size.width; + + auto get_scaled_image_size = [align_size](float scale, int size) -> int { + float scaled_size = size * scale; + // Round up to nearest multiple of align_size + int aligned = static_cast(std::ceil(scaled_size / align_size)) * align_size; + // Ensure at least one patch + return std::max(align_size, aligned); + }; + + // Linear search with 0.02 step size + while (scale > 0.0f) { + target_height = get_scaled_image_size(scale, original_size.height); + target_width = get_scaled_image_size(scale, original_size.width); + + int num_patches_h = target_height / patch_size; + int num_patches_w = target_width / patch_size; + int num_patches = num_patches_h * num_patches_w; + + if (num_patches > max_num_patches) { + scale -= 0.02f; + } else { + break; + } + } + + clip_image_size new_size = {target_width, target_height}; + + // Resize the image + clip_image_u8 resized; + img_tool::resize(*img, resized, new_size, img_tool::RESIZE_ALGO_BILINEAR, false); + + // Normalize to float32 + clip_image_f32_ptr img_f32(clip_image_f32_init()); + normalize_image_u8_to_f32(resized, *img_f32, params.image_mean, params.image_std); + + // Add to results + res_imgs->entries.push_back(std::move(img_f32)); + } break; case PROJECTOR_TYPE_IDEFICS3: { @@ -2929,6 +3013,7 @@ int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 * case PROJECTOR_TYPE_QWEN25VL: case PROJECTOR_TYPE_QWEN3VL: case PROJECTOR_TYPE_GLM4V: + case PROJECTOR_TYPE_YOUTUVL: return (img->nx / params.patch_size) / 2; default: break; @@ -2944,6 +3029,7 @@ int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 * case PROJECTOR_TYPE_QWEN25VL: case PROJECTOR_TYPE_QWEN3VL: case PROJECTOR_TYPE_GLM4V: + case PROJECTOR_TYPE_YOUTUVL: return (img->ny / params.patch_size) / 2; default: break; @@ -3004,6 +3090,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im case PROJECTOR_TYPE_QWEN25VL: case PROJECTOR_TYPE_QWEN3VL: case PROJECTOR_TYPE_GLM4V: + case PROJECTOR_TYPE_YOUTUVL: { // dynamic size (2 conv, so double patch size) int x_patch = img->nx / (params.patch_size * 2); @@ -3131,7 +3218,6 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima const int pos_w = image_size_width / patch_size; const int pos_h = image_size_height / patch_size; - const bool use_window_attn = hparams.n_wa_pattern > 0; // for qwen2.5vl auto get_inp_tensor = [&gf](const char * name) { ggml_tensor * inp = ggml_graph_get_tensor(gf, name); @@ -3280,9 +3366,11 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima set_input_i32("positions", positions); } break; case PROJECTOR_TYPE_QWEN25VL: + case PROJECTOR_TYPE_YOUTUVL: { // pw * ph = number of tokens output by ViT after apply patch merger // ipw * ipw = number of vision token been processed inside ViT + const bool use_window_attn = ctx->model.proj_type == PROJECTOR_TYPE_QWEN25VL ? hparams.n_wa_pattern > 0 : !hparams.wa_layer_indexes.empty(); const int merge_ratio = 2; const int pw = image_size_width / patch_size / merge_ratio; const int ph = image_size_height / patch_size / merge_ratio; @@ -3293,7 +3381,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima std::vector inv_idx(ph * pw); if (use_window_attn) { - const int attn_window_size = 112; + const int attn_window_size = hparams.attn_window_size > 0 ? hparams.attn_window_size : 112; const int grid_window = attn_window_size / patch_size / merge_ratio; int dst = 0; // [num_vision_tokens, num_vision_tokens] attention mask tensor @@ -3531,6 +3619,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { case PROJECTOR_TYPE_QWEN2VL: case PROJECTOR_TYPE_QWEN25VL: case PROJECTOR_TYPE_JANUS_PRO: + case PROJECTOR_TYPE_YOUTUVL: return ctx->model.mm_1_b->ne[0]; case PROJECTOR_TYPE_QWEN3VL: // main path + deepstack paths diff --git a/tools/mtmd/models/models.h b/tools/mtmd/models/models.h index e08c33f353..74e94f60ec 100644 --- a/tools/mtmd/models/models.h +++ b/tools/mtmd/models/models.h @@ -27,6 +27,11 @@ struct clip_graph_qwen3vl : clip_graph { ggml_cgraph * build() override; }; +struct clip_graph_youtuvl : clip_graph { + clip_graph_youtuvl(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} + ggml_cgraph * build() override; +}; + struct clip_graph_minicpmv : clip_graph { clip_graph_minicpmv(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} ggml_cgraph * build() override; diff --git a/tools/mtmd/models/youtuvl.cpp b/tools/mtmd/models/youtuvl.cpp new file mode 100644 index 0000000000..ffbf2be554 --- /dev/null +++ b/tools/mtmd/models/youtuvl.cpp @@ -0,0 +1,179 @@ +#include "models.h" + +ggml_cgraph * clip_graph_youtuvl::build() { + GGML_ASSERT(model.class_embedding == nullptr); + const int batch_size = 1; + const bool use_window_attn = !hparams.wa_layer_indexes.empty(); + const int n_pos = n_patches; + const int num_position_ids = n_pos * 4; + const int m = 2; + const int Wp = n_patches_x; + const int Hp = n_patches_y; + const int Hm = Hp / m; + const int Wm = Wp / m; + norm_type norm_t = NORM_TYPE_NORMAL; + + int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4}; + + ggml_tensor * inp = build_inp_raw(); + + // change conv3d to linear + // reshape and permute to get patches, permute from (patch_size, m, Wm, patch_size, m, Hm, C) to (C, patch_size, patch_size, m, m, Wm, Hm) + { + inp = ggml_reshape_4d( + ctx0, inp, + Wm * m * patch_size, m * patch_size, Hm, 3); + inp = ggml_permute(ctx0, inp, 1, 2, 3, 0); + inp = ggml_cont_4d( + ctx0, inp, + m * patch_size * 3, Wm, m * patch_size, Hm); + + inp = ggml_permute(ctx0, inp, 0, 2, 1, 3); + inp = ggml_cont_4d( + ctx0, inp, + m * patch_size * 3, patch_size, m, Hm * Wm); + + inp = ggml_permute(ctx0, inp, 1, 0, 2, 3); + inp = ggml_cont_4d( + ctx0, inp, + patch_size, 3, patch_size, Hm * Wm * m * m); + + inp = ggml_permute(ctx0, inp, 2, 0, 1, 3); + inp = ggml_cont_3d( + ctx0, inp, + 3*patch_size* patch_size, Hm * Wm * m * m, 1); + } + inp = ggml_mul_mat(ctx0, model.patch_embeddings_0, inp); + + if (model.patch_bias) { + inp = ggml_add(ctx0, inp, model.patch_bias); + } + + inp = ggml_reshape_2d(ctx0, inp, n_embd, n_patches); + + ggml_tensor * inpL = inp; + ggml_tensor * window_mask = nullptr; + ggml_tensor * window_idx = nullptr; + ggml_tensor * inv_window_idx = nullptr; + + ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_position_ids); + ggml_set_name(positions, "positions"); + ggml_set_input(positions); + + // pre-layernorm + if (model.pre_ln_w) { + inpL = build_norm(inpL, model.pre_ln_w, model.pre_ln_b, norm_t, eps, -1); + } + if (use_window_attn) { + inv_window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos / 4); + ggml_set_name(inv_window_idx, "inv_window_idx"); + ggml_set_input(inv_window_idx); + // mask for window attention + window_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_pos, n_pos); + ggml_set_name(window_mask, "window_mask"); + ggml_set_input(window_mask); + + // if flash attn is used, we need to pad the mask and cast to f16 + if (flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) { + window_mask = ggml_cast(ctx0, window_mask, GGML_TYPE_F16); + } + + // inpL shape: [n_embd, n_patches_x * n_patches_y, batch_size] + GGML_ASSERT(batch_size == 1); + inpL = ggml_reshape_2d(ctx0, inpL, n_embd * 4, n_patches_x * n_patches_y * batch_size / 4); + inpL = ggml_get_rows(ctx0, inpL, inv_window_idx); + inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_patches_x * n_patches_y, batch_size); + } + + // loop over layers + for (int il = 0; il < n_layer; il++) { + const auto & layer = model.layers[il]; + const bool full_attn = use_window_attn ? hparams.wa_layer_indexes.count(il) > 0 : true; + + ggml_tensor * cur = inpL; // inpL = residual, cur = hidden_states + + // layernorm1 + cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, norm_t, eps, il); + // self-attention + { + ggml_tensor * Qcur = ggml_add(ctx0, + ggml_mul_mat(ctx0, layer.q_w, cur), layer.q_b); + ggml_tensor * Kcur = ggml_add(ctx0, + ggml_mul_mat(ctx0, layer.k_w, cur), layer.k_b); + ggml_tensor * Vcur = ggml_add(ctx0, + ggml_mul_mat(ctx0, layer.v_w, cur), layer.v_b); + + Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_patches); + Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_patches); + Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, n_patches); + + Qcur = ggml_rope_multi( + ctx0, Qcur, positions, nullptr, + d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1); + Kcur = ggml_rope_multi( + ctx0, Kcur, positions, nullptr, + d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1); + + ggml_tensor * attn_mask = full_attn ? nullptr : window_mask; + + cur = build_attn(layer.o_w, layer.o_b, + Qcur, Kcur, Vcur, attn_mask, kq_scale, il); + } + // re-add the layer input, e.g., residual + cur = ggml_add(ctx0, cur, inpL); + + inpL = cur; // inpL = residual, cur = hidden_states + + // layernorm2 + cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, norm_t, eps, il); + + // ffn + cur = build_ffn(cur, + layer.ff_up_w, layer.ff_up_b, + nullptr, nullptr, + layer.ff_down_w, layer.ff_down_b, + hparams.ffn_op, il); + + // residual 2 + cur = ggml_add(ctx0, inpL, cur); + + inpL = cur; + } + + ggml_tensor * embeddings = inpL; + if (use_window_attn) { + const int spatial_merge_unit = 4; + window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos / spatial_merge_unit); + ggml_set_name(window_idx, "window_idx"); + ggml_set_input(window_idx); + GGML_ASSERT(batch_size == 1); + embeddings = ggml_reshape_2d(ctx0, embeddings, n_embd * spatial_merge_unit, n_patches / spatial_merge_unit); + embeddings = ggml_get_rows(ctx0, embeddings, window_idx); + embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd, n_patches, batch_size); + cb(embeddings, "window_order_restored", -1); + } + + // post-layernorm (part of Siglip2VisionTransformer, applied after encoder) + if (model.post_ln_w) { + embeddings = build_norm(embeddings, model.post_ln_w, model.post_ln_b, norm_t, eps, n_layer); + } + + // Now apply merger (VLPatchMerger): + // 1. Apply RMS norm (ln_q in VLPatchMerger) + embeddings = build_norm(embeddings, model.mm_input_norm_w, nullptr, NORM_TYPE_RMS, 1e-6, -1); + cb(embeddings, "merger_normed", -1); + + // 2. First reshape for spatial merge (merge 2x2 patches) + embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd * 4, n_pos / 4, batch_size); + cb(embeddings, "merger_reshaped", -1); + + embeddings = build_ffn(embeddings, + model.mm_0_w, model.mm_0_b, + nullptr, nullptr, + model.mm_1_w, model.mm_1_b, + FFN_GELU, + -1); + ggml_build_forward_expand(gf, embeddings); + + return gf; +} diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index b0b5ab42ab..fca55b76f8 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -283,7 +283,7 @@ struct mtmd_context { // https://github.com/huggingface/transformers/blob/1cd110c6cb6a6237614130c470e9a902dbc1a4bd/docs/source/en/model_doc/pixtral.md img_end = "[IMG_END]"; - } else if (proj == PROJECTOR_TYPE_QWEN2VL || proj == PROJECTOR_TYPE_QWEN25VL || proj == PROJECTOR_TYPE_QWEN3VL) { + } else if (proj == PROJECTOR_TYPE_QWEN2VL || proj == PROJECTOR_TYPE_QWEN25VL || proj == PROJECTOR_TYPE_QWEN3VL || proj == PROJECTOR_TYPE_YOUTUVL) { // <|vision_start|> ... (image embeddings) ... <|vision_end|> img_beg = "<|vision_start|>"; img_end = "<|vision_end|>";