diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index edc0ed539d..395d0d37ba 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -3937,9 +3937,7 @@ class Qwen2VLVisionModel(MmprojModel): return [] # skip other tensors -@ModelBase.register("Qwen2_5OmniModel") -class Qwen25OmniModel(Qwen2VLVisionModel): - has_vision_encoder = True +class Qwen25AudioModel(MmprojModel): has_audio_encoder = True def __init__(self, *args, **kwargs): @@ -3955,12 +3953,6 @@ class Qwen25OmniModel(Qwen2VLVisionModel): self.gguf_writer.add_audio_num_mel_bins(self.hparams_audio["num_mel_bins"]) self.gguf_writer.add_audio_attention_layernorm_eps(self.hparams_audio.get("layer_norm_eps", 1e-5)) - def get_vision_config(self) -> dict[str, Any] | None: - return self.global_config["thinker_config"].get("vision_config") - - def get_audio_config(self) -> dict[str, Any] | None: - return self.global_config["thinker_config"].get("audio_config") - def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: # SinusoidsPositionEmbedding assert self.hparams_audio is not None @@ -3993,7 +3985,30 @@ class Qwen25OmniModel(Qwen2VLVisionModel): return [] return [(self.map_tensor_name(name), data_torch)] - return super().modify_tensors(data_torch, name, bid) + return [] # skip other tensors + + +@ModelBase.register("Qwen2_5OmniModel") +class Qwen25OmniModel(Qwen2VLVisionModel, Qwen25AudioModel): + has_audio_encoder = True + has_vision_encoder = True + + def get_vision_config(self) -> dict[str, Any] | None: + return self.global_config["thinker_config"].get("vision_config") + + def get_audio_config(self) -> dict[str, Any] | None: + return self.global_config["thinker_config"].get("audio_config") + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN25O) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if "visual." in name: + return Qwen2VLVisionModel.modify_tensors(self, data_torch, name, bid) + elif "audio_tower." in name: + return Qwen25AudioModel.modify_tensors(self, data_torch, name, bid) + return [] @ModelBase.register("InternVisionModel") @@ -4387,7 +4402,9 @@ class Qwen3VLVisionModel(MmprojModel): def set_gguf_parameters(self): super().set_gguf_parameters() - self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN3VL) + # in case mixed modalities, the arch will be handled by subclass + if not self.has_audio_encoder: + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN3VL) self.gguf_writer.add_vision_use_gelu(True) if self.hparams_vision is not None: @@ -4470,9 +4487,42 @@ class Qwen3VLVisionModel(MmprojModel): if name.startswith("visual."): return [(self.map_tensor_name(name), data_torch)] + return [] # skip other tensors - # Fall back to parent class for other tensors - return super().modify_tensors(data_torch, name, bid) + +@ModelBase.register("Qwen3OmniMoeForConditionalGeneration") +class Qwen3OmniMmprojModel(Qwen3VLVisionModel, Qwen25AudioModel): + has_audio_encoder = True + has_vision_encoder = True + + def get_vision_config(self) -> dict[str, Any] | None: + return self.global_config["thinker_config"].get("vision_config") + + def get_audio_config(self) -> dict[str, Any] | None: + return self.global_config["thinker_config"].get("audio_config") + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_clip_vision_projector_type(gguf.VisionProjectorType.QWEN3VL) + self.gguf_writer.add_clip_audio_projector_type(gguf.VisionProjectorType.QWEN3A) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if "visual." in name: + # need to transform vision tensor naming, so that modify_tensors() logic can be used correctly + name = name.replace("thinker.visual.", "model.visual.") + if ".merger_list." in name: + name = name.replace(".merger_list.", ".deepstack_merger_list.") + name = name.replace(".ln_q", ".norm") + name = name.replace(".mlp.0", ".linear_fc1") + name = name.replace(".mlp.2", ".linear_fc2") + if ".merger." in name: + name = name.replace(".ln_q", ".norm") + name = name.replace(".mlp.0", ".linear_fc1") + name = name.replace(".mlp.2", ".linear_fc2") + return Qwen3VLVisionModel.modify_tensors(self, data_torch, name, bid) + elif "audio_tower." in name: + return Qwen25AudioModel.modify_tensors(self, data_torch, name, bid) + return [] @ModelBase.register("Glm4vForConditionalGeneration", "Glm4vMoeForConditionalGeneration") @@ -4519,7 +4569,7 @@ class Qwen3VLTextModel(Qwen3Model): return super().modify_tensors(data_torch, name, bid) -@ModelBase.register("Qwen3VLMoeForConditionalGeneration") +@ModelBase.register("Qwen3VLMoeForConditionalGeneration", "Qwen3OmniMoeForConditionalGeneration") class Qwen3VLMoeTextModel(Qwen3MoeModel): model_arch = gguf.MODEL_ARCH.QWEN3VLMOE @@ -4531,9 +4581,13 @@ class Qwen3VLMoeTextModel(Qwen3MoeModel): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: # Skip vision tensors - they go in the mmproj file - if name.startswith("model.visual."): + if "visual." in name or "audio_tower." in name \ + or "talker." in name or "code2wav." in name: return [] + # qwen3-omni + name = name.replace("thinker.", "") + return super().modify_tensors(data_torch, name, bid) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index c2a0f41c1b..1599aa15d9 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -275,10 +275,12 @@ class Keys: DATASETS = "imatrix.datasets" class Clip: - PROJECTOR_TYPE = "clip.projector_type" - HAS_VISION_ENCODER = "clip.has_vision_encoder" - HAS_AUDIO_ENCODER = "clip.has_audio_encoder" - HAS_LLAVA_PROJECTOR = "clip.has_llava_projector" + PROJECTOR_TYPE = "clip.projector_type" + VISION_PROJECTOR_TYPE = "clip.vision.projector_type" # for mixed modalities + AUDIO_PROJECTOR_TYPE = "clip.audio.projector_type" # for mixed modalities + HAS_VISION_ENCODER = "clip.has_vision_encoder" + HAS_AUDIO_ENCODER = "clip.has_audio_encoder" + HAS_LLAVA_PROJECTOR = "clip.has_llava_projector" class ClipVision: IMAGE_SIZE = "clip.vision.image_size" @@ -698,6 +700,7 @@ class MODEL_TENSOR(IntEnum): A_ENC_EMBD_NORM = auto() A_ENC_EMBD_TO_LOGITS = auto() A_ENC_CONV1D = auto() + A_ENC_CONV_OUT = auto() A_PRE_NORM = auto() A_POST_NORM = auto() A_ENC_ATTN_Q = auto() @@ -1095,6 +1098,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.A_ENC_EMBD_NORM: "a.position_embd_norm", MODEL_TENSOR.A_ENC_EMBD_TO_LOGITS: "a.embd_to_logits", MODEL_TENSOR.A_ENC_CONV1D: "a.conv1d.{bid}", + MODEL_TENSOR.A_ENC_CONV_OUT: "a.conv_out", MODEL_TENSOR.A_PRE_NORM: "a.pre_ln", MODEL_TENSOR.A_POST_NORM: "a.post_ln", MODEL_TENSOR.A_ENC_ATTN_Q: "a.blk.{bid}.attn_q", @@ -1192,6 +1196,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.A_ENC_EMBD_NORM, MODEL_TENSOR.A_ENC_EMBD_TO_LOGITS, MODEL_TENSOR.A_ENC_CONV1D, + MODEL_TENSOR.A_ENC_CONV_OUT, MODEL_TENSOR.A_PRE_NORM, MODEL_TENSOR.A_POST_NORM, MODEL_TENSOR.A_ENC_ATTN_Q, @@ -3483,6 +3488,7 @@ class VisionProjectorType: ULTRAVOX = "ultravox" INTERNVL = "internvl" QWEN2A = "qwen2a" # audio + QWEN3A = "qwen3a" # audio GLMA = "glma" # audio QWEN25O = "qwen2.5o" # omni VOXTRAL = "voxtral" diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 6a4a504f8d..996eba85ee 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -1083,6 +1083,12 @@ class GGUFWriter: def add_clip_projector_type(self, value: str) -> None: self.add_string(Keys.Clip.PROJECTOR_TYPE, value) + def add_clip_vision_projector_type(self, value: str) -> None: + self.add_string(Keys.Clip.VISION_PROJECTOR_TYPE, value) + + def add_clip_audio_projector_type(self, value: str) -> None: + self.add_string(Keys.Clip.AUDIO_PROJECTOR_TYPE, value) + def add_vision_projection_dim(self, value: int) -> None: self.add_uint32(Keys.ClipVision.PROJECTION_DIM, value) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 115df6c7c3..aeb6843c3e 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -1563,6 +1563,11 @@ class TensorNameMap: MODEL_TENSOR.A_ENC_CONV1D: ( "audio_tower.conv{bid}", # ultravox "conformer.pre_encode.conv.{bid}", # lfm2 + "audio_tower.conv2d{bid}", # qwen3omni + ), + + MODEL_TENSOR.A_ENC_CONV_OUT: ( + "audio_tower.conv_out", # qwen3omni ), MODEL_TENSOR.A_PRE_NORM: (), @@ -1651,7 +1656,8 @@ class TensorNameMap: MODEL_TENSOR.A_MMPROJ: ( "audio.multi_modal_projector.linear_{bid}", # ultravox - "audio_adapter.model.{bid}" # lfm2 + "audio_adapter.model.{bid}", # lfm2 + "audio_tower.proj{bid}", # qwen3omni ), MODEL_TENSOR.A_MMPROJ_FC: ( diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index 1ed0741883..f77ec73f4a 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -125,6 +125,7 @@ // ultravox #define TN_CONV1D "a.conv1d.%d.%s" +#define TN_CONV_OUT "a.conv_out.%s" #define TN_MM_AUDIO_MLP "mm.a.mlp.%d.%s" #define TN_MM_AUDIO_FC "mm.a.fc.%s" // fully connected layer #define TN_MM_NORM_PRE "mm.a.norm_pre.%s" @@ -177,6 +178,7 @@ enum projector_type { PROJECTOR_TYPE_INTERNVL, PROJECTOR_TYPE_LLAMA4, PROJECTOR_TYPE_QWEN2A, + PROJECTOR_TYPE_QWEN3A, PROJECTOR_TYPE_GLMA, PROJECTOR_TYPE_QWEN25O, // will be replaced by QWEN2A or QWEN25VL depending on clip_ctx PROJECTOR_TYPE_VOXTRAL, @@ -207,6 +209,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_INTERNVL, "internvl"}, { PROJECTOR_TYPE_LLAMA4, "llama4"}, { PROJECTOR_TYPE_QWEN2A, "qwen2a"}, + { PROJECTOR_TYPE_QWEN3A, "qwen3a"}, { PROJECTOR_TYPE_GLMA, "glma"}, { PROJECTOR_TYPE_QWEN25O, "qwen2.5o"}, { PROJECTOR_TYPE_VOXTRAL, "voxtral"}, diff --git a/tools/mtmd/clip-model.h b/tools/mtmd/clip-model.h index 1e5aa87b98..2c4c1547a2 100644 --- a/tools/mtmd/clip-model.h +++ b/tools/mtmd/clip-model.h @@ -298,6 +298,8 @@ struct clip_model { ggml_tensor * conv1d_1_b = nullptr; ggml_tensor * conv1d_2_w = nullptr; ggml_tensor * conv1d_2_b = nullptr; + ggml_tensor * conv_out_w = nullptr; + ggml_tensor * conv_out_b = nullptr; ggml_tensor * mm_norm_pre_w = nullptr; ggml_tensor * mm_norm_pre_b = nullptr; ggml_tensor * mm_norm_mid_w = nullptr; diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index fb08dd258c..af926c6da0 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -817,6 +817,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 case PROJECTOR_TYPE_ULTRAVOX: case PROJECTOR_TYPE_VOXTRAL: case PROJECTOR_TYPE_QWEN2A: + case PROJECTOR_TYPE_QWEN3A: case PROJECTOR_TYPE_GLMA: case PROJECTOR_TYPE_MUSIC_FLAMINGO: { @@ -1175,6 +1176,7 @@ struct clip_model_loader { } break; case PROJECTOR_TYPE_ULTRAVOX: case PROJECTOR_TYPE_QWEN2A: + case PROJECTOR_TYPE_QWEN3A: case PROJECTOR_TYPE_GLMA: case PROJECTOR_TYPE_VOXTRAL: case PROJECTOR_TYPE_MUSIC_FLAMINGO: @@ -1569,6 +1571,18 @@ struct clip_model_loader { model.mm_fc_w = get_tensor(string_format(TN_MM_AUDIO_FC, "weight")); model.mm_fc_b = get_tensor(string_format(TN_MM_AUDIO_FC, "bias")); } break; + case PROJECTOR_TYPE_QWEN3A: + { + model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight")); + model.conv1d_1_b = get_tensor(string_format(TN_CONV1D, 1, "bias")); + model.conv1d_2_w = get_tensor(string_format(TN_CONV1D, 2, "weight")); + model.conv1d_2_b = get_tensor(string_format(TN_CONV1D, 2, "bias")); + model.conv_out_w = get_tensor(string_format(TN_CONV_OUT, "weight")); // no bias + model.mm_1_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "weight")); + model.mm_1_b = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "bias")); + model.mm_2_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "weight")); + model.mm_2_b = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "bias")); + } break; case PROJECTOR_TYPE_VOXTRAL: { model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight")); @@ -3044,6 +3058,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im case PROJECTOR_TYPE_VOXTRAL: case PROJECTOR_TYPE_ULTRAVOX: case PROJECTOR_TYPE_QWEN2A: + case PROJECTOR_TYPE_QWEN3A: case PROJECTOR_TYPE_MUSIC_FLAMINGO: { n_patches = img->nx; @@ -3413,6 +3428,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima case PROJECTOR_TYPE_IDEFICS3: case PROJECTOR_TYPE_INTERNVL: case PROJECTOR_TYPE_QWEN2A: + case PROJECTOR_TYPE_QWEN3A: case PROJECTOR_TYPE_GLMA: case PROJECTOR_TYPE_ULTRAVOX: case PROJECTOR_TYPE_LFM2: @@ -3549,8 +3565,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { return ctx->model.mm_model_proj->ne[1]; case PROJECTOR_TYPE_QWEN2A: return ctx->model.mm_fc_w->ne[1]; + case PROJECTOR_TYPE_QWEN3A: case PROJECTOR_TYPE_GLMA: - return ctx->model.mm_2_w->ne[1]; case PROJECTOR_TYPE_LFM2: case PROJECTOR_TYPE_KIMIVL: return ctx->model.mm_2_w->ne[1]; @@ -3602,6 +3618,7 @@ bool clip_has_audio_encoder(const struct clip_ctx * ctx) { bool clip_has_whisper_encoder(const struct clip_ctx * ctx) { return ctx->proj_type() == PROJECTOR_TYPE_ULTRAVOX || ctx->proj_type() == PROJECTOR_TYPE_QWEN2A + || ctx->proj_type() == PROJECTOR_TYPE_QWEN3A || ctx->proj_type() == PROJECTOR_TYPE_GLMA || ctx->proj_type() == PROJECTOR_TYPE_VOXTRAL || ctx->proj_type() == PROJECTOR_TYPE_MUSIC_FLAMINGO; diff --git a/tools/mtmd/models/whisper-enc.cpp b/tools/mtmd/models/whisper-enc.cpp index 2f2b127755..a541857797 100644 --- a/tools/mtmd/models/whisper-enc.cpp +++ b/tools/mtmd/models/whisper-enc.cpp @@ -19,9 +19,18 @@ ggml_cgraph * clip_graph_whisper_enc::build() { cur = ggml_add(ctx0, cur, model.conv1d_2_b); cur = ggml_gelu_erf(ctx0, cur); + // transpose inp = ggml_cont(ctx0, ggml_transpose(ctx0, cur)); cb(inp, "after_conv1d", -1); + + if (model.conv_out_w) { + inp = ggml_mul_mat(ctx0, model.conv_out_w, inp); + if (model.conv_out_b) { + inp = ggml_add(ctx0, inp, model.conv_out_b); + } + cb(inp, "after_conv_out", -1); + } } // sanity check (only check one layer, but it should be the same for all) @@ -77,6 +86,15 @@ ggml_cgraph * clip_graph_whisper_enc::build() { cur = ggml_mul_mat(ctx0, model.mm_fc_w, cur); cur = ggml_add(ctx0, cur, model.mm_fc_b); + } else if (proj_type == PROJECTOR_TYPE_QWEN3A) { + // projector + cur = build_ffn(cur, + model.mm_1_w, model.mm_1_b, + nullptr, nullptr, + model.mm_2_w, model.mm_2_b, + FFN_GELU_ERF, + -1); + } else if (proj_type == PROJECTOR_TYPE_VOXTRAL) { // projector cur = build_ffn(cur,