From d7435467e79a14fd9875da7154f7360b75a21e33 Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Thu, 1 Jan 2026 15:40:36 +0100 Subject: [PATCH 1/6] add qwen3a --- convert_hf_to_gguf.py | 84 +++++++++++++++++++++++++------ gguf-py/gguf/constants.py | 14 ++++-- gguf-py/gguf/gguf_writer.py | 6 +++ gguf-py/gguf/tensor_mapping.py | 8 ++- tools/mtmd/clip-impl.h | 3 ++ tools/mtmd/clip-model.h | 2 + tools/mtmd/clip.cpp | 19 ++++++- tools/mtmd/models/whisper-enc.cpp | 18 +++++++ 8 files changed, 133 insertions(+), 21 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index edc0ed539d..395d0d37ba 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -3937,9 +3937,7 @@ class Qwen2VLVisionModel(MmprojModel): return [] # skip other tensors -@ModelBase.register("Qwen2_5OmniModel") -class Qwen25OmniModel(Qwen2VLVisionModel): - has_vision_encoder = True +class Qwen25AudioModel(MmprojModel): has_audio_encoder = True def __init__(self, *args, **kwargs): @@ -3955,12 +3953,6 @@ class Qwen25OmniModel(Qwen2VLVisionModel): self.gguf_writer.add_audio_num_mel_bins(self.hparams_audio["num_mel_bins"]) self.gguf_writer.add_audio_attention_layernorm_eps(self.hparams_audio.get("layer_norm_eps", 1e-5)) - def get_vision_config(self) -> dict[str, Any] | None: - return self.global_config["thinker_config"].get("vision_config") - - def get_audio_config(self) -> dict[str, Any] | None: - return self.global_config["thinker_config"].get("audio_config") - def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: # SinusoidsPositionEmbedding assert self.hparams_audio is not None @@ -3993,7 +3985,30 @@ class Qwen25OmniModel(Qwen2VLVisionModel): return [] return [(self.map_tensor_name(name), data_torch)] - return super().modify_tensors(data_torch, name, bid) + return [] # skip other tensors + + +@ModelBase.register("Qwen2_5OmniModel") +class Qwen25OmniModel(Qwen2VLVisionModel, Qwen25AudioModel): + has_audio_encoder = True + has_vision_encoder = True + + def get_vision_config(self) -> dict[str, Any] | None: + return self.global_config["thinker_config"].get("vision_config") + + def get_audio_config(self) -> dict[str, Any] | None: + return self.global_config["thinker_config"].get("audio_config") + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN25O) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if "visual." in name: + return Qwen2VLVisionModel.modify_tensors(self, data_torch, name, bid) + elif "audio_tower." in name: + return Qwen25AudioModel.modify_tensors(self, data_torch, name, bid) + return [] @ModelBase.register("InternVisionModel") @@ -4387,7 +4402,9 @@ class Qwen3VLVisionModel(MmprojModel): def set_gguf_parameters(self): super().set_gguf_parameters() - self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN3VL) + # in case mixed modalities, the arch will be handled by subclass + if not self.has_audio_encoder: + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN3VL) self.gguf_writer.add_vision_use_gelu(True) if self.hparams_vision is not None: @@ -4470,9 +4487,42 @@ class Qwen3VLVisionModel(MmprojModel): if name.startswith("visual."): return [(self.map_tensor_name(name), data_torch)] + return [] # skip other tensors - # Fall back to parent class for other tensors - return super().modify_tensors(data_torch, name, bid) + +@ModelBase.register("Qwen3OmniMoeForConditionalGeneration") +class Qwen3OmniMmprojModel(Qwen3VLVisionModel, Qwen25AudioModel): + has_audio_encoder = True + has_vision_encoder = True + + def get_vision_config(self) -> dict[str, Any] | None: + return self.global_config["thinker_config"].get("vision_config") + + def get_audio_config(self) -> dict[str, Any] | None: + return self.global_config["thinker_config"].get("audio_config") + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_clip_vision_projector_type(gguf.VisionProjectorType.QWEN3VL) + self.gguf_writer.add_clip_audio_projector_type(gguf.VisionProjectorType.QWEN3A) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if "visual." in name: + # need to transform vision tensor naming, so that modify_tensors() logic can be used correctly + name = name.replace("thinker.visual.", "model.visual.") + if ".merger_list." in name: + name = name.replace(".merger_list.", ".deepstack_merger_list.") + name = name.replace(".ln_q", ".norm") + name = name.replace(".mlp.0", ".linear_fc1") + name = name.replace(".mlp.2", ".linear_fc2") + if ".merger." in name: + name = name.replace(".ln_q", ".norm") + name = name.replace(".mlp.0", ".linear_fc1") + name = name.replace(".mlp.2", ".linear_fc2") + return Qwen3VLVisionModel.modify_tensors(self, data_torch, name, bid) + elif "audio_tower." in name: + return Qwen25AudioModel.modify_tensors(self, data_torch, name, bid) + return [] @ModelBase.register("Glm4vForConditionalGeneration", "Glm4vMoeForConditionalGeneration") @@ -4519,7 +4569,7 @@ class Qwen3VLTextModel(Qwen3Model): return super().modify_tensors(data_torch, name, bid) -@ModelBase.register("Qwen3VLMoeForConditionalGeneration") +@ModelBase.register("Qwen3VLMoeForConditionalGeneration", "Qwen3OmniMoeForConditionalGeneration") class Qwen3VLMoeTextModel(Qwen3MoeModel): model_arch = gguf.MODEL_ARCH.QWEN3VLMOE @@ -4531,9 +4581,13 @@ class Qwen3VLMoeTextModel(Qwen3MoeModel): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: # Skip vision tensors - they go in the mmproj file - if name.startswith("model.visual."): + if "visual." in name or "audio_tower." in name \ + or "talker." in name or "code2wav." in name: return [] + # qwen3-omni + name = name.replace("thinker.", "") + return super().modify_tensors(data_torch, name, bid) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index c2a0f41c1b..1599aa15d9 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -275,10 +275,12 @@ class Keys: DATASETS = "imatrix.datasets" class Clip: - PROJECTOR_TYPE = "clip.projector_type" - HAS_VISION_ENCODER = "clip.has_vision_encoder" - HAS_AUDIO_ENCODER = "clip.has_audio_encoder" - HAS_LLAVA_PROJECTOR = "clip.has_llava_projector" + PROJECTOR_TYPE = "clip.projector_type" + VISION_PROJECTOR_TYPE = "clip.vision.projector_type" # for mixed modalities + AUDIO_PROJECTOR_TYPE = "clip.audio.projector_type" # for mixed modalities + HAS_VISION_ENCODER = "clip.has_vision_encoder" + HAS_AUDIO_ENCODER = "clip.has_audio_encoder" + HAS_LLAVA_PROJECTOR = "clip.has_llava_projector" class ClipVision: IMAGE_SIZE = "clip.vision.image_size" @@ -698,6 +700,7 @@ class MODEL_TENSOR(IntEnum): A_ENC_EMBD_NORM = auto() A_ENC_EMBD_TO_LOGITS = auto() A_ENC_CONV1D = auto() + A_ENC_CONV_OUT = auto() A_PRE_NORM = auto() A_POST_NORM = auto() A_ENC_ATTN_Q = auto() @@ -1095,6 +1098,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.A_ENC_EMBD_NORM: "a.position_embd_norm", MODEL_TENSOR.A_ENC_EMBD_TO_LOGITS: "a.embd_to_logits", MODEL_TENSOR.A_ENC_CONV1D: "a.conv1d.{bid}", + MODEL_TENSOR.A_ENC_CONV_OUT: "a.conv_out", MODEL_TENSOR.A_PRE_NORM: "a.pre_ln", MODEL_TENSOR.A_POST_NORM: "a.post_ln", MODEL_TENSOR.A_ENC_ATTN_Q: "a.blk.{bid}.attn_q", @@ -1192,6 +1196,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.A_ENC_EMBD_NORM, MODEL_TENSOR.A_ENC_EMBD_TO_LOGITS, MODEL_TENSOR.A_ENC_CONV1D, + MODEL_TENSOR.A_ENC_CONV_OUT, MODEL_TENSOR.A_PRE_NORM, MODEL_TENSOR.A_POST_NORM, MODEL_TENSOR.A_ENC_ATTN_Q, @@ -3483,6 +3488,7 @@ class VisionProjectorType: ULTRAVOX = "ultravox" INTERNVL = "internvl" QWEN2A = "qwen2a" # audio + QWEN3A = "qwen3a" # audio GLMA = "glma" # audio QWEN25O = "qwen2.5o" # omni VOXTRAL = "voxtral" diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 6a4a504f8d..996eba85ee 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -1083,6 +1083,12 @@ class GGUFWriter: def add_clip_projector_type(self, value: str) -> None: self.add_string(Keys.Clip.PROJECTOR_TYPE, value) + def add_clip_vision_projector_type(self, value: str) -> None: + self.add_string(Keys.Clip.VISION_PROJECTOR_TYPE, value) + + def add_clip_audio_projector_type(self, value: str) -> None: + self.add_string(Keys.Clip.AUDIO_PROJECTOR_TYPE, value) + def add_vision_projection_dim(self, value: int) -> None: self.add_uint32(Keys.ClipVision.PROJECTION_DIM, value) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 115df6c7c3..aeb6843c3e 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -1563,6 +1563,11 @@ class TensorNameMap: MODEL_TENSOR.A_ENC_CONV1D: ( "audio_tower.conv{bid}", # ultravox "conformer.pre_encode.conv.{bid}", # lfm2 + "audio_tower.conv2d{bid}", # qwen3omni + ), + + MODEL_TENSOR.A_ENC_CONV_OUT: ( + "audio_tower.conv_out", # qwen3omni ), MODEL_TENSOR.A_PRE_NORM: (), @@ -1651,7 +1656,8 @@ class TensorNameMap: MODEL_TENSOR.A_MMPROJ: ( "audio.multi_modal_projector.linear_{bid}", # ultravox - "audio_adapter.model.{bid}" # lfm2 + "audio_adapter.model.{bid}", # lfm2 + "audio_tower.proj{bid}", # qwen3omni ), MODEL_TENSOR.A_MMPROJ_FC: ( diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index 1ed0741883..f77ec73f4a 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -125,6 +125,7 @@ // ultravox #define TN_CONV1D "a.conv1d.%d.%s" +#define TN_CONV_OUT "a.conv_out.%s" #define TN_MM_AUDIO_MLP "mm.a.mlp.%d.%s" #define TN_MM_AUDIO_FC "mm.a.fc.%s" // fully connected layer #define TN_MM_NORM_PRE "mm.a.norm_pre.%s" @@ -177,6 +178,7 @@ enum projector_type { PROJECTOR_TYPE_INTERNVL, PROJECTOR_TYPE_LLAMA4, PROJECTOR_TYPE_QWEN2A, + PROJECTOR_TYPE_QWEN3A, PROJECTOR_TYPE_GLMA, PROJECTOR_TYPE_QWEN25O, // will be replaced by QWEN2A or QWEN25VL depending on clip_ctx PROJECTOR_TYPE_VOXTRAL, @@ -207,6 +209,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_INTERNVL, "internvl"}, { PROJECTOR_TYPE_LLAMA4, "llama4"}, { PROJECTOR_TYPE_QWEN2A, "qwen2a"}, + { PROJECTOR_TYPE_QWEN3A, "qwen3a"}, { PROJECTOR_TYPE_GLMA, "glma"}, { PROJECTOR_TYPE_QWEN25O, "qwen2.5o"}, { PROJECTOR_TYPE_VOXTRAL, "voxtral"}, diff --git a/tools/mtmd/clip-model.h b/tools/mtmd/clip-model.h index 1e5aa87b98..2c4c1547a2 100644 --- a/tools/mtmd/clip-model.h +++ b/tools/mtmd/clip-model.h @@ -298,6 +298,8 @@ struct clip_model { ggml_tensor * conv1d_1_b = nullptr; ggml_tensor * conv1d_2_w = nullptr; ggml_tensor * conv1d_2_b = nullptr; + ggml_tensor * conv_out_w = nullptr; + ggml_tensor * conv_out_b = nullptr; ggml_tensor * mm_norm_pre_w = nullptr; ggml_tensor * mm_norm_pre_b = nullptr; ggml_tensor * mm_norm_mid_w = nullptr; diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index fb08dd258c..af926c6da0 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -817,6 +817,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 case PROJECTOR_TYPE_ULTRAVOX: case PROJECTOR_TYPE_VOXTRAL: case PROJECTOR_TYPE_QWEN2A: + case PROJECTOR_TYPE_QWEN3A: case PROJECTOR_TYPE_GLMA: case PROJECTOR_TYPE_MUSIC_FLAMINGO: { @@ -1175,6 +1176,7 @@ struct clip_model_loader { } break; case PROJECTOR_TYPE_ULTRAVOX: case PROJECTOR_TYPE_QWEN2A: + case PROJECTOR_TYPE_QWEN3A: case PROJECTOR_TYPE_GLMA: case PROJECTOR_TYPE_VOXTRAL: case PROJECTOR_TYPE_MUSIC_FLAMINGO: @@ -1569,6 +1571,18 @@ struct clip_model_loader { model.mm_fc_w = get_tensor(string_format(TN_MM_AUDIO_FC, "weight")); model.mm_fc_b = get_tensor(string_format(TN_MM_AUDIO_FC, "bias")); } break; + case PROJECTOR_TYPE_QWEN3A: + { + model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight")); + model.conv1d_1_b = get_tensor(string_format(TN_CONV1D, 1, "bias")); + model.conv1d_2_w = get_tensor(string_format(TN_CONV1D, 2, "weight")); + model.conv1d_2_b = get_tensor(string_format(TN_CONV1D, 2, "bias")); + model.conv_out_w = get_tensor(string_format(TN_CONV_OUT, "weight")); // no bias + model.mm_1_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "weight")); + model.mm_1_b = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "bias")); + model.mm_2_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "weight")); + model.mm_2_b = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "bias")); + } break; case PROJECTOR_TYPE_VOXTRAL: { model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight")); @@ -3044,6 +3058,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im case PROJECTOR_TYPE_VOXTRAL: case PROJECTOR_TYPE_ULTRAVOX: case PROJECTOR_TYPE_QWEN2A: + case PROJECTOR_TYPE_QWEN3A: case PROJECTOR_TYPE_MUSIC_FLAMINGO: { n_patches = img->nx; @@ -3413,6 +3428,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima case PROJECTOR_TYPE_IDEFICS3: case PROJECTOR_TYPE_INTERNVL: case PROJECTOR_TYPE_QWEN2A: + case PROJECTOR_TYPE_QWEN3A: case PROJECTOR_TYPE_GLMA: case PROJECTOR_TYPE_ULTRAVOX: case PROJECTOR_TYPE_LFM2: @@ -3549,8 +3565,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { return ctx->model.mm_model_proj->ne[1]; case PROJECTOR_TYPE_QWEN2A: return ctx->model.mm_fc_w->ne[1]; + case PROJECTOR_TYPE_QWEN3A: case PROJECTOR_TYPE_GLMA: - return ctx->model.mm_2_w->ne[1]; case PROJECTOR_TYPE_LFM2: case PROJECTOR_TYPE_KIMIVL: return ctx->model.mm_2_w->ne[1]; @@ -3602,6 +3618,7 @@ bool clip_has_audio_encoder(const struct clip_ctx * ctx) { bool clip_has_whisper_encoder(const struct clip_ctx * ctx) { return ctx->proj_type() == PROJECTOR_TYPE_ULTRAVOX || ctx->proj_type() == PROJECTOR_TYPE_QWEN2A + || ctx->proj_type() == PROJECTOR_TYPE_QWEN3A || ctx->proj_type() == PROJECTOR_TYPE_GLMA || ctx->proj_type() == PROJECTOR_TYPE_VOXTRAL || ctx->proj_type() == PROJECTOR_TYPE_MUSIC_FLAMINGO; diff --git a/tools/mtmd/models/whisper-enc.cpp b/tools/mtmd/models/whisper-enc.cpp index 2f2b127755..a541857797 100644 --- a/tools/mtmd/models/whisper-enc.cpp +++ b/tools/mtmd/models/whisper-enc.cpp @@ -19,9 +19,18 @@ ggml_cgraph * clip_graph_whisper_enc::build() { cur = ggml_add(ctx0, cur, model.conv1d_2_b); cur = ggml_gelu_erf(ctx0, cur); + // transpose inp = ggml_cont(ctx0, ggml_transpose(ctx0, cur)); cb(inp, "after_conv1d", -1); + + if (model.conv_out_w) { + inp = ggml_mul_mat(ctx0, model.conv_out_w, inp); + if (model.conv_out_b) { + inp = ggml_add(ctx0, inp, model.conv_out_b); + } + cb(inp, "after_conv_out", -1); + } } // sanity check (only check one layer, but it should be the same for all) @@ -77,6 +86,15 @@ ggml_cgraph * clip_graph_whisper_enc::build() { cur = ggml_mul_mat(ctx0, model.mm_fc_w, cur); cur = ggml_add(ctx0, cur, model.mm_fc_b); + } else if (proj_type == PROJECTOR_TYPE_QWEN3A) { + // projector + cur = build_ffn(cur, + model.mm_1_w, model.mm_1_b, + nullptr, nullptr, + model.mm_2_w, model.mm_2_b, + FFN_GELU_ERF, + -1); + } else if (proj_type == PROJECTOR_TYPE_VOXTRAL) { // projector cur = build_ffn(cur, From d703cf718451b4bdb2760d4549eab618914c09b6 Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Thu, 1 Jan 2026 21:29:41 +0100 Subject: [PATCH 2/6] wip --- convert_hf_to_gguf.py | 17 +++++--- gguf-py/gguf/constants.py | 3 ++ gguf-py/gguf/tensor_mapping.py | 3 ++ tools/mtmd/CMakeLists.txt | 1 + tools/mtmd/clip-impl.h | 1 + tools/mtmd/clip-model.h | 8 ++++ tools/mtmd/clip.cpp | 21 +++++++--- tools/mtmd/models/models.h | 5 +++ tools/mtmd/models/qwen3a.cpp | 69 +++++++++++++++++++++++++++++++ tools/mtmd/models/whisper-enc.cpp | 18 -------- tools/mtmd/mtmd.cpp | 3 +- 11 files changed, 119 insertions(+), 30 deletions(-) create mode 100644 tools/mtmd/models/qwen3a.cpp diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 395d0d37ba..78bc36bbee 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -4515,12 +4515,15 @@ class Qwen3OmniMmprojModel(Qwen3VLVisionModel, Qwen25AudioModel): name = name.replace(".ln_q", ".norm") name = name.replace(".mlp.0", ".linear_fc1") name = name.replace(".mlp.2", ".linear_fc2") - if ".merger." in name: + elif ".merger." in name: name = name.replace(".ln_q", ".norm") name = name.replace(".mlp.0", ".linear_fc1") name = name.replace(".mlp.2", ".linear_fc2") return Qwen3VLVisionModel.modify_tensors(self, data_torch, name, bid) elif "audio_tower." in name: + if "conv2d" in name and name.endswith(".bias"): + # transform conv2d bias [n_embd] --> [1, 1, n_embd] + data_torch = data_torch.unsqueeze(-1).unsqueeze(-1) return Qwen25AudioModel.modify_tensors(self, data_torch, name, bid) return [] @@ -4555,9 +4558,10 @@ class Qwen3VLTextModel(Qwen3Model): def set_gguf_parameters(self): super().set_gguf_parameters() - - # Handle MRoPE (Multi-axis Rotary Position Embedding) for Qwen3-VL - vision_config = self.hparams.get("vision_config", {}) + if "thinker_config" in self.hparams: + vision_config = self.hparams["thinker_config"].get("vision_config", {}) + else: + vision_config = self.hparams.get("vision_config", {}) deepstack_layer_num = len(vision_config.get("deepstack_visual_indexes", [])) self.gguf_writer.add_num_deepstack_layers(deepstack_layer_num) @@ -4575,7 +4579,10 @@ class Qwen3VLMoeTextModel(Qwen3MoeModel): def set_gguf_parameters(self): super().set_gguf_parameters() - vision_config = self.hparams.get("vision_config", {}) + if "thinker_config" in self.hparams: + vision_config = self.hparams["thinker_config"].get("vision_config", {}) + else: + vision_config = self.hparams.get("vision_config", {}) deepstack_layer_num = len(vision_config.get("deepstack_visual_indexes", [])) self.gguf_writer.add_num_deepstack_layers(deepstack_layer_num) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 1599aa15d9..190d4e353e 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -700,6 +700,7 @@ class MODEL_TENSOR(IntEnum): A_ENC_EMBD_NORM = auto() A_ENC_EMBD_TO_LOGITS = auto() A_ENC_CONV1D = auto() + A_ENC_CONV2D = auto() A_ENC_CONV_OUT = auto() A_PRE_NORM = auto() A_POST_NORM = auto() @@ -1098,6 +1099,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.A_ENC_EMBD_NORM: "a.position_embd_norm", MODEL_TENSOR.A_ENC_EMBD_TO_LOGITS: "a.embd_to_logits", MODEL_TENSOR.A_ENC_CONV1D: "a.conv1d.{bid}", + MODEL_TENSOR.A_ENC_CONV2D: "a.conv2d.{bid}", MODEL_TENSOR.A_ENC_CONV_OUT: "a.conv_out", MODEL_TENSOR.A_PRE_NORM: "a.pre_ln", MODEL_TENSOR.A_POST_NORM: "a.post_ln", @@ -1196,6 +1198,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.A_ENC_EMBD_NORM, MODEL_TENSOR.A_ENC_EMBD_TO_LOGITS, MODEL_TENSOR.A_ENC_CONV1D, + MODEL_TENSOR.A_ENC_CONV2D, MODEL_TENSOR.A_ENC_CONV_OUT, MODEL_TENSOR.A_PRE_NORM, MODEL_TENSOR.A_POST_NORM, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index aeb6843c3e..c4d0e45d54 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -1563,6 +1563,9 @@ class TensorNameMap: MODEL_TENSOR.A_ENC_CONV1D: ( "audio_tower.conv{bid}", # ultravox "conformer.pre_encode.conv.{bid}", # lfm2 + ), + + MODEL_TENSOR.A_ENC_CONV2D: ( "audio_tower.conv2d{bid}", # qwen3omni ), diff --git a/tools/mtmd/CMakeLists.txt b/tools/mtmd/CMakeLists.txt index 317d5f19fd..a8773e1124 100644 --- a/tools/mtmd/CMakeLists.txt +++ b/tools/mtmd/CMakeLists.txt @@ -27,6 +27,7 @@ add_library(mtmd models/qwen3vl.cpp models/siglip.cpp models/whisper-enc.cpp + models/qwen3a.cpp ) set_target_properties(mtmd PROPERTIES diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index f77ec73f4a..a6fad4cbe9 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -125,6 +125,7 @@ // ultravox #define TN_CONV1D "a.conv1d.%d.%s" +#define TN_CONV2D "a.conv2d.%d.%s" #define TN_CONV_OUT "a.conv_out.%s" #define TN_MM_AUDIO_MLP "mm.a.mlp.%d.%s" #define TN_MM_AUDIO_FC "mm.a.fc.%s" // fully connected layer diff --git a/tools/mtmd/clip-model.h b/tools/mtmd/clip-model.h index 2c4c1547a2..41f364e5d0 100644 --- a/tools/mtmd/clip-model.h +++ b/tools/mtmd/clip-model.h @@ -304,6 +304,14 @@ struct clip_model { ggml_tensor * mm_norm_pre_b = nullptr; ggml_tensor * mm_norm_mid_w = nullptr; + // qwen3a + ggml_tensor * conv2d_1_w = nullptr; + ggml_tensor * conv2d_1_b = nullptr; + ggml_tensor * conv2d_2_w = nullptr; + ggml_tensor * conv2d_2_b = nullptr; + ggml_tensor * conv2d_3_w = nullptr; + ggml_tensor * conv2d_3_b = nullptr; + // cogvlm ggml_tensor * mm_post_fc_norm_w = nullptr; ggml_tensor * mm_post_fc_norm_b = nullptr; diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index af926c6da0..da7f320ffb 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -817,7 +817,6 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 case PROJECTOR_TYPE_ULTRAVOX: case PROJECTOR_TYPE_VOXTRAL: case PROJECTOR_TYPE_QWEN2A: - case PROJECTOR_TYPE_QWEN3A: case PROJECTOR_TYPE_GLMA: case PROJECTOR_TYPE_MUSIC_FLAMINGO: { @@ -847,6 +846,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 { builder = std::make_unique(ctx, img); } break; + case PROJECTOR_TYPE_QWEN3A: + { + builder = std::make_unique(ctx, img); + } break; default: GGML_ABORT("missing cgraph builder"); } @@ -1573,10 +1576,12 @@ struct clip_model_loader { } break; case PROJECTOR_TYPE_QWEN3A: { - model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight")); - model.conv1d_1_b = get_tensor(string_format(TN_CONV1D, 1, "bias")); - model.conv1d_2_w = get_tensor(string_format(TN_CONV1D, 2, "weight")); - model.conv1d_2_b = get_tensor(string_format(TN_CONV1D, 2, "bias")); + model.conv2d_1_w = get_tensor(string_format(TN_CONV2D, 1, "weight")); + model.conv2d_1_b = get_tensor(string_format(TN_CONV2D, 1, "bias")); + model.conv2d_2_w = get_tensor(string_format(TN_CONV2D, 2, "weight")); + model.conv2d_2_b = get_tensor(string_format(TN_CONV2D, 2, "bias")); + model.conv2d_3_w = get_tensor(string_format(TN_CONV2D, 3, "weight")); + model.conv2d_3_b = get_tensor(string_format(TN_CONV2D, 3, "bias")); model.conv_out_w = get_tensor(string_format(TN_CONV_OUT, "weight")); // no bias model.mm_1_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "weight")); model.mm_1_b = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "bias")); @@ -3058,7 +3063,6 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im case PROJECTOR_TYPE_VOXTRAL: case PROJECTOR_TYPE_ULTRAVOX: case PROJECTOR_TYPE_QWEN2A: - case PROJECTOR_TYPE_QWEN3A: case PROJECTOR_TYPE_MUSIC_FLAMINGO: { n_patches = img->nx; @@ -3078,6 +3082,10 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im n_patches /= 2; } } break; + case PROJECTOR_TYPE_QWEN3A: + { + return 375; // TODO: calculate this + } break; case PROJECTOR_TYPE_GLMA: { n_patches = img->nx; @@ -3566,6 +3574,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { case PROJECTOR_TYPE_QWEN2A: return ctx->model.mm_fc_w->ne[1]; case PROJECTOR_TYPE_QWEN3A: + return ctx->model.mm_2_w->ne[1] * 4; // 4 for deepstack, TODO: do NOT hardcode case PROJECTOR_TYPE_GLMA: case PROJECTOR_TYPE_LFM2: case PROJECTOR_TYPE_KIMIVL: diff --git a/tools/mtmd/models/models.h b/tools/mtmd/models/models.h index e08c33f353..266f01f0f9 100644 --- a/tools/mtmd/models/models.h +++ b/tools/mtmd/models/models.h @@ -71,3 +71,8 @@ struct clip_graph_glm4v : clip_graph { clip_graph_glm4v(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} ggml_cgraph * build() override; }; + +struct clip_graph_qwen3a : clip_graph { + clip_graph_qwen3a(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} + ggml_cgraph * build() override; +}; diff --git a/tools/mtmd/models/qwen3a.cpp b/tools/mtmd/models/qwen3a.cpp new file mode 100644 index 0000000000..2680073290 --- /dev/null +++ b/tools/mtmd/models/qwen3a.cpp @@ -0,0 +1,69 @@ +#include "models.h" + +ggml_cgraph * clip_graph_qwen3a::build() { + ggml_tensor * inp = build_inp_raw(1); + + // conv2d block + // TODO: do we need to split by chunks of n_window each like on transformers impl? + { + inp = ggml_conv_2d(ctx0, model.conv2d_1_w, inp, 2, 2, 1, 1, 1, 1); + inp = ggml_add(ctx0, inp, model.conv2d_1_b); + inp = ggml_gelu_erf(ctx0, inp); + + inp = ggml_conv_2d(ctx0, model.conv2d_2_w, inp, 2, 2, 1, 1, 1, 1); + inp = ggml_add(ctx0, inp, model.conv2d_2_b); + inp = ggml_gelu_erf(ctx0, inp); + + inp = ggml_conv_2d(ctx0, model.conv2d_3_w, inp, 2, 2, 1, 1, 1, 1); + inp = ggml_add(ctx0, inp, model.conv2d_3_b); + inp = ggml_gelu_erf(ctx0, inp); + + // inp is now [time, frames, channels] + cb(inp, "after_conv_blocks", -1); + + inp = ggml_permute(ctx0, inp, 2, 1, 0, 3); // [channels, frames, time] + inp = ggml_cont(ctx0, inp); + inp = ggml_reshape_2d(ctx0, inp, inp->ne[0] * inp->ne[1], inp->ne[2]); // [channels * time, frames] + + // project to n_embd + inp = ggml_mul_mat(ctx0, model.conv_out_w, inp); + if (model.conv_out_b) { + inp = ggml_add(ctx0, inp, model.conv_out_b); + } + cb(inp, "after_conv_out", -1); + } + + auto n_pos = inp->ne[1]; + + ggml_tensor * pos_embd_selected = ggml_view_2d( + ctx0, model.position_embeddings, + model.position_embeddings->ne[0], n_pos, + model.position_embeddings->nb[1], 0 + ); + ggml_tensor * cur = build_vit( + inp, n_pos, + NORM_TYPE_NORMAL, + hparams.ffn_op, + pos_embd_selected, + nullptr); + + cb(cur, "after_transformer", -1); + + // projector + cur = build_ffn(cur, + model.mm_1_w, model.mm_1_b, + nullptr, nullptr, + model.mm_2_w, model.mm_2_b, + FFN_GELU_ERF, + -1); + + cb(cur, "projected", -1); + + // pad deepstack if needed + // TODO: do NOT hard code 3 here + cur = ggml_pad(ctx0, cur, cur->ne[0] * 3, 0, 0, 0); + + ggml_build_forward_expand(gf, cur); + + return gf; +} diff --git a/tools/mtmd/models/whisper-enc.cpp b/tools/mtmd/models/whisper-enc.cpp index a541857797..2f2b127755 100644 --- a/tools/mtmd/models/whisper-enc.cpp +++ b/tools/mtmd/models/whisper-enc.cpp @@ -19,18 +19,9 @@ ggml_cgraph * clip_graph_whisper_enc::build() { cur = ggml_add(ctx0, cur, model.conv1d_2_b); cur = ggml_gelu_erf(ctx0, cur); - // transpose inp = ggml_cont(ctx0, ggml_transpose(ctx0, cur)); cb(inp, "after_conv1d", -1); - - if (model.conv_out_w) { - inp = ggml_mul_mat(ctx0, model.conv_out_w, inp); - if (model.conv_out_b) { - inp = ggml_add(ctx0, inp, model.conv_out_b); - } - cb(inp, "after_conv_out", -1); - } } // sanity check (only check one layer, but it should be the same for all) @@ -86,15 +77,6 @@ ggml_cgraph * clip_graph_whisper_enc::build() { cur = ggml_mul_mat(ctx0, model.mm_fc_w, cur); cur = ggml_add(ctx0, cur, model.mm_fc_b); - } else if (proj_type == PROJECTOR_TYPE_QWEN3A) { - // projector - cur = build_ffn(cur, - model.mm_1_w, model.mm_1_b, - nullptr, nullptr, - model.mm_2_w, model.mm_2_b, - FFN_GELU_ERF, - -1); - } else if (proj_type == PROJECTOR_TYPE_VOXTRAL) { // projector cur = build_ffn(cur, diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index b0b5ab42ab..c18a25b7d3 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -326,6 +326,7 @@ struct mtmd_context { // set preprocessor switch (proj) { case PROJECTOR_TYPE_QWEN2A: + case PROJECTOR_TYPE_QWEN3A: case PROJECTOR_TYPE_QWEN25O: case PROJECTOR_TYPE_ULTRAVOX: case PROJECTOR_TYPE_VOXTRAL: @@ -344,7 +345,7 @@ struct mtmd_context { audio_preproc->initialize(); // set special tokens - if (proj == PROJECTOR_TYPE_QWEN2A) { + if (proj == PROJECTOR_TYPE_QWEN2A || proj == PROJECTOR_TYPE_QWEN3A) { // <|audio_bos|> ... (embeddings) ... <|audio_eos|> aud_beg = "<|audio_bos|>"; aud_end = "<|audio_eos|>"; From e0adcf723227f4f22d2701312cbc6e6827a51965 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 2 Apr 2026 00:22:17 +0200 Subject: [PATCH 3/6] vision ok --- convert_hf_to_gguf.py | 21 +++++++++++++-------- tools/mtmd/clip.cpp | 9 +++++++-- 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index d559d70631..0ca9f97151 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -4936,7 +4936,7 @@ class Qwen3VLVisionModel(MmprojModel): return if name.startswith("visual."): - return [(self.map_tensor_name(name), data_torch)] + yield (self.map_tensor_name(name), data_torch) return [] # skip other tensors @@ -4975,7 +4975,6 @@ class Qwen3OmniMmprojModel(Qwen3VLVisionModel, Qwen25AudioModel): # transform conv2d bias [n_embd] --> [1, 1, n_embd] data_torch = data_torch.unsqueeze(-1).unsqueeze(-1) yield from Qwen25AudioModel.modify_tensors(self, data_torch, name, bid) - return [] @ModelBase.register("Glm4vForConditionalGeneration", "Glm4vMoeForConditionalGeneration", "GlmOcrForConditionalGeneration") @@ -5028,14 +5027,20 @@ class Qwen3VLTextModel(Qwen3Model): class Qwen3VLMoeTextModel(Qwen3MoeModel): model_arch = gguf.MODEL_ARCH.QWEN3VLMOE + def set_vocab(self): + super().set_vocab() + # correct BOS/EOS tokens + with open(self.dir_model / "tokenizer_config.json", "r", encoding="utf-8") as f: + tokenizer_config = json.load(f) + added_tokens = tokenizer_config.get("added_tokens_decoder", {}) + for token_id, data in added_tokens.items(): + if data.get("content") == "<|im_end|>": + self.gguf_writer.add_bos_token_id(int(token_id)) + self.gguf_writer.add_eos_token_id(int(token_id)) + def set_gguf_parameters(self): super().set_gguf_parameters() - if "thinker_config" in self.hparams: - vision_config = self.hparams["thinker_config"].get("vision_config", {}) - else: - vision_config = self.hparams.get("vision_config", {}) - deepstack_layer_num = len(vision_config.get("deepstack_visual_indexes", [])) - self.gguf_writer.add_num_deepstack_layers(deepstack_layer_num) + self.gguf_writer.add_num_deepstack_layers(0) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: # Skip vision tensors - they go in the mmproj file diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 3178ea6601..12a298e9a8 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -2669,7 +2669,12 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im } break; case PROJECTOR_TYPE_QWEN3A: { - return 375; // TODO: calculate this + // 3x stride-2 conv2d: each step is floor((n-1)/2)+1 + int n = img->nx; + n = (n - 1) / 2 + 1; + n = (n - 1) / 2 + 1; + n = (n - 1) / 2 + 1; + n_patches = n; } break; case PROJECTOR_TYPE_GLMA: { @@ -3256,7 +3261,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { case PROJECTOR_TYPE_QWEN2A: return ctx->model.mm_fc_w->ne[1]; case PROJECTOR_TYPE_QWEN3A: - return ctx->model.mm_2_w->ne[1] * 4; // 4 for deepstack, TODO: do NOT hardcode + return ctx->model.mm_2_w->ne[1]; case PROJECTOR_TYPE_GLMA: case PROJECTOR_TYPE_LFM2: case PROJECTOR_TYPE_KIMIVL: From 792edb7d7a07e5a55c2e651a84f69fe8b46db402 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 2 Apr 2026 00:36:47 +0200 Subject: [PATCH 4/6] no more deepstack for audio --- tools/mtmd/models/qwen3a.cpp | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/tools/mtmd/models/qwen3a.cpp b/tools/mtmd/models/qwen3a.cpp index 2680073290..1384e5155e 100644 --- a/tools/mtmd/models/qwen3a.cpp +++ b/tools/mtmd/models/qwen3a.cpp @@ -18,12 +18,15 @@ ggml_cgraph * clip_graph_qwen3a::build() { inp = ggml_add(ctx0, inp, model.conv2d_3_b); inp = ggml_gelu_erf(ctx0, inp); - // inp is now [time, frames, channels] + // inp [n_pos, n_mels/8, channels, 1] (W, H, C, N) cb(inp, "after_conv_blocks", -1); - inp = ggml_permute(ctx0, inp, 2, 1, 0, 3); // [channels, frames, time] - inp = ggml_cont(ctx0, inp); - inp = ggml_reshape_2d(ctx0, inp, inp->ne[0] * inp->ne[1], inp->ne[2]); // [channels * time, frames] + const int64_t n_pos_after_conv = inp->ne[0]; + const int64_t n_mel_after_conv = inp->ne[1]; // 128/8 = 16 + + inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 0, 2, 3, 1)); + inp = ggml_reshape_2d(ctx0, inp, n_pos_after_conv, n_mel_after_conv * inp->ne[3]); // [n_pos, 7680] + inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp)); // [7680, n_pos] // project to n_embd inp = ggml_mul_mat(ctx0, model.conv_out_w, inp); @@ -59,10 +62,6 @@ ggml_cgraph * clip_graph_qwen3a::build() { cb(cur, "projected", -1); - // pad deepstack if needed - // TODO: do NOT hard code 3 here - cur = ggml_pad(ctx0, cur, cur->ne[0] * 3, 0, 0, 0); - ggml_build_forward_expand(gf, cur); return gf; From 172865e93c2eddab0809771cd606c0c8897b13a1 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 2 Apr 2026 00:55:44 +0200 Subject: [PATCH 5/6] convert ASR model ok --- convert_hf_to_gguf.py | 79 +++++++++++++++++++++++++++++++++++++++---- 1 file changed, 73 insertions(+), 6 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 0ca9f97151..d1279668d4 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -4825,7 +4825,10 @@ class RND1Model(Qwen2MoeModel): class Qwen3VLVisionModel(MmprojModel): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - assert self.hparams_vision is not None + if self.hparams_vision is None: + logger.info("No vision config found, skipping vision tensor processing") + return + # Compute image_size if not present if "image_size" not in self.hparams_vision: # For Qwen3VL/Qwen3VLMoe, compute from num_position_embeddings @@ -4946,18 +4949,29 @@ class Qwen3OmniMmprojModel(Qwen3VLVisionModel, Qwen25AudioModel): has_vision_encoder = True def get_vision_config(self) -> dict[str, Any] | None: - return self.global_config["thinker_config"].get("vision_config") + if self.has_vision_encoder: + return self.global_config["thinker_config"].get("vision_config") + else: + return None def get_audio_config(self) -> dict[str, Any] | None: - return self.global_config["thinker_config"].get("audio_config") + if self.has_audio_encoder: + return self.global_config["thinker_config"].get("audio_config") + else: + return None def set_gguf_parameters(self): - super().set_gguf_parameters() - self.gguf_writer.add_clip_vision_projector_type(gguf.VisionProjectorType.QWEN3VL) - self.gguf_writer.add_clip_audio_projector_type(gguf.VisionProjectorType.QWEN3A) + if self.has_vision_encoder: + Qwen3VLVisionModel.set_gguf_parameters(self) + self.gguf_writer.add_clip_vision_projector_type(gguf.VisionProjectorType.QWEN3VL) + if self.has_audio_encoder: + Qwen25AudioModel.set_gguf_parameters(self) + self.gguf_writer.add_clip_audio_projector_type(gguf.VisionProjectorType.QWEN3A) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: if "visual." in name: + if not self.has_vision_encoder: + raise ValueError(f"Model does not have vision encoder, but found tensor {name}") # need to transform vision tensor naming, so that modify_tensors() logic can be used correctly name = name.replace("thinker.visual.", "model.visual.") if ".merger_list." in name: @@ -4971,12 +4985,20 @@ class Qwen3OmniMmprojModel(Qwen3VLVisionModel, Qwen25AudioModel): name = name.replace(".mlp.2", ".linear_fc2") yield from Qwen3VLVisionModel.modify_tensors(self, data_torch, name, bid) elif "audio_tower." in name: + if not self.has_audio_encoder: + raise ValueError(f"Model does not have audio encoder, but found tensor {name}") if "conv2d" in name and name.endswith(".bias"): # transform conv2d bias [n_embd] --> [1, 1, n_embd] data_torch = data_torch.unsqueeze(-1).unsqueeze(-1) yield from Qwen25AudioModel.modify_tensors(self, data_torch, name, bid) +@ModelBase.register("Qwen3ASRForConditionalGeneration") +class Qwen3ASRMmprojModel(Qwen3OmniMmprojModel): + has_audio_encoder = True + has_vision_encoder = False + + @ModelBase.register("Glm4vForConditionalGeneration", "Glm4vMoeForConditionalGeneration", "GlmOcrForConditionalGeneration") class Glm4VVisionModel(Qwen3VLVisionModel): def set_gguf_parameters(self): @@ -5023,6 +5045,31 @@ class Qwen3VLTextModel(Qwen3Model): yield from super().modify_tensors(data_torch, name, bid) +@ModelBase.register("Qwen3ASRForConditionalGeneration") +class Qwen3ASRTextModel(Qwen3VLTextModel): + model_arch = gguf.MODEL_ARCH.QWEN3VL + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_num_deepstack_layers(0) + + def set_vocab(self): + super().set_vocab() + # fix chat template, use correct chatml format + self.gguf_writer.add_chat_template("{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}") + + def modify_tensors(self, data_torch, name, bid): + # qwen3-omni + name = name.replace("thinker.", "") + + # Skip vision and audio tensors - they go in the mmproj file + if "visual." in name or "audio_tower." in name \ + or "talker." in name or "code2wav." in name: + return + + yield from super().modify_tensors(data_torch, name, bid) + + @ModelBase.register("Qwen3VLMoeForConditionalGeneration", "Qwen3OmniMoeForConditionalGeneration") class Qwen3VLMoeTextModel(Qwen3MoeModel): model_arch = gguf.MODEL_ARCH.QWEN3VLMOE @@ -5083,6 +5130,26 @@ class Qwen3VLMoeTextModel(Qwen3MoeModel): yield from super().modify_tensors(data_torch, name, bid) +@ModelBase.register("Qwen3OmniMoeForConditionalGeneration") +class Qwen3OmniMoeTextModel(Qwen3VLMoeTextModel): + model_arch = gguf.MODEL_ARCH.QWEN3VLMOE + + def set_vocab(self): + super().set_vocab() + # correct BOS/EOS tokens + with open(self.dir_model / "tokenizer_config.json", "r", encoding="utf-8") as f: + tokenizer_config = json.load(f) + added_tokens = tokenizer_config.get("added_tokens_decoder", {}) + for token_id, data in added_tokens.items(): + if data.get("content") == "<|im_end|>": + self.gguf_writer.add_bos_token_id(int(token_id)) + self.gguf_writer.add_eos_token_id(int(token_id)) + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_num_deepstack_layers(0) + + class _LinearAttentionVReorderBase(Qwen3NextModel): model_arch = gguf.MODEL_ARCH.QWEN3NEXT # overridden by subclasses """reorders V heads from grouped to tiled order for ggml broadcast From eefcfeed7abcafc74901c46721e6c3324b85c144 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 2 Apr 2026 01:01:30 +0200 Subject: [PATCH 6/6] qwen3 asr working --- convert_hf_to_gguf.py | 8 ++++++++ tools/mtmd/mtmd.cpp | 4 ++++ 2 files changed, 12 insertions(+) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index d1279668d4..05d8e99284 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -5057,6 +5057,14 @@ class Qwen3ASRTextModel(Qwen3VLTextModel): super().set_vocab() # fix chat template, use correct chatml format self.gguf_writer.add_chat_template("{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}") + # correct BOS/EOS tokens + with open(self.dir_model / "tokenizer_config.json", "r", encoding="utf-8") as f: + tokenizer_config = json.load(f) + added_tokens = tokenizer_config.get("added_tokens_decoder", {}) + for token_id, data in added_tokens.items(): + if data.get("content") == "<|im_end|>": + self.gguf_writer.add_bos_token_id(int(token_id)) + self.gguf_writer.add_eos_token_id(int(token_id)) def modify_tensors(self, data_torch, name, bid): # qwen3-omni diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index 2a8c4e5aad..8219591ed2 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -982,6 +982,10 @@ bool mtmd_decode_use_non_causal(mtmd_context * ctx) { } bool mtmd_decode_use_mrope(mtmd_context * ctx) { + if (ctx->ctx_v == nullptr && ctx->proj_type_a() == PROJECTOR_TYPE_QWEN3A) { + // qwen3-asr + return true; + } switch (ctx->proj_type_v()) { case PROJECTOR_TYPE_QWEN2VL: case PROJECTOR_TYPE_QWEN25VL: