model : add Qwen3-Omni multimodal architecture support

Adds support for Qwen3-Omni, Alibaba's multimodal LLM that handles
text and vision. This enables the main LLM architecture and vision
encoder support.

Main LLM changes:
- Add LLM_ARCH_QWEN3OMNI enum and architecture registration
- Add hparams loading for MoE-based architecture (48 layers, 128 experts)
- Reuse llm_build_qwen3moe graph builder
- Add IMROPE type for multimodal position encoding

Vision encoder changes (via mtmd):
- Add PROJECTOR_TYPE_QWEN3O with auto-conversion to QWEN3VL for vision
- Support different embedding dimensions (vision=8192, audio=2048)
- Add separate Q/K/V tensor support in qwen3vl graph builder

Tested with Qwen3-Omni-30B-Q8_0.gguf on distributed 5-GPU setup:
- 41-44 tokens/sec inference speed
- Text and vision inference working

Note: Audio encoder support is WIP and will follow in a separate PR.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
EliteGPT AI 2025-12-31 12:57:45 +10:00
parent 9a6369bb60
commit 7bab4a3065
9 changed files with 409 additions and 19 deletions

View File

@ -4537,6 +4537,315 @@ class Qwen3VLMoeTextModel(Qwen3MoeModel):
return super().modify_tensors(data_torch, name, bid)
@ModelBase.register("Qwen3OmniMoeForConditionalGeneration")
class Qwen3OmniModel(MmprojModel):
"""Qwen3-Omni multimodal model converter for audio + vision encoders.
Key differences from Qwen2.5-Omni:
- Audio uses conv2d1/conv2d2/conv2d3 (not conv1/conv2)
- Audio has conv_out, ln_post, proj1, proj2
- Vision has merger_list (deepstack) like Qwen3-VL
- Vision patch_embed is Conv3D (needs 5D4D tensor splitting)
- Vision has explicit pos_embed.weight
"""
has_vision_encoder = True
has_audio_encoder = True
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Setup audio config
assert self.hparams_audio is not None
self.hparams_audio["hidden_size"] = self.hparams_audio.get("d_model")
self.hparams_audio["intermediate_size"] = self.hparams_audio.get("encoder_ffn_dim")
self.hparams_audio["num_attention_heads"] = self.hparams_audio.get("encoder_attention_heads")
# Setup vision config
assert self.hparams_vision is not None
self.hparams_vision["num_attention_heads"] = self.hparams_vision.get("num_heads")
self.hparams_vision["num_hidden_layers"] = self.hparams_vision.get("depth")
# Handle image_size - may need to compute from other params
if "image_size" not in self.hparams_vision or self.hparams_vision["image_size"] is None:
self.hparams_vision["image_size"] = 768 # Default for Qwen3-Omni
# Track deepstack layers
self.is_deepstack_layers = [False] * int(self.hparams_vision.get("num_hidden_layers", 27) or 27)
for idx in self.hparams_vision.get("deepstack_visual_indexes", []):
self.is_deepstack_layers[idx] = True
def get_vision_config(self) -> dict[str, Any] | None:
return self.global_config.get("thinker_config", {}).get("vision_config")
def get_audio_config(self) -> dict[str, Any] | None:
return self.global_config.get("thinker_config", {}).get("audio_config")
def set_gguf_parameters(self):
super().set_gguf_parameters()
# Set projector type for Qwen3-Omni
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN3O)
# Audio parameters
assert self.hparams_audio is not None
self.gguf_writer.add_audio_num_mel_bins(self.hparams_audio.get("num_mel_bins", 128))
self.gguf_writer.add_audio_attention_layernorm_eps(self.hparams_audio.get("layer_norm_eps", 1e-5))
# Vision parameters
self.gguf_writer.add_vision_use_gelu(True) # Qwen3-Omni uses GELU
# Vision attention layernorm eps from text config
text_config = self.global_config.get("thinker_config", {}).get("text_config", {})
rms_norm_eps = text_config.get("rms_norm_eps", 1e-6)
self.gguf_writer.add_vision_attention_layernorm_eps(rms_norm_eps)
# Deepstack layers for vision
if any(self.is_deepstack_layers):
self.gguf_writer.add_vision_is_deepstack_layers(self.is_deepstack_layers)
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
"""Generate sinusoidal position embeddings for audio encoder."""
assert self.hparams_audio is not None
max_timescale = 10000
length = 1500 # Max audio sequence length
channels = self.hparams_audio.get("hidden_size", 1280)
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2).float())
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
pos_embd = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1).to(dtype=torch.float32)
yield ("audio_tower.embed_positions.weight", pos_embd)
def tensor_force_quant(self, name, new_name, bid, n_dims):
# Keep conv layers in higher precision
if ".conv" in name and ".weight" in name:
return gguf.GGMLQuantizationType.F16
return super().tensor_force_quant(name, new_name, bid, n_dims)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# Strip thinker prefix
if name.startswith("thinker."):
name = name.replace("thinker.", "")
# Skip text model tensors (handled by text model converter)
if name.startswith("model.") or name.startswith("lm_head") or name.startswith("embed_tokens"):
return []
# Skip talker and code2wav (not needed for inference)
if name.startswith("talker.") or name.startswith("code2wav."):
return []
# Handle audio tensors
if name.startswith("audio_tower"):
# Strip audio_tower. prefix for processing
audio_name = name.replace("audio_tower.", "")
# Skip embed_positions - we generate sinusoidal positions in generate_extra_tensors
if audio_name.startswith("embed_positions"):
return []
# Handle conv2d1/2/3 - map to a.enc_conv1d.{bid}
for i in [1, 2, 3]:
if audio_name.startswith(f"conv2d{i}."):
suffix = audio_name.split(".", 1)[1] # weight or bias
if suffix == "bias":
data_torch = data_torch.unsqueeze(-1)
new_name = self.format_tensor_name(gguf.MODEL_TENSOR.A_ENC_CONV1D, i - 1, suffix=f".{suffix}")
return [(new_name, data_torch)]
# Handle conv_out - use a separate conv layer index
if audio_name.startswith("conv_out."):
suffix = audio_name.split(".", 1)[1]
if suffix == "bias":
data_torch = data_torch.unsqueeze(-1)
new_name = self.format_tensor_name(gguf.MODEL_TENSOR.A_ENC_CONV1D, 3, suffix=f".{suffix}")
return [(new_name, data_torch)]
# Handle ln_post - post normalization
if audio_name.startswith("ln_post."):
suffix = audio_name.split(".", 1)[1]
new_name = self.format_tensor_name(gguf.MODEL_TENSOR.A_POST_NORM, suffix=f".{suffix}")
return [(new_name, data_torch)]
# Handle proj1/proj2 - audio multimodal projector (use A_MMPROJ which supports bid)
if audio_name.startswith("proj1."):
suffix = audio_name.split(".", 1)[1]
new_name = self.format_tensor_name(gguf.MODEL_TENSOR.A_MMPROJ, 0, suffix=f".{suffix}")
return [(new_name, data_torch)]
if audio_name.startswith("proj2."):
suffix = audio_name.split(".", 1)[1]
new_name = self.format_tensor_name(gguf.MODEL_TENSOR.A_MMPROJ, 1, suffix=f".{suffix}")
return [(new_name, data_torch)]
# Handle encoder layers - transform to Whisper-compatible names and use map_tensor_name
if audio_name.startswith("layers."):
# Qwen3-Omni uses same layer naming as Whisper/Ultravox
# audio_tower.layers.{bid}.self_attn.q_proj -> audio_tower.layers.{bid}.self_attn.q_proj
# Just add back the audio_tower prefix and use map_tensor_name
return [(self.map_tensor_name(name), data_torch)]
# Fallback for any other audio tensors
logger.warning(f"Unknown audio tensor: {name}")
return [(self.map_tensor_name(name), data_torch)]
# Handle visual tensors
if name.startswith("visual."):
# Handle merger_list (deepstack)
if name.startswith("visual.merger_list."):
# Format: visual.merger_list.{idx}.{ln_q|mlp}.{layer}.{weight|bias}
parts = name.split(".")
idx = int(parts[2]) # merger_list index
# Get actual layer index from deepstack_visual_indexes
deepstack_indexes = self.hparams_vision.get("deepstack_visual_indexes", [])
if idx < len(deepstack_indexes):
layer_idx = deepstack_indexes[idx]
else:
layer_idx = idx # Fallback
suffix_parts = parts[3:] # Everything after the index
suffix = ".".join(suffix_parts)
if suffix.startswith("ln_q"):
tensor_type = gguf.MODEL_TENSOR.V_DS_NORM
tail = suffix.split(".", 1)[1] if "." in suffix else "weight"
elif suffix.startswith("mlp.0"):
tensor_type = gguf.MODEL_TENSOR.V_DS_FC1
tail = suffix.split(".", 2)[2] if suffix.count(".") >= 2 else "weight"
elif suffix.startswith("mlp.2"):
tensor_type = gguf.MODEL_TENSOR.V_DS_FC2
tail = suffix.split(".", 2)[2] if suffix.count(".") >= 2 else "weight"
else:
raise ValueError(f"Unexpected deepstack tensor: {name}")
new_name = self.format_tensor_name(tensor_type, layer_idx, suffix=f".{tail}")
return [(new_name, data_torch)]
# Handle main merger
if name.startswith("visual.merger."):
suffix = name.split(".", 2)[2]
if suffix.startswith("mlp.0"):
# First FC layer
tail = suffix.split(".", 2)[2] if suffix.count(".") >= 2 else "weight"
new_name = self.format_tensor_name(gguf.MODEL_TENSOR.V_MMPROJ, 0, suffix=f".{tail}")
elif suffix.startswith("mlp.2"):
# Second FC layer
tail = suffix.split(".", 2)[2] if suffix.count(".") >= 2 else "weight"
new_name = self.format_tensor_name(gguf.MODEL_TENSOR.V_MMPROJ, 2, suffix=f".{tail}")
elif suffix.startswith("ln_q"):
tail = suffix.split(".", 1)[1] if "." in suffix else "weight"
new_name = self.format_tensor_name(gguf.MODEL_TENSOR.V_POST_NORM, suffix=f".{tail}")
else:
raise ValueError(f"Unexpected merger tensor: {name}")
return [(new_name, data_torch)]
# Handle QKV split for attention
if ".qkv." in name:
if data_torch.ndim == 2: # weight
c3, _ = data_torch.shape
else: # bias
c3 = data_torch.shape[0]
assert c3 % 3 == 0
c = c3 // 3
wq = data_torch[:c]
wk = data_torch[c:c * 2]
wv = data_torch[c * 2:]
return [
(self.map_tensor_name(name.replace("qkv", "q")), wq),
(self.map_tensor_name(name.replace("qkv", "k")), wk),
(self.map_tensor_name(name.replace("qkv", "v")), wv),
]
# Handle patch_embed - Conv3D needs splitting to 4D tensors (GGUF max is 4D)
if name == "visual.patch_embed.proj.weight":
# Split Conv3D into Conv2Ds along temporal dimension
if data_torch.ndim == 5:
c1, c2, kt, kh, kw = data_torch.shape
del c1, c2, kh, kw
if kt != 2:
raise ValueError("Current implementation only supports temporal_patch_size of 2")
return [
(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".weight", data_torch[:, :, 0, ...]),
(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".weight.1", data_torch[:, :, 1, ...]),
]
return [(self.map_tensor_name(name), data_torch)]
if name == "visual.patch_embed.proj.bias":
return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".bias", data_torch)]
# Default handling for other visual tensors
return [(self.map_tensor_name(name), data_torch)]
# Fall back to parent for any other tensors
return super().modify_tensors(data_torch, name, bid)
@ModelBase.register("Qwen3OmniMoeForConditionalGeneration")
class Qwen3OmniMoeTextModel(Qwen3MoeModel):
"""Qwen3-Omni MoE text model converter.
Converts the text model (thinker.model.*) from Qwen3-Omni to GGUF format.
The audio and vision encoders are handled by Qwen3OmniModel (mmproj converter).
Key differences from Qwen3VLMoeTextModel:
- Tensor prefix is thinker.model.* (not model.*)
- Must skip: thinker.audio_tower, thinker.visual, talker, code2wav
- Config structure: thinker_config.text_config (handled by load_hparams)
"""
model_arch = gguf.MODEL_ARCH.QWEN3OMNI
def set_gguf_parameters(self):
super().set_gguf_parameters()
# Handle MRoPE (Multi-axis Rotary Position Embedding) for Qwen3-Omni
# The text_config is already merged into hparams by load_hparams
rope_scaling = self.hparams.get("rope_scaling") or self.hparams.get("rope_parameters") or {}
if rope_scaling.get("mrope_section"):
# mrope_section contains [time, height, width] dimensions
mrope_section = rope_scaling["mrope_section"]
# Pad to 4 dimensions [time, height, width, extra]
while len(mrope_section) < 4:
mrope_section.append(0)
self.gguf_writer.add_rope_dimension_sections(mrope_section[:4])
logger.info(f"MRoPE sections: {mrope_section[:4]}")
# Get vision config for deepstack layers (from thinker_config in hparams)
thinker_config = self.hparams.get("thinker_config", {})
vision_config = thinker_config.get("vision_config", {})
deepstack_layer_num = len(vision_config.get("deepstack_visual_indexes", []))
if deepstack_layer_num > 0:
self.gguf_writer.add_num_deepstack_layers(deepstack_layer_num)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# Skip multimodal encoders - they go in the mmproj file
if name.startswith("thinker.audio_tower."):
return []
if name.startswith("thinker.visual."):
return []
# Skip talker (speech synthesis) and code2wav (audio generation) - not needed for text inference
if name.startswith("talker."):
return []
if name.startswith("code2wav."):
return []
# Strip thinker prefix to get standard tensor names
# Original names:
# thinker.model.layers.* -> model.layers.*
# thinker.model.embed_tokens.* -> model.embed_tokens.*
# thinker.model.norm.* -> model.norm.*
# thinker.lm_head.* -> lm_head.* (NOT model.lm_head!)
if name.startswith("thinker.model."):
name = name.replace("thinker.model.", "model.", 1)
elif name.startswith("thinker."):
# Handle other thinker tensors (lm_head, etc.) - just strip thinker.
name = name.replace("thinker.", "", 1)
return super().modify_tensors(data_torch, name, bid)
@ModelBase.register("GPT2LMHeadModel")
class GPT2Model(TextModel):
model_arch = gguf.MODEL_ARCH.GPT2

View File

@ -372,6 +372,7 @@ class MODEL_ARCH(IntEnum):
QWEN3NEXT = auto()
QWEN3VL = auto()
QWEN3VLMOE = auto()
QWEN3OMNI = auto()
PHI2 = auto()
PHI3 = auto()
PHIMOE = auto()
@ -769,6 +770,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.QWEN3NEXT: "qwen3next",
MODEL_ARCH.QWEN3VL: "qwen3vl",
MODEL_ARCH.QWEN3VLMOE: "qwen3vlmoe",
MODEL_ARCH.QWEN3OMNI: "qwen3omni",
MODEL_ARCH.PHI2: "phi2",
MODEL_ARCH.PHI3: "phi3",
MODEL_ARCH.PHIMOE: "phimoe",
@ -1720,6 +1722,23 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
],
MODEL_ARCH.QWEN3OMNI: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
],
MODEL_ARCH.PLAMO: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
@ -3485,6 +3504,7 @@ class VisionProjectorType:
QWEN2A = "qwen2a" # audio
GLMA = "glma" # audio
QWEN25O = "qwen2.5o" # omni
QWEN3O = "qwen3o" # qwen3-omni
VOXTRAL = "voxtral"
LFM2 = "lfm2"
KIMIVL = "kimivl"

View File

@ -37,6 +37,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_QWEN3NEXT, "qwen3next" },
{ LLM_ARCH_QWEN3VL, "qwen3vl" },
{ LLM_ARCH_QWEN3VLMOE, "qwen3vlmoe" },
{ LLM_ARCH_QWEN3OMNI, "qwen3omni" },
{ LLM_ARCH_PHI2, "phi2" },
{ LLM_ARCH_PHI3, "phi3" },
{ LLM_ARCH_PHIMOE, "phimoe" },
@ -915,6 +916,7 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
};
case LLM_ARCH_QWEN3MOE:
case LLM_ARCH_QWEN3VLMOE:
case LLM_ARCH_QWEN3OMNI:
case LLM_ARCH_OLMOE:
case LLM_ARCH_LLADA_MOE:
case LLM_ARCH_RND1:

View File

@ -41,6 +41,7 @@ enum llm_arch {
LLM_ARCH_QWEN3NEXT,
LLM_ARCH_QWEN3VL,
LLM_ARCH_QWEN3VLMOE,
LLM_ARCH_QWEN3OMNI,
LLM_ARCH_PHI2,
LLM_ARCH_PHI3,
LLM_ARCH_PHIMOE,

View File

@ -1144,6 +1144,16 @@ void llama_model::load_hparams(llama_model_loader & ml) {
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_QWEN3OMNI:
{
ml.get_key(LLM_KV_NUM_DEEPSTACK_LAYERS, hparams.n_deepstack_layers, false);
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
switch (hparams.n_layer) {
case 48: type = LLM_TYPE_30B_A3B; break;
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_PHI2:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@ -3598,6 +3608,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
} break;
case LLM_ARCH_QWEN3MOE:
case LLM_ARCH_QWEN3VLMOE:
case LLM_ARCH_QWEN3OMNI:
case LLM_ARCH_RND1:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@ -7115,7 +7126,7 @@ void llama_model::print_info() const {
LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp);
}
if (arch == LLM_ARCH_QWEN3MOE || arch == LLM_ARCH_OPENAI_MOE || arch == LLM_ARCH_QWEN3VLMOE || arch == LLM_ARCH_RND1) {
if (arch == LLM_ARCH_QWEN3MOE || arch == LLM_ARCH_OPENAI_MOE || arch == LLM_ARCH_QWEN3VLMOE || arch == LLM_ARCH_QWEN3OMNI || arch == LLM_ARCH_RND1) {
LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
}
@ -7510,6 +7521,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
{
llm = std::make_unique<llm_build_qwen3vlmoe>(*this, params);
} break;
case LLM_ARCH_QWEN3OMNI:
{
llm = std::make_unique<llm_build_qwen3moe>(*this, params);
} break;
case LLM_ARCH_PHI2:
{
llm = std::make_unique<llm_build_phi2>(*this, params);
@ -8081,6 +8096,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
return LLAMA_ROPE_TYPE_MROPE;
case LLM_ARCH_QWEN3VL:
case LLM_ARCH_QWEN3VLMOE:
case LLM_ARCH_QWEN3OMNI:
return LLAMA_ROPE_TYPE_IMROPE;
case LLM_ARCH_GLM4:

View File

@ -169,6 +169,7 @@ enum projector_type {
PROJECTOR_TYPE_GLM_EDGE,
PROJECTOR_TYPE_QWEN2VL,
PROJECTOR_TYPE_QWEN3VL,
PROJECTOR_TYPE_QWEN3O, // qwen3-omni: converts to QWEN3VL for vision, uses custom encoder for audio
PROJECTOR_TYPE_GEMMA3,
PROJECTOR_TYPE_IDEFICS3,
PROJECTOR_TYPE_PIXTRAL,
@ -199,6 +200,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
{ PROJECTOR_TYPE_QWEN2VL, "qwen2vl_merger"},
{ PROJECTOR_TYPE_QWEN25VL, "qwen2.5vl_merger"},
{ PROJECTOR_TYPE_QWEN3VL, "qwen3vl_merger"},
{ PROJECTOR_TYPE_QWEN3O, "qwen3o"},
{ PROJECTOR_TYPE_GEMMA3, "gemma3"},
{ PROJECTOR_TYPE_IDEFICS3, "idefics3"},
{ PROJECTOR_TYPE_PIXTRAL, "pixtral"},

View File

@ -970,6 +970,18 @@ struct clip_model_loader {
? PROJECTOR_TYPE_QWEN25VL
: PROJECTOR_TYPE_QWEN2A;
}
// Qwen3-Omni: vision uses qwen3vl pipeline, audio stays qwen3o
if (model.proj_type == PROJECTOR_TYPE_QWEN3O) {
projector_type new_type = modality == CLIP_MODALITY_VISION
? PROJECTOR_TYPE_QWEN3VL
: PROJECTOR_TYPE_QWEN3O;
LOG_INF("%s: QWEN3O auto-conversion: %s -> %s (modality=%s)\n", __func__,
PROJECTOR_TYPE_NAMES[model.proj_type].c_str(),
PROJECTOR_TYPE_NAMES[new_type].c_str(),
modality == CLIP_MODALITY_VISION ? "vision" : "audio");
model.proj_type = new_type;
}
}
const bool is_vision = model.modality == CLIP_MODALITY_VISION;

View File

@ -85,23 +85,43 @@ ggml_cgraph * clip_graph_qwen3vl::build() {
// self-attention
{
cur = ggml_mul_mat(ctx0, layer.qkv_w, cur);
cur = ggml_add(ctx0, cur, layer.qkv_b);
// Support both separate Q/K/V (Qwen3-Omni) and combined QKV (Qwen3-VL)
ggml_tensor * Qcur;
ggml_tensor * Kcur;
ggml_tensor * Vcur;
ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos,
/* nb1 */ ggml_row_size(cur->type, d_head),
/* nb2 */ cur->nb[1],
/* offset */ 0);
if (layer.qkv_w) {
// Combined QKV format
cur = ggml_mul_mat(ctx0, layer.qkv_w, cur);
cur = ggml_add(ctx0, cur, layer.qkv_b);
ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos,
/* nb1 */ ggml_row_size(cur->type, d_head),
/* nb2 */ cur->nb[1],
/* offset */ ggml_row_size(cur->type, n_embd));
Qcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos,
/* nb1 */ ggml_row_size(cur->type, d_head),
/* nb2 */ cur->nb[1],
/* offset */ 0);
ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos,
/* nb1 */ ggml_row_size(cur->type, d_head),
/* nb2 */ cur->nb[1],
/* offset */ ggml_row_size(cur->type, 2 * n_embd));
Kcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos,
/* nb1 */ ggml_row_size(cur->type, d_head),
/* nb2 */ cur->nb[1],
/* offset */ ggml_row_size(cur->type, n_embd));
Vcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos,
/* nb1 */ ggml_row_size(cur->type, d_head),
/* nb2 */ cur->nb[1],
/* offset */ ggml_row_size(cur->type, 2 * n_embd));
} else {
// Separate Q/K/V format (like Qwen3-Omni)
Qcur = ggml_add(ctx0,
ggml_mul_mat(ctx0, layer.q_w, cur), layer.q_b);
Kcur = ggml_add(ctx0,
ggml_mul_mat(ctx0, layer.k_w, cur), layer.k_b);
Vcur = ggml_add(ctx0,
ggml_mul_mat(ctx0, layer.v_w, cur), layer.v_b);
Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_pos);
Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_pos);
Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, n_pos);
}
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);

View File

@ -188,7 +188,15 @@ struct mtmd_context {
}
// if both vision and audio mmproj are present, we need to validate their n_embd
if (ctx_v && ctx_a) {
// Note: QWEN3O has different embedding dimensions for vision and audio, which is valid
// - Vision uses deepstack, so n_embd_v = n_embd * (1 + n_deepstack_layers) = 8192
// - Audio doesn't use deepstack, so n_embd_a = projection_dim = 2048
projector_type proj_v = ctx_v ? clip_get_projector_type(ctx_v) : PROJECTOR_TYPE_UNKNOWN;
projector_type proj_a = ctx_a ? clip_get_projector_type(ctx_a) : PROJECTOR_TYPE_UNKNOWN;
bool is_qwen3o = (proj_v == PROJECTOR_TYPE_QWEN3VL || proj_v == PROJECTOR_TYPE_QWEN3O) &&
(proj_a == PROJECTOR_TYPE_QWEN3O);
if (ctx_v && ctx_a && !is_qwen3o) {
int n_embd_v = clip_n_mmproj_embd(ctx_v);
int n_embd_a = clip_n_mmproj_embd(ctx_a);
if (n_embd_v != n_embd_a) {
@ -198,9 +206,9 @@ struct mtmd_context {
}
}
// since we already validate n_embd of vision and audio mmproj,
// we can safely assume that they are the same
int n_embd_clip = clip_n_mmproj_embd(ctx_v ? ctx_v : ctx_a);
// For QWEN3O, use vision embedding dimension (includes deepstack) for validation
// For other models, vision and audio should have same embedding dimension
int n_embd_clip = is_qwen3o ? clip_n_mmproj_embd(ctx_v) : clip_n_mmproj_embd(ctx_v ? ctx_v : ctx_a);
if (n_embd_text != n_embd_clip) {
throw std::runtime_error(string_format(
"mismatch between text model (n_embd = %d) and mmproj (n_embd = %d)\n"