model: add support for qwen3vl series (#16780)
* support qwen3vl series.
Co-authored-by: Thireus ☠ <Thireus@users.noreply.github.com>
Co-authored-by: yairpatch <yairpatch@users.noreply.github.com>
Co-authored-by: LETS-BEE <LETS-BEE@users.noreply.github.com>
* bugfix: fix the arch check for qwen3vl-moe.
* use build_ffn
* optimize deepstack structure
* optimize deepstack feature saving
* Revert "optimize deepstack feature saving" for temporal fix
This reverts commit f321b9fdf1.
* code clean
* use fused qkv in clip
* clean up / rm is_deepstack_layers for simplification
* add test model
* move test model to "big" section
* fix imrope check
* remove trailing whitespace
* fix rope fail
* metal : add imrope support
* add imrope support for sycl
* vulkan: add imrope w/o check
* fix vulkan
* webgpu: add imrope w/o check
* Update gguf-py/gguf/tensor_mapping.py
Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
* fix tensor mapping
---------
Co-authored-by: Thireus ☠ <Thireus@users.noreply.github.com>
Co-authored-by: yairpatch <yairpatch@users.noreply.github.com>
Co-authored-by: LETS-BEE <LETS-BEE@users.noreply.github.com>
Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
This commit is contained in:
parent
dcca0d3ab8
commit
d261223d24
|
|
@ -3852,7 +3852,43 @@ class Qwen2MoeModel(TextModel):
|
||||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||||
# process the experts separately
|
# process the experts separately
|
||||||
name = name.replace("language_model.", "") # InternVL
|
name = name.replace("language_model.", "") # InternVL
|
||||||
if name.startswith("mlp") or name.startswith("vision_model") or name.startswith("model.vision_tower") or name.startswith("model.multi_modal_projector"):
|
|
||||||
|
# handle aggregated expert tensors
|
||||||
|
# GGUF stores dimensions reversed from PyTorch, so:
|
||||||
|
# PyTorch (A,B,C) -> GGUF writes [C,B,A] -> GGML reads ne={C,B,A}
|
||||||
|
# Input shapes from HF: (n_expert, n_ff_exp, n_embd) or (n_expert, n_embd, n_ff_exp)
|
||||||
|
# Expected GGML ne: {n_embd, n_ff_exp, n_expert} for gate/up, {n_ff_exp, n_embd, n_expert} for down
|
||||||
|
if name.endswith("mlp.experts.down_proj") or name.endswith("mlp.experts.down_proj.weight"):
|
||||||
|
mapped = f"{name}.weight" if not name.endswith(".weight") else name
|
||||||
|
# Input: (n_expert=128, n_ff_exp=768, n_embd=2048)
|
||||||
|
# Want GGML ne: {n_ff_exp, n_embd, n_expert} = {768, 2048, 128}
|
||||||
|
# Need PyTorch: (128, 2048, 768) [reversed of GGML]
|
||||||
|
# So: permute(0, 2, 1): (128, 768, 2048) -> (128, 2048, 768)
|
||||||
|
permuted = data_torch.permute(0, 2, 1).contiguous()
|
||||||
|
return [(self.map_tensor_name(mapped), permuted)]
|
||||||
|
|
||||||
|
if name.endswith("mlp.experts.gate_up_proj") or name.endswith("mlp.experts.gate_up_proj.weight"):
|
||||||
|
if data_torch.ndim < 3 or data_torch.shape[-1] % 2 != 0:
|
||||||
|
raise ValueError(f"Unexpected gate_up_proj shape for {name}: {tuple(data_torch.shape)}")
|
||||||
|
split_dim = data_torch.shape[-1] // 2
|
||||||
|
gate = data_torch[..., :split_dim].contiguous()
|
||||||
|
up = data_torch[..., split_dim:].contiguous()
|
||||||
|
# Input gate/up: (n_expert=128, n_embd=2048, n_ff_exp=768)
|
||||||
|
# Want GGML ne: {n_embd, n_ff_exp, n_expert} = {2048, 768, 128}
|
||||||
|
# Need PyTorch: (128, 768, 2048) [reversed of GGML]
|
||||||
|
# So: permute(0, 2, 1): (128, 2048, 768) -> (128, 768, 2048)
|
||||||
|
base_name = name.removesuffix(".weight")
|
||||||
|
base = base_name.rsplit('.', 1)[0]
|
||||||
|
mapped_gate = f"{base}.gate_proj.weight"
|
||||||
|
mapped_up = f"{base}.up_proj.weight"
|
||||||
|
perm_gate = gate.permute(0, 2, 1).contiguous()
|
||||||
|
perm_up = up.permute(0, 2, 1).contiguous()
|
||||||
|
return [
|
||||||
|
(self.map_tensor_name(mapped_gate), perm_gate),
|
||||||
|
(self.map_tensor_name(mapped_up), perm_up),
|
||||||
|
]
|
||||||
|
|
||||||
|
if name.startswith("mlp") or name.startswith("vision_model") or name.startswith("model.vision_tower") or name.startswith("model.multi_modal_projector") or name.startswith("model.visual"):
|
||||||
# skip visual tensors
|
# skip visual tensors
|
||||||
return []
|
return []
|
||||||
if name.find("experts") != -1:
|
if name.find("experts") != -1:
|
||||||
|
|
@ -4004,6 +4040,187 @@ class Qwen3MoeModel(Qwen2MoeModel):
|
||||||
super().set_vocab()
|
super().set_vocab()
|
||||||
|
|
||||||
|
|
||||||
|
@ModelBase.register("Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration")
|
||||||
|
class Qwen3VLVisionModel(MmprojModel):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
assert self.hparams_vision is not None
|
||||||
|
# Compute image_size if not present
|
||||||
|
if "image_size" not in self.hparams_vision:
|
||||||
|
# For Qwen3VL/Qwen3VLMoe, compute from num_position_embeddings
|
||||||
|
num_pos = self.hparams_vision.get("num_position_embeddings", 2304)
|
||||||
|
patch_size = self.hparams_vision.get("patch_size", 16)
|
||||||
|
# num_position_embeddings = (image_size / patch_size) ** 2
|
||||||
|
# So image_size = sqrt(num_position_embeddings) * patch_size
|
||||||
|
image_size = int(num_pos**0.5 * patch_size)
|
||||||
|
self.hparams_vision["image_size"] = image_size
|
||||||
|
|
||||||
|
# Rename config values for compatibility
|
||||||
|
self.hparams_vision["num_attention_heads"] = self.hparams_vision.get("num_heads")
|
||||||
|
self.hparams_vision["num_hidden_layers"] = self.hparams_vision.get("depth")
|
||||||
|
|
||||||
|
self.is_deepstack_layers = [False] * int(self.hparams_vision["num_hidden_layers"] or 0)
|
||||||
|
for idx in self.hparams_vision.get("deepstack_visual_indexes", []):
|
||||||
|
self.is_deepstack_layers[idx] = True
|
||||||
|
|
||||||
|
def set_gguf_parameters(self):
|
||||||
|
super().set_gguf_parameters()
|
||||||
|
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN3VL)
|
||||||
|
self.gguf_writer.add_vision_use_gelu(True)
|
||||||
|
|
||||||
|
if self.hparams_vision is not None:
|
||||||
|
merge_size = self.hparams_vision.get("spatial_merge_size")
|
||||||
|
if merge_size is not None:
|
||||||
|
self.gguf_writer.add_vision_spatial_merge_size(int(merge_size))
|
||||||
|
|
||||||
|
# Use text config's rms_norm_eps for vision attention layernorm eps
|
||||||
|
rms_norm_eps = self.global_config.get("text_config", {}).get("rms_norm_eps", 1e-6)
|
||||||
|
self.gguf_writer.add_vision_attention_layernorm_eps(rms_norm_eps)
|
||||||
|
|
||||||
|
if self.is_deepstack_layers:
|
||||||
|
self.gguf_writer.add_vision_is_deepstack_layers(self.is_deepstack_layers)
|
||||||
|
|
||||||
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||||
|
assert self.hparams_vision is not None
|
||||||
|
# Skip text model tensors - they go in the text model file
|
||||||
|
if name.startswith("model.language_model.") or name.startswith("lm_head."):
|
||||||
|
return []
|
||||||
|
|
||||||
|
if name.startswith("model.visual."):
|
||||||
|
name = name.replace("model.visual.", "visual.", 1)
|
||||||
|
|
||||||
|
if name.startswith("visual.deepstack_merger_list."):
|
||||||
|
prefix, rest = name.split(".", maxsplit=3)[2:]
|
||||||
|
# prefix is the layer index, convert to absolute clip layer index!
|
||||||
|
idx = self.hparams_vision.get("deepstack_visual_indexes", [])[int(prefix)]
|
||||||
|
target = rest
|
||||||
|
|
||||||
|
tensor_type: gguf.MODEL_TENSOR
|
||||||
|
if target.startswith("norm."):
|
||||||
|
tensor_type = gguf.MODEL_TENSOR.V_DS_NORM
|
||||||
|
suffix = target.split(".", 1)[1]
|
||||||
|
elif target.startswith("linear_fc1."):
|
||||||
|
tensor_type = gguf.MODEL_TENSOR.V_DS_FC1
|
||||||
|
suffix = target.split(".", 1)[1]
|
||||||
|
elif target.startswith("linear_fc2."):
|
||||||
|
tensor_type = gguf.MODEL_TENSOR.V_DS_FC2
|
||||||
|
suffix = target.split(".", 1)[1]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unexpected deepstack tensor: {name}")
|
||||||
|
|
||||||
|
new_name = self.format_tensor_name(tensor_type, idx, suffix=f".{suffix}")
|
||||||
|
return [(new_name, data_torch)]
|
||||||
|
|
||||||
|
if name.startswith("visual.merger."):
|
||||||
|
suffix = name.split(".", 2)[2]
|
||||||
|
if suffix.startswith("linear_fc"):
|
||||||
|
fc_idx_str, tail = suffix.split(".", 1)
|
||||||
|
fc_num = int(fc_idx_str.replace("linear_fc", ""))
|
||||||
|
# Qwen3VL has linear_fc1 and linear_fc2
|
||||||
|
# Map to indices 0 and 2 (matching Qwen2VL which uses indices 0 and 2)
|
||||||
|
if fc_num == 1:
|
||||||
|
fc_idx = 0
|
||||||
|
elif fc_num == 2:
|
||||||
|
fc_idx = 2
|
||||||
|
else:
|
||||||
|
raise ValueError(f"unexpected fc index {fc_num} in {name}")
|
||||||
|
new_name = self.format_tensor_name(gguf.MODEL_TENSOR.V_MMPROJ, fc_idx, suffix=f".{tail}")
|
||||||
|
elif suffix.startswith("norm."):
|
||||||
|
new_name = self.format_tensor_name(gguf.MODEL_TENSOR.V_POST_NORM, suffix=f".{suffix.split('.', 1)[1]}")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unexpected merger tensor: {name}")
|
||||||
|
return [(new_name, data_torch)]
|
||||||
|
|
||||||
|
if name == "visual.patch_embed.proj.weight":
|
||||||
|
# split Conv3D into Conv2Ds along temporal dimension
|
||||||
|
c1, c2, kt, _, _ = data_torch.shape
|
||||||
|
del c1, c2
|
||||||
|
if kt != 2:
|
||||||
|
raise ValueError("Current implementation only supports temporal_patch_size of 2")
|
||||||
|
return [
|
||||||
|
(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".weight", data_torch[:, :, 0, ...]),
|
||||||
|
(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".weight.1", data_torch[:, :, 1, ...]),
|
||||||
|
]
|
||||||
|
|
||||||
|
if name == "visual.patch_embed.proj.bias":
|
||||||
|
# Include the bias - it's used by the C++ code
|
||||||
|
return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".bias", data_torch)]
|
||||||
|
|
||||||
|
if name.startswith("visual."):
|
||||||
|
return [(self.map_tensor_name(name), data_torch)]
|
||||||
|
|
||||||
|
# Fall back to parent class for other tensors
|
||||||
|
return super().modify_tensors(data_torch, name, bid)
|
||||||
|
|
||||||
|
|
||||||
|
@ModelBase.register("Qwen3VLForConditionalGeneration")
|
||||||
|
class Qwen3VLTextModel(Qwen3Model):
|
||||||
|
model_arch = gguf.MODEL_ARCH.QWEN3VL
|
||||||
|
|
||||||
|
def set_gguf_parameters(self):
|
||||||
|
super().set_gguf_parameters()
|
||||||
|
|
||||||
|
# Handle MRoPE (Multi-axis Rotary Position Embedding) for Qwen3-VL
|
||||||
|
text_config = self.hparams.get("text_config", {})
|
||||||
|
# rope_scaling is deprecated in V5, use rope_parameters instead
|
||||||
|
rope_scaling = text_config.get("rope_scaling") or text_config.get("rope_parameters") or {}
|
||||||
|
|
||||||
|
if rope_scaling.get("mrope_section"):
|
||||||
|
# mrope_section contains [time, height, width] dimensions
|
||||||
|
mrope_section = rope_scaling["mrope_section"]
|
||||||
|
# Pad to 4 dimensions [time, height, width, extra]
|
||||||
|
while len(mrope_section) < 4:
|
||||||
|
mrope_section.append(0)
|
||||||
|
self.gguf_writer.add_rope_dimension_sections(mrope_section[:4])
|
||||||
|
|
||||||
|
logger.info(f"MRoPE sections: {mrope_section[:4]}")
|
||||||
|
|
||||||
|
vision_config = self.hparams.get("vision_config", {})
|
||||||
|
deepstack_layer_num = len(vision_config.get("deepstack_visual_indexes", []))
|
||||||
|
self.gguf_writer.add_num_deepstack_layers(deepstack_layer_num)
|
||||||
|
|
||||||
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||||
|
# Skip vision tensors - they go in the mmproj file
|
||||||
|
if name.startswith("model.visual."):
|
||||||
|
return []
|
||||||
|
|
||||||
|
return super().modify_tensors(data_torch, name, bid)
|
||||||
|
|
||||||
|
|
||||||
|
@ModelBase.register("Qwen3VLMoeForConditionalGeneration")
|
||||||
|
class Qwen3VLMoeTextModel(Qwen3MoeModel):
|
||||||
|
model_arch = gguf.MODEL_ARCH.QWEN3VLMOE
|
||||||
|
|
||||||
|
def set_gguf_parameters(self):
|
||||||
|
super().set_gguf_parameters()
|
||||||
|
|
||||||
|
# Handle MRoPE (Multi-axis Rotary Position Embedding) for Qwen3-VL
|
||||||
|
text_config = self.hparams.get("text_config", {})
|
||||||
|
# rope_scaling is deprecated in V5, use rope_parameters instead
|
||||||
|
rope_scaling = text_config.get("rope_scaling") or text_config.get("rope_parameters") or {}
|
||||||
|
|
||||||
|
if rope_scaling.get("mrope_section"):
|
||||||
|
# mrope_section contains [time, height, width] dimensions
|
||||||
|
mrope_section = rope_scaling["mrope_section"]
|
||||||
|
# Pad to 4 dimensions [time, height, width, extra]
|
||||||
|
while len(mrope_section) < 4:
|
||||||
|
mrope_section.append(0)
|
||||||
|
self.gguf_writer.add_rope_dimension_sections(mrope_section[:4])
|
||||||
|
|
||||||
|
logger.info(f"MRoPE sections: {mrope_section[:4]}")
|
||||||
|
|
||||||
|
vision_config = self.hparams.get("vision_config", {})
|
||||||
|
deepstack_layer_num = len(vision_config.get("deepstack_visual_indexes", []))
|
||||||
|
self.gguf_writer.add_num_deepstack_layers(deepstack_layer_num)
|
||||||
|
|
||||||
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||||
|
# Skip vision tensors - they go in the mmproj file
|
||||||
|
if name.startswith("model.visual."):
|
||||||
|
return []
|
||||||
|
|
||||||
|
return super().modify_tensors(data_torch, name, bid)
|
||||||
|
|
||||||
|
|
||||||
@ModelBase.register("GPT2LMHeadModel")
|
@ModelBase.register("GPT2LMHeadModel")
|
||||||
class GPT2Model(TextModel):
|
class GPT2Model(TextModel):
|
||||||
model_arch = gguf.MODEL_ARCH.GPT2
|
model_arch = gguf.MODEL_ARCH.GPT2
|
||||||
|
|
|
||||||
|
|
@ -242,6 +242,7 @@
|
||||||
#define GGML_ROPE_TYPE_NEOX 2
|
#define GGML_ROPE_TYPE_NEOX 2
|
||||||
#define GGML_ROPE_TYPE_MROPE 8
|
#define GGML_ROPE_TYPE_MROPE 8
|
||||||
#define GGML_ROPE_TYPE_VISION 24
|
#define GGML_ROPE_TYPE_VISION 24
|
||||||
|
#define GGML_ROPE_TYPE_IMROPE 40 // binary: 101000
|
||||||
|
|
||||||
#define GGML_MROPE_SECTIONS 4
|
#define GGML_MROPE_SECTIONS 4
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5474,7 +5474,7 @@ static void ggml_rope_cache_init(
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_mrope_cache_init(
|
static void ggml_mrope_cache_init(
|
||||||
float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool indep_sects,
|
float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool is_imrope, bool indep_sects,
|
||||||
float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
|
float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
|
||||||
float * cache, float sin_sign, float theta_scale) {
|
float * cache, float sin_sign, float theta_scale) {
|
||||||
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
|
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
|
||||||
|
|
@ -5509,6 +5509,17 @@ static void ggml_mrope_cache_init(
|
||||||
}
|
}
|
||||||
|
|
||||||
float theta = theta_t;
|
float theta = theta_t;
|
||||||
|
if (is_imrope) { // qwen3vl apply interleaved mrope
|
||||||
|
if (sector % 3 == 1 && sector < 3 * sections[1]) {
|
||||||
|
theta = theta_h;
|
||||||
|
} else if (sector % 3 == 2 && sector < 3 * sections[2]) {
|
||||||
|
theta = theta_w;
|
||||||
|
} else if (sector % 3 == 0 && sector < 3 * sections[0]) {
|
||||||
|
theta = theta_t;
|
||||||
|
} else {
|
||||||
|
theta = theta_e;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
if (sector >= sections[0] && sector < sec_w) {
|
if (sector >= sections[0] && sector < sec_w) {
|
||||||
theta = theta_h;
|
theta = theta_h;
|
||||||
}
|
}
|
||||||
|
|
@ -5518,6 +5529,7 @@ static void ggml_mrope_cache_init(
|
||||||
else if (sector >= sec_w + sections[2]) {
|
else if (sector >= sec_w + sections[2]) {
|
||||||
theta = theta_e;
|
theta = theta_e;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
rope_yarn(
|
rope_yarn(
|
||||||
theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
|
theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
|
||||||
|
|
@ -5589,6 +5601,7 @@ static void ggml_compute_forward_rope_f32(
|
||||||
|
|
||||||
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
||||||
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, multimodal rotary position embedding
|
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, multimodal rotary position embedding
|
||||||
|
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope
|
||||||
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
||||||
|
|
||||||
if (is_mrope) {
|
if (is_mrope) {
|
||||||
|
|
@ -5627,7 +5640,7 @@ static void ggml_compute_forward_rope_f32(
|
||||||
const int64_t p_w = pos[i2 + ne2 * 2];
|
const int64_t p_w = pos[i2 + ne2 * 2];
|
||||||
const int64_t p_e = pos[i2 + ne2 * 3];
|
const int64_t p_e = pos[i2 + ne2 * 3];
|
||||||
ggml_mrope_cache_init(
|
ggml_mrope_cache_init(
|
||||||
p_t, p_h, p_w, p_e, sections, is_vision,
|
p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,
|
||||||
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -5775,6 +5788,7 @@ static void ggml_compute_forward_rope_f16(
|
||||||
|
|
||||||
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
||||||
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
|
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
|
||||||
|
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
|
||||||
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
||||||
|
|
||||||
if (is_mrope) {
|
if (is_mrope) {
|
||||||
|
|
@ -5813,7 +5827,7 @@ static void ggml_compute_forward_rope_f16(
|
||||||
const int64_t p_w = pos[i2 + ne2 * 2];
|
const int64_t p_w = pos[i2 + ne2 * 2];
|
||||||
const int64_t p_e = pos[i2 + ne2 * 3];
|
const int64_t p_e = pos[i2 + ne2 * 3];
|
||||||
ggml_mrope_cache_init(
|
ggml_mrope_cache_init(
|
||||||
p_t, p_h, p_w, p_e, sections, is_vision,
|
p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,
|
||||||
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -125,7 +125,7 @@ template<bool forward, bool has_ff, typename T>
|
||||||
static __global__ void rope_multi(
|
static __global__ void rope_multi(
|
||||||
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2,
|
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2,
|
||||||
const int n_dims, const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
|
const int n_dims, const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
|
||||||
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections) {
|
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections, const bool is_imrope) {
|
||||||
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
||||||
|
|
||||||
if (i0 >= ne0) {
|
if (i0 >= ne0) {
|
||||||
|
|
@ -152,6 +152,17 @@ static __global__ void rope_multi(
|
||||||
const int sector = (i0 / 2) % sect_dims;
|
const int sector = (i0 / 2) % sect_dims;
|
||||||
|
|
||||||
float theta_base = 0.0;
|
float theta_base = 0.0;
|
||||||
|
if (is_imrope) {
|
||||||
|
if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h
|
||||||
|
theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f);
|
||||||
|
} else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w
|
||||||
|
theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f);
|
||||||
|
} else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t
|
||||||
|
theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
|
||||||
|
} else {
|
||||||
|
theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
if (sector < sections.v[0]) {
|
if (sector < sections.v[0]) {
|
||||||
theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
|
theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
|
||||||
}
|
}
|
||||||
|
|
@ -164,6 +175,7 @@ static __global__ void rope_multi(
|
||||||
else if (sector >= sec_w + sections.v[2]) {
|
else if (sector >= sec_w + sections.v[2]) {
|
||||||
theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f);
|
theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
||||||
|
|
||||||
|
|
@ -276,7 +288,7 @@ template<bool forward, typename T>
|
||||||
static void rope_multi_cuda(
|
static void rope_multi_cuda(
|
||||||
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
|
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
|
||||||
const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
|
const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
|
||||||
const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) {
|
const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, const bool is_imrope, cudaStream_t stream) {
|
||||||
GGML_ASSERT(ne0 % 2 == 0);
|
GGML_ASSERT(ne0 % 2 == 0);
|
||||||
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
||||||
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
||||||
|
|
@ -287,11 +299,11 @@ static void rope_multi_cuda(
|
||||||
if (freq_factors == nullptr) {
|
if (freq_factors == nullptr) {
|
||||||
rope_multi<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
|
rope_multi<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
|
||||||
x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
||||||
attn_factor, corr_dims, theta_scale, freq_factors, sections);
|
attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope);
|
||||||
} else {
|
} else {
|
||||||
rope_multi<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
|
rope_multi<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
|
||||||
x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
||||||
attn_factor, corr_dims, theta_scale, freq_factors, sections);
|
attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -369,6 +381,7 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
|
||||||
|
|
||||||
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
||||||
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
|
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
|
||||||
|
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
|
||||||
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
||||||
|
|
||||||
if (is_mrope) {
|
if (is_mrope) {
|
||||||
|
|
@ -406,11 +419,11 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
|
||||||
if (src0->type == GGML_TYPE_F32) {
|
if (src0->type == GGML_TYPE_F32) {
|
||||||
rope_multi_cuda<forward>(
|
rope_multi_cuda<forward>(
|
||||||
(const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
|
(const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
|
||||||
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
|
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream);
|
||||||
} else if (src0->type == GGML_TYPE_F16) {
|
} else if (src0->type == GGML_TYPE_F16) {
|
||||||
rope_multi_cuda<forward>(
|
rope_multi_cuda<forward>(
|
||||||
(const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
|
(const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
|
||||||
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
|
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream);
|
||||||
} else {
|
} else {
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1332,11 +1332,12 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope(ggml_metal_library_t
|
||||||
|
|
||||||
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
||||||
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
|
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
|
||||||
|
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
|
||||||
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
||||||
|
|
||||||
if (is_neox) {
|
if (is_neox) {
|
||||||
snprintf(base, 256, "kernel_rope_neox_%s", ggml_type_name(op->src[0]->type));
|
snprintf(base, 256, "kernel_rope_neox_%s", ggml_type_name(op->src[0]->type));
|
||||||
} else if (is_mrope && !is_vision) {
|
} else if ((is_mrope || is_imrope) && !is_vision) {
|
||||||
GGML_ASSERT(op->src[1]->ne[0]*4 >= op->src[0]->ne[2]); // need at least 4 pos per token
|
GGML_ASSERT(op->src[1]->ne[0]*4 >= op->src[0]->ne[2]); // need at least 4 pos per token
|
||||||
snprintf(base, 256, "kernel_rope_multi_%s", ggml_type_name(op->src[0]->type));
|
snprintf(base, 256, "kernel_rope_multi_%s", ggml_type_name(op->src[0]->type));
|
||||||
} else if (is_vision) {
|
} else if (is_vision) {
|
||||||
|
|
@ -1346,14 +1347,20 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope(ggml_metal_library_t
|
||||||
snprintf(base, 256, "kernel_rope_norm_%s", ggml_type_name(op->src[0]->type));
|
snprintf(base, 256, "kernel_rope_norm_%s", ggml_type_name(op->src[0]->type));
|
||||||
}
|
}
|
||||||
|
|
||||||
snprintf(name, 256, "%s", base);
|
snprintf(name, 256, "%s_imrope=%d", base, is_imrope ? 1 : 0);
|
||||||
|
|
||||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||||
if (res) {
|
if (res) {
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
||||||
|
|
||||||
|
ggml_metal_cv_set_bool(cv, is_imrope, FC_ROPE + 0);
|
||||||
|
|
||||||
|
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
||||||
|
|
||||||
|
ggml_metal_cv_free(cv);
|
||||||
|
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -76,6 +76,7 @@
|
||||||
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 500
|
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 500
|
||||||
#define FC_MUL_MV 600
|
#define FC_MUL_MV 600
|
||||||
#define FC_MUL_MM 700
|
#define FC_MUL_MM 700
|
||||||
|
#define FC_ROPE 800
|
||||||
|
|
||||||
// op-specific constants
|
// op-specific constants
|
||||||
#define OP_FLASH_ATTN_EXT_NQPTG 8
|
#define OP_FLASH_ATTN_EXT_NQPTG 8
|
||||||
|
|
|
||||||
|
|
@ -3709,6 +3709,8 @@ template [[host_name("kernel_mul_mv_bf16_f32_short")]] kernel mul_mv_t_t_short_
|
||||||
template [[host_name("kernel_mul_mv_bf16_bf16_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<bfloat, bfloat>;
|
template [[host_name("kernel_mul_mv_bf16_bf16_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<bfloat, bfloat>;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
constant bool FC_rope_is_imrope [[function_constant(FC_ROPE + 0)]];
|
||||||
|
|
||||||
static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
||||||
const float y = (i0 / 2 - low) / max(0.001f, high - low);
|
const float y = (i0 / 2 - low) / max(0.001f, high - low);
|
||||||
return 1.0f - min(1.0f, max(0.0f, y));
|
return 1.0f - min(1.0f, max(0.0f, y));
|
||||||
|
|
@ -3889,15 +3891,27 @@ kernel void kernel_rope_multi(
|
||||||
const int sector = ic % sect_dims;
|
const int sector = ic % sect_dims;
|
||||||
|
|
||||||
float theta_base;
|
float theta_base;
|
||||||
|
if (FC_rope_is_imrope) {
|
||||||
|
if (sector % 3 == 1 && sector < 3 * args.sect_1) { // h
|
||||||
|
theta_base = (float) pos[i2 + args.ne02 * 1];
|
||||||
|
} else if (sector % 3 == 2 && sector < 3 * args.sect_2) { // w
|
||||||
|
theta_base = (float) pos[i2 + args.ne02 * 2];
|
||||||
|
} else if (sector % 3 == 0 && sector < 3 * args.sect_0) { // t
|
||||||
|
theta_base = (float) pos[i2 + args.ne02 * 0];
|
||||||
|
} else { // e
|
||||||
|
theta_base = (float) pos[i2 + args.ne02 * 3];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
if (sector < args.sect_0) {
|
if (sector < args.sect_0) {
|
||||||
theta_base = (float) pos[i2];
|
theta_base = (float) pos[i2];
|
||||||
} else if (sector < sec_w01) {
|
} else if (sector < sec_w01) {
|
||||||
theta_base = (float) pos[i2 + args.ne02];
|
theta_base = (float) pos[i2 + args.ne02 * 1];
|
||||||
} else if (sector < sec_w012) {
|
} else if (sector < sec_w012) {
|
||||||
theta_base = (float) pos[i2 + args.ne02 * 2];
|
theta_base = (float) pos[i2 + args.ne02 * 2];
|
||||||
} else {
|
} else {
|
||||||
theta_base = (float) pos[i2 + args.ne02 * 3];
|
theta_base = (float) pos[i2 + args.ne02 * 3];
|
||||||
}
|
}
|
||||||
|
}
|
||||||
// end of mrope
|
// end of mrope
|
||||||
|
|
||||||
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
|
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
|
||||||
|
|
|
||||||
|
|
@ -119,7 +119,7 @@ static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const
|
||||||
const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale,
|
const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale,
|
||||||
const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
|
const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
|
||||||
const float theta_scale, const float * freq_factors, const mrope_sections sections,
|
const float theta_scale, const float * freq_factors, const mrope_sections sections,
|
||||||
const sycl::nd_item<3> & item_ct1) {
|
const bool is_imrope, const sycl::nd_item<3> & item_ct1) {
|
||||||
// get index pos
|
// get index pos
|
||||||
const int i0 = 2 * (item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1));
|
const int i0 = 2 * (item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1));
|
||||||
if (i0 >= ne0) {
|
if (i0 >= ne0) {
|
||||||
|
|
@ -143,6 +143,17 @@ static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const
|
||||||
|
|
||||||
|
|
||||||
float theta_base = 0.0;
|
float theta_base = 0.0;
|
||||||
|
if (is_imrope) {
|
||||||
|
if (sector % 3 == 1 && sector < 3 * sections.v[1]) {
|
||||||
|
theta_base = pos[channel_x + ne2 * 1]*sycl::pow(theta_scale, i0/2.0f);
|
||||||
|
} else if (sector % 3 == 2 && sector < 3 * sections.v[2]) {
|
||||||
|
theta_base = pos[channel_x + ne2 * 2]*sycl::pow(theta_scale, i0/2.0f);
|
||||||
|
} else if (sector % 3 == 0 && sector < 3 * sections.v[0]) {
|
||||||
|
theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f);
|
||||||
|
} else {
|
||||||
|
theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
if (sector < sections.v[0]) {
|
if (sector < sections.v[0]) {
|
||||||
theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f);
|
theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f);
|
||||||
}
|
}
|
||||||
|
|
@ -155,6 +166,7 @@ static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const
|
||||||
else if (sector >= sec_w + sections.v[2]) {
|
else if (sector >= sec_w + sections.v[2]) {
|
||||||
theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f);
|
theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
|
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
|
||||||
float cos_theta;
|
float cos_theta;
|
||||||
|
|
@ -281,7 +293,7 @@ static void rope_multi_sycl(const T * x, T * dst, const int ne0, const int ne1,
|
||||||
const size_t s2, const int n_dims, const int nr, const int32_t * pos,
|
const size_t s2, const int n_dims, const int nr, const int32_t * pos,
|
||||||
const float freq_scale, const float freq_base, const float ext_factor,
|
const float freq_scale, const float freq_base, const float ext_factor,
|
||||||
const float attn_factor, const rope_corr_dims corr_dims, const float * freq_factors,
|
const float attn_factor, const rope_corr_dims corr_dims, const float * freq_factors,
|
||||||
const mrope_sections sections, queue_ptr stream) {
|
const mrope_sections sections, const bool is_imrope, queue_ptr stream) {
|
||||||
GGML_ASSERT(ne0 % 2 == 0);
|
GGML_ASSERT(ne0 % 2 == 0);
|
||||||
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
|
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
|
||||||
const int n_blocks_y = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
|
const int n_blocks_y = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
|
||||||
|
|
@ -297,12 +309,12 @@ static void rope_multi_sycl(const T * x, T * dst, const int ne0, const int ne1,
|
||||||
if (freq_factors == nullptr) {
|
if (freq_factors == nullptr) {
|
||||||
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
|
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
|
||||||
rope_multi<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
|
rope_multi<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
|
||||||
corr_dims, theta_scale, freq_factors, sections, item_ct1);
|
corr_dims, theta_scale, freq_factors, sections, is_imrope, item_ct1);
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
|
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
|
||||||
rope_multi<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
|
rope_multi<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
|
||||||
corr_dims, theta_scale, freq_factors, sections, item_ct1);
|
corr_dims, theta_scale, freq_factors, sections, is_imrope, item_ct1);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -381,6 +393,7 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
|
||||||
|
|
||||||
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
||||||
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
|
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
|
||||||
|
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
|
||||||
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
||||||
|
|
||||||
if (is_mrope) {
|
if (is_mrope) {
|
||||||
|
|
@ -422,11 +435,11 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
|
||||||
if (dst->src[0]->type == GGML_TYPE_F16) {
|
if (dst->src[0]->type == GGML_TYPE_F16) {
|
||||||
rope_multi_sycl((const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, ne01, ne02, s01,
|
rope_multi_sycl((const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, ne01, ne02, s01,
|
||||||
s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
|
s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
|
||||||
freq_factors, sections, main_stream);
|
freq_factors, sections, is_imrope, main_stream);
|
||||||
} else if (dst->src[0]->type == GGML_TYPE_F32) {
|
} else if (dst->src[0]->type == GGML_TYPE_F32) {
|
||||||
rope_multi_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, ne02, s01, s02, n_dims,
|
rope_multi_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, ne02, s01, s02, n_dims,
|
||||||
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections,
|
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections,
|
||||||
main_stream);
|
is_imrope, main_stream);
|
||||||
} else {
|
} else {
|
||||||
GGML_ABORT("Fatal error: Tensor type unsupported!");
|
GGML_ABORT("Fatal error: Tensor type unsupported!");
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1056,6 +1056,7 @@ struct vk_op_rope_push_constants {
|
||||||
uint32_t s1;
|
uint32_t s1;
|
||||||
uint32_t s2;
|
uint32_t s2;
|
||||||
int32_t sections[4];
|
int32_t sections[4];
|
||||||
|
uint32_t is_imrope;
|
||||||
uint32_t is_back;
|
uint32_t is_back;
|
||||||
uint32_t set_rows_stride;
|
uint32_t set_rows_stride;
|
||||||
};
|
};
|
||||||
|
|
@ -9927,6 +9928,8 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
|
||||||
memcpy(sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
|
memcpy(sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
|
||||||
|
|
||||||
float corr_dims[2];
|
float corr_dims[2];
|
||||||
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
||||||
|
|
||||||
|
|
@ -9948,7 +9951,7 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
|
||||||
(uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
|
(uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
|
||||||
freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
|
freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
|
||||||
src2 != nullptr, (uint32_t)src0->ne[2], s1, s2,
|
src2 != nullptr, (uint32_t)src0->ne[2], s1, s2,
|
||||||
{ sections[0], sections[1], sections[2], sections[3] }, backprop, set_rows_stride,
|
{ sections[0], sections[1], sections[2], sections[3] }, is_imrope, backprop, set_rows_stride,
|
||||||
}, dryrun);
|
}, dryrun);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,7 @@ layout (push_constant) uniform parameter {
|
||||||
uint s1;
|
uint s1;
|
||||||
uint s2;
|
uint s2;
|
||||||
int sections[4];
|
int sections[4];
|
||||||
|
uint is_imrope;
|
||||||
uint is_back;
|
uint is_back;
|
||||||
uint set_rows_stride;
|
uint set_rows_stride;
|
||||||
} p;
|
} p;
|
||||||
|
|
|
||||||
|
|
@ -32,6 +32,17 @@ void main() {
|
||||||
const uint sector = (i0 / 2) % sect_dims;
|
const uint sector = (i0 / 2) % sect_dims;
|
||||||
|
|
||||||
float theta_base = 0.0;
|
float theta_base = 0.0;
|
||||||
|
if (p.is_imrope != 0) {
|
||||||
|
if (sector % 3 == 1 && sector < 3 * p.sections[1]) {
|
||||||
|
theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
|
||||||
|
} else if (sector % 3 == 2 && sector < 3 * p.sections[2]) {
|
||||||
|
theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
|
||||||
|
} else if (sector % 3 == 0 && sector < 3 * p.sections[0]) {
|
||||||
|
theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f);
|
||||||
|
} else {
|
||||||
|
theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
if (sector < p.sections[0]) {
|
if (sector < p.sections[0]) {
|
||||||
theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f);
|
theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f);
|
||||||
}
|
}
|
||||||
|
|
@ -44,6 +55,7 @@ void main() {
|
||||||
else if (sector >= sec_w + p.sections[2]) {
|
else if (sector >= sec_w + p.sections[2]) {
|
||||||
theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
|
theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
|
const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -221,6 +221,7 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||||
|
|
||||||
let is_neox = bool(params.mode & 2);
|
let is_neox = bool(params.mode & 2);
|
||||||
let is_mrope = bool(params.mode & 8);
|
let is_mrope = bool(params.mode & 8);
|
||||||
|
let is_imrope = params.mode == 40;
|
||||||
let is_vision = params.mode == 24;
|
let is_vision = params.mode == 24;
|
||||||
|
|
||||||
var i = gid.x * 2; // start index for this thread
|
var i = gid.x * 2; // start index for this thread
|
||||||
|
|
@ -248,6 +249,17 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||||
let sec_w = params.sections1 + params.sections0;
|
let sec_w = params.sections1 + params.sections0;
|
||||||
let sec_e = params.sections2 + sec_w;
|
let sec_e = params.sections2 + sec_w;
|
||||||
let sector = (i0 / 2) % sect_dims;
|
let sector = (i0 / 2) % sect_dims;
|
||||||
|
if (is_imrope) {
|
||||||
|
if (sector % 3 == 1 && sector < 3 * params.sections1) {
|
||||||
|
theta_base_mult = 1;
|
||||||
|
} else if (sector % 3 == 2 && sector < 3 * params.sections2) {
|
||||||
|
theta_base_mult = 2;
|
||||||
|
} else if (sector % 3 == 0 && sector < 3 * params.sections0) {
|
||||||
|
theta_base_mult = 0;
|
||||||
|
} else {
|
||||||
|
theta_base_mult = 3;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
if (sector >= params.sections0 && sector < sec_w) {
|
if (sector >= params.sections0 && sector < sec_w) {
|
||||||
theta_base_mult = 1;
|
theta_base_mult = 1;
|
||||||
if (is_vision) {
|
if (is_vision) {
|
||||||
|
|
@ -268,6 +280,7 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||||
theta_scale_pwr = sector;
|
theta_scale_pwr = sector;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
let theta_base = f32(src1[params.offset_src1 + i2 + params.ne2 * theta_base_mult]) * pow(params.theta_scale, f32(theta_scale_pwr));
|
let theta_base = f32(src1[params.offset_src1 + i2 + params.ne2 * theta_base_mult]) * pow(params.theta_scale, f32(theta_scale_pwr));
|
||||||
let thetas = rope_yarn(theta_base/freq_factor(i0), i0);
|
let thetas = rope_yarn(theta_base/freq_factor(i0), i0);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -111,6 +111,7 @@ class Keys:
|
||||||
EXPERTS_PER_GROUP = "{arch}.experts_per_group"
|
EXPERTS_PER_GROUP = "{arch}.experts_per_group"
|
||||||
MOE_EVERY_N_LAYERS = "{arch}.moe_every_n_layers"
|
MOE_EVERY_N_LAYERS = "{arch}.moe_every_n_layers"
|
||||||
NEXTN_PREDICT_LAYERS = "{arch}.nextn_predict_layers"
|
NEXTN_PREDICT_LAYERS = "{arch}.nextn_predict_layers"
|
||||||
|
NUM_DEEPSTACK_LAYERS = "{arch}.n_deepstack_layers"
|
||||||
POOLING_TYPE = "{arch}.pooling_type"
|
POOLING_TYPE = "{arch}.pooling_type"
|
||||||
LOGIT_SCALE = "{arch}.logit_scale"
|
LOGIT_SCALE = "{arch}.logit_scale"
|
||||||
DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id"
|
DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id"
|
||||||
|
|
@ -277,6 +278,7 @@ class Keys:
|
||||||
USE_GELU = "clip.use_gelu"
|
USE_GELU = "clip.use_gelu"
|
||||||
USE_SILU = "clip.use_silu"
|
USE_SILU = "clip.use_silu"
|
||||||
N_WA_PATTERN = "clip.vision.n_wa_pattern" # used by qwen2.5vl
|
N_WA_PATTERN = "clip.vision.n_wa_pattern" # used by qwen2.5vl
|
||||||
|
IS_DEEPSTACK_LAYERS = "clip.vision.is_deepstack_layers"
|
||||||
|
|
||||||
class Attention:
|
class Attention:
|
||||||
HEAD_COUNT = "clip.vision.attention.head_count"
|
HEAD_COUNT = "clip.vision.attention.head_count"
|
||||||
|
|
@ -350,6 +352,8 @@ class MODEL_ARCH(IntEnum):
|
||||||
QWEN2VL = auto()
|
QWEN2VL = auto()
|
||||||
QWEN3 = auto()
|
QWEN3 = auto()
|
||||||
QWEN3MOE = auto()
|
QWEN3MOE = auto()
|
||||||
|
QWEN3VL = auto()
|
||||||
|
QWEN3VLMOE = auto()
|
||||||
PHI2 = auto()
|
PHI2 = auto()
|
||||||
PHI3 = auto()
|
PHI3 = auto()
|
||||||
PHIMOE = auto()
|
PHIMOE = auto()
|
||||||
|
|
@ -431,6 +435,7 @@ class VISION_PROJECTOR_TYPE(IntEnum):
|
||||||
GLM_EDGE = auto()
|
GLM_EDGE = auto()
|
||||||
MERGER = auto()
|
MERGER = auto()
|
||||||
GEMMA3 = auto()
|
GEMMA3 = auto()
|
||||||
|
QWEN3VL = auto()
|
||||||
COGVLM = auto()
|
COGVLM = auto()
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -648,6 +653,9 @@ class MODEL_TENSOR(IntEnum):
|
||||||
V_RESMPL_QUERY = auto() # minicpmv
|
V_RESMPL_QUERY = auto() # minicpmv
|
||||||
V_TOK_EMBD_IMG_BREAK = auto() # pixtral
|
V_TOK_EMBD_IMG_BREAK = auto() # pixtral
|
||||||
V_MM_PATCH_MERGER = auto() # mistral small 3.1
|
V_MM_PATCH_MERGER = auto() # mistral small 3.1
|
||||||
|
V_DS_NORM = auto() # qwen3vl
|
||||||
|
V_DS_FC1 = auto() # qwen3vl
|
||||||
|
V_DS_FC2 = auto() # qwen3vl
|
||||||
V_MM_POST_FC_NORM = auto() # cogvlm
|
V_MM_POST_FC_NORM = auto() # cogvlm
|
||||||
V_MM_UP = auto() # cogvlm
|
V_MM_UP = auto() # cogvlm
|
||||||
V_MM_DOWN = auto() # cogvlm
|
V_MM_DOWN = auto() # cogvlm
|
||||||
|
|
@ -709,6 +717,8 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||||
MODEL_ARCH.QWEN2VL: "qwen2vl",
|
MODEL_ARCH.QWEN2VL: "qwen2vl",
|
||||||
MODEL_ARCH.QWEN3: "qwen3",
|
MODEL_ARCH.QWEN3: "qwen3",
|
||||||
MODEL_ARCH.QWEN3MOE: "qwen3moe",
|
MODEL_ARCH.QWEN3MOE: "qwen3moe",
|
||||||
|
MODEL_ARCH.QWEN3VL: "qwen3vl",
|
||||||
|
MODEL_ARCH.QWEN3VLMOE: "qwen3vlmoe",
|
||||||
MODEL_ARCH.PHI2: "phi2",
|
MODEL_ARCH.PHI2: "phi2",
|
||||||
MODEL_ARCH.PHI3: "phi3",
|
MODEL_ARCH.PHI3: "phi3",
|
||||||
MODEL_ARCH.PHIMOE: "phimoe",
|
MODEL_ARCH.PHIMOE: "phimoe",
|
||||||
|
|
@ -1007,6 +1017,9 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||||
MODEL_TENSOR.V_RESMPL_QUERY: "resampler.query",
|
MODEL_TENSOR.V_RESMPL_QUERY: "resampler.query",
|
||||||
MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK: "v.token_embd.img_break", # pixtral
|
MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK: "v.token_embd.img_break", # pixtral
|
||||||
MODEL_TENSOR.V_MM_PATCH_MERGER: "mm.patch_merger", # mistral small 3.1
|
MODEL_TENSOR.V_MM_PATCH_MERGER: "mm.patch_merger", # mistral small 3.1
|
||||||
|
MODEL_TENSOR.V_DS_NORM: "v.deepstack.{bid}.norm",
|
||||||
|
MODEL_TENSOR.V_DS_FC1: "v.deepstack.{bid}.fc1",
|
||||||
|
MODEL_TENSOR.V_DS_FC2: "v.deepstack.{bid}.fc2",
|
||||||
MODEL_TENSOR.V_MM_POST_FC_NORM: "mm.post_fc_norm", # cogvlm
|
MODEL_TENSOR.V_MM_POST_FC_NORM: "mm.post_fc_norm", # cogvlm
|
||||||
MODEL_TENSOR.V_MM_UP: "mm.up",
|
MODEL_TENSOR.V_MM_UP: "mm.up",
|
||||||
MODEL_TENSOR.V_MM_DOWN: "mm.down",
|
MODEL_TENSOR.V_MM_DOWN: "mm.down",
|
||||||
|
|
@ -1082,6 +1095,9 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||||
MODEL_TENSOR.V_RESMPL_QUERY,
|
MODEL_TENSOR.V_RESMPL_QUERY,
|
||||||
MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK,
|
MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK,
|
||||||
MODEL_TENSOR.V_MM_PATCH_MERGER,
|
MODEL_TENSOR.V_MM_PATCH_MERGER,
|
||||||
|
MODEL_TENSOR.V_DS_NORM,
|
||||||
|
MODEL_TENSOR.V_DS_FC1,
|
||||||
|
MODEL_TENSOR.V_DS_FC2,
|
||||||
MODEL_TENSOR.V_MM_POST_FC_NORM,
|
MODEL_TENSOR.V_MM_POST_FC_NORM,
|
||||||
MODEL_TENSOR.V_MM_UP,
|
MODEL_TENSOR.V_MM_UP,
|
||||||
MODEL_TENSOR.V_MM_DOWN,
|
MODEL_TENSOR.V_MM_DOWN,
|
||||||
|
|
@ -1529,6 +1545,40 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||||
MODEL_TENSOR.FFN_UP_EXP,
|
MODEL_TENSOR.FFN_UP_EXP,
|
||||||
],
|
],
|
||||||
|
MODEL_ARCH.QWEN3VL: [
|
||||||
|
MODEL_TENSOR.TOKEN_EMBD,
|
||||||
|
MODEL_TENSOR.OUTPUT_NORM,
|
||||||
|
MODEL_TENSOR.OUTPUT,
|
||||||
|
MODEL_TENSOR.ROPE_FREQS,
|
||||||
|
MODEL_TENSOR.ATTN_NORM,
|
||||||
|
MODEL_TENSOR.ATTN_Q,
|
||||||
|
MODEL_TENSOR.ATTN_Q_NORM,
|
||||||
|
MODEL_TENSOR.ATTN_K,
|
||||||
|
MODEL_TENSOR.ATTN_K_NORM,
|
||||||
|
MODEL_TENSOR.ATTN_V,
|
||||||
|
MODEL_TENSOR.ATTN_OUT,
|
||||||
|
MODEL_TENSOR.FFN_NORM,
|
||||||
|
MODEL_TENSOR.FFN_GATE,
|
||||||
|
MODEL_TENSOR.FFN_DOWN,
|
||||||
|
MODEL_TENSOR.FFN_UP,
|
||||||
|
],
|
||||||
|
MODEL_ARCH.QWEN3VLMOE: [
|
||||||
|
MODEL_TENSOR.TOKEN_EMBD,
|
||||||
|
MODEL_TENSOR.OUTPUT_NORM,
|
||||||
|
MODEL_TENSOR.OUTPUT,
|
||||||
|
MODEL_TENSOR.ATTN_NORM,
|
||||||
|
MODEL_TENSOR.ATTN_Q,
|
||||||
|
MODEL_TENSOR.ATTN_Q_NORM,
|
||||||
|
MODEL_TENSOR.ATTN_K,
|
||||||
|
MODEL_TENSOR.ATTN_K_NORM,
|
||||||
|
MODEL_TENSOR.ATTN_V,
|
||||||
|
MODEL_TENSOR.ATTN_OUT,
|
||||||
|
MODEL_TENSOR.FFN_NORM,
|
||||||
|
MODEL_TENSOR.FFN_GATE_INP,
|
||||||
|
MODEL_TENSOR.FFN_GATE_EXP,
|
||||||
|
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||||
|
MODEL_TENSOR.FFN_UP_EXP,
|
||||||
|
],
|
||||||
MODEL_ARCH.PLAMO: [
|
MODEL_ARCH.PLAMO: [
|
||||||
MODEL_TENSOR.TOKEN_EMBD,
|
MODEL_TENSOR.TOKEN_EMBD,
|
||||||
MODEL_TENSOR.OUTPUT_NORM,
|
MODEL_TENSOR.OUTPUT_NORM,
|
||||||
|
|
@ -3106,6 +3156,7 @@ class VisionProjectorType:
|
||||||
LLAMA4 = "llama4"
|
LLAMA4 = "llama4"
|
||||||
QWEN2VL = "qwen2vl_merger"
|
QWEN2VL = "qwen2vl_merger"
|
||||||
QWEN25VL = "qwen2.5vl_merger"
|
QWEN25VL = "qwen2.5vl_merger"
|
||||||
|
QWEN3VL = "qwen3vl_merger"
|
||||||
ULTRAVOX = "ultravox"
|
ULTRAVOX = "ultravox"
|
||||||
INTERNVL = "internvl"
|
INTERNVL = "internvl"
|
||||||
QWEN2A = "qwen2a" # audio
|
QWEN2A = "qwen2a" # audio
|
||||||
|
|
|
||||||
|
|
@ -860,6 +860,9 @@ class GGUFWriter:
|
||||||
def add_pooling_type(self, value: PoolingType) -> None:
|
def add_pooling_type(self, value: PoolingType) -> None:
|
||||||
self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value)
|
self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value)
|
||||||
|
|
||||||
|
def add_num_deepstack_layers(self, count: int) -> None:
|
||||||
|
self.add_uint32(Keys.LLM.NUM_DEEPSTACK_LAYERS.format(arch=self.arch), count)
|
||||||
|
|
||||||
def add_rope_dimension_count(self, count: int) -> None:
|
def add_rope_dimension_count(self, count: int) -> None:
|
||||||
self.add_uint32(Keys.Rope.DIMENSION_COUNT.format(arch=self.arch), count)
|
self.add_uint32(Keys.Rope.DIMENSION_COUNT.format(arch=self.arch), count)
|
||||||
|
|
||||||
|
|
@ -1071,6 +1074,9 @@ class GGUFWriter:
|
||||||
def add_vision_n_wa_pattern(self, value: int) -> None:
|
def add_vision_n_wa_pattern(self, value: int) -> None:
|
||||||
self.add_uint32(Keys.ClipVision.N_WA_PATTERN, value)
|
self.add_uint32(Keys.ClipVision.N_WA_PATTERN, value)
|
||||||
|
|
||||||
|
def add_vision_is_deepstack_layers(self, layers: Sequence[bool]) -> None:
|
||||||
|
self.add_array(Keys.ClipVision.IS_DEEPSTACK_LAYERS, layers)
|
||||||
|
|
||||||
# audio models
|
# audio models
|
||||||
|
|
||||||
def add_audio_projection_dim(self, value: int) -> None:
|
def add_audio_projection_dim(self, value: int) -> None:
|
||||||
|
|
|
||||||
|
|
@ -1215,10 +1215,12 @@ class TensorNameMap:
|
||||||
"model.vision_model.embeddings.position_embedding", # SmolVLM
|
"model.vision_model.embeddings.position_embedding", # SmolVLM
|
||||||
"vision_model.positional_embedding_vlm", # llama 4
|
"vision_model.positional_embedding_vlm", # llama 4
|
||||||
"vision_tower.patch_embed.pos_emb", # kimi-vl
|
"vision_tower.patch_embed.pos_emb", # kimi-vl
|
||||||
|
"visual.pos_embed", # qwen3vl
|
||||||
"model.vision.patch_embedding.position_embedding", # cogvlm
|
"model.vision.patch_embedding.position_embedding", # cogvlm
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.V_ENC_ATTN_QKV: (
|
MODEL_TENSOR.V_ENC_ATTN_QKV: (
|
||||||
|
"visual.blocks.{bid}.attn.qkv", # qwen3vl
|
||||||
"model.vision.transformer.layers.{bid}.attention.query_key_value", # cogvlm
|
"model.vision.transformer.layers.{bid}.attention.query_key_value", # cogvlm
|
||||||
),
|
),
|
||||||
|
|
||||||
|
|
@ -1320,6 +1322,7 @@ class TensorNameMap:
|
||||||
"vision_model.model.layers.{bid}.mlp.fc1", # llama4
|
"vision_model.model.layers.{bid}.mlp.fc1", # llama4
|
||||||
"visual.blocks.{bid}.mlp.fc1", # qwen2vl
|
"visual.blocks.{bid}.mlp.fc1", # qwen2vl
|
||||||
"visual.blocks.{bid}.mlp.up_proj", # qwen2.5vl
|
"visual.blocks.{bid}.mlp.up_proj", # qwen2.5vl
|
||||||
|
"visual.blocks.{bid}.mlp.linear_fc1", # qwen3vl
|
||||||
"vision_tower.encoder.blocks.{bid}.mlp.fc0", # kimi-vl (fc0/fc1)
|
"vision_tower.encoder.blocks.{bid}.mlp.fc0", # kimi-vl (fc0/fc1)
|
||||||
"model.vision.transformer.layers.{bid}.mlp.fc1", # cogvlm
|
"model.vision.transformer.layers.{bid}.mlp.fc1", # cogvlm
|
||||||
),
|
),
|
||||||
|
|
@ -1340,6 +1343,7 @@ class TensorNameMap:
|
||||||
"vision_model.model.layers.{bid}.mlp.fc2", # llama4
|
"vision_model.model.layers.{bid}.mlp.fc2", # llama4
|
||||||
"visual.blocks.{bid}.mlp.fc2", # qwen2vl
|
"visual.blocks.{bid}.mlp.fc2", # qwen2vl
|
||||||
"visual.blocks.{bid}.mlp.down_proj", # qwen2.5vl
|
"visual.blocks.{bid}.mlp.down_proj", # qwen2.5vl
|
||||||
|
"visual.blocks.{bid}.mlp.linear_fc2", # qwen3vl
|
||||||
"vision_tower.encoder.blocks.{bid}.mlp.fc1", # kimi-vl (fc0/fc1)
|
"vision_tower.encoder.blocks.{bid}.mlp.fc1", # kimi-vl (fc0/fc1)
|
||||||
"model.vision.transformer.layers.{bid}.mlp.fc2", # cogvlm
|
"model.vision.transformer.layers.{bid}.mlp.fc2", # cogvlm
|
||||||
),
|
),
|
||||||
|
|
@ -1438,6 +1442,18 @@ class TensorNameMap:
|
||||||
"patch_merger.merging_layer", # mistral
|
"patch_merger.merging_layer", # mistral
|
||||||
),
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.V_DS_NORM: (
|
||||||
|
"model.visual.deepstack_merger_list.{bid}.norm", # deepstack in qwen3vl
|
||||||
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.V_DS_FC1: (
|
||||||
|
"model.visual.deepstack_merger_list.{bid}.linear_fc1", # deepstack in qwen3vl
|
||||||
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.V_DS_FC2: (
|
||||||
|
"model.visual.deepstack_merger_list.{bid}.linear_fc2", # deepstack in qwen3vl
|
||||||
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.V_MM_POST_FC_NORM: (
|
MODEL_TENSOR.V_MM_POST_FC_NORM: (
|
||||||
"model.vision.linear_proj.norm1", # cogvlm
|
"model.vision.linear_proj.norm1", # cogvlm
|
||||||
),
|
),
|
||||||
|
|
|
||||||
|
|
@ -83,6 +83,7 @@ extern "C" {
|
||||||
LLAMA_ROPE_TYPE_NORM = 0,
|
LLAMA_ROPE_TYPE_NORM = 0,
|
||||||
LLAMA_ROPE_TYPE_NEOX = GGML_ROPE_TYPE_NEOX,
|
LLAMA_ROPE_TYPE_NEOX = GGML_ROPE_TYPE_NEOX,
|
||||||
LLAMA_ROPE_TYPE_MROPE = GGML_ROPE_TYPE_MROPE,
|
LLAMA_ROPE_TYPE_MROPE = GGML_ROPE_TYPE_MROPE,
|
||||||
|
LLAMA_ROPE_TYPE_IMROPE = GGML_ROPE_TYPE_IMROPE,
|
||||||
LLAMA_ROPE_TYPE_VISION = GGML_ROPE_TYPE_VISION,
|
LLAMA_ROPE_TYPE_VISION = GGML_ROPE_TYPE_VISION,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -32,6 +32,8 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||||
{ LLM_ARCH_QWEN2VL, "qwen2vl" },
|
{ LLM_ARCH_QWEN2VL, "qwen2vl" },
|
||||||
{ LLM_ARCH_QWEN3, "qwen3" },
|
{ LLM_ARCH_QWEN3, "qwen3" },
|
||||||
{ LLM_ARCH_QWEN3MOE, "qwen3moe" },
|
{ LLM_ARCH_QWEN3MOE, "qwen3moe" },
|
||||||
|
{ LLM_ARCH_QWEN3VL, "qwen3vl" },
|
||||||
|
{ LLM_ARCH_QWEN3VLMOE, "qwen3vlmoe" },
|
||||||
{ LLM_ARCH_PHI2, "phi2" },
|
{ LLM_ARCH_PHI2, "phi2" },
|
||||||
{ LLM_ARCH_PHI3, "phi3" },
|
{ LLM_ARCH_PHI3, "phi3" },
|
||||||
{ LLM_ARCH_PHIMOE, "phimoe" },
|
{ LLM_ARCH_PHIMOE, "phimoe" },
|
||||||
|
|
@ -146,6 +148,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||||
{ LLM_KV_EXPERTS_PER_GROUP, "%s.experts_per_group" },
|
{ LLM_KV_EXPERTS_PER_GROUP, "%s.experts_per_group" },
|
||||||
{ LLM_KV_MOE_EVERY_N_LAYERS, "%s.moe_every_n_layers" },
|
{ LLM_KV_MOE_EVERY_N_LAYERS, "%s.moe_every_n_layers" },
|
||||||
{ LLM_KV_NEXTN_PREDICT_LAYERS, "%s.nextn_predict_layers" },
|
{ LLM_KV_NEXTN_PREDICT_LAYERS, "%s.nextn_predict_layers" },
|
||||||
|
{ LLM_KV_NUM_DEEPSTACK_LAYERS, "%s.n_deepstack_layers" },
|
||||||
{ LLM_KV_POOLING_TYPE, "%s.pooling_type" },
|
{ LLM_KV_POOLING_TYPE, "%s.pooling_type" },
|
||||||
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
|
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
|
||||||
{ LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" },
|
{ LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" },
|
||||||
|
|
@ -780,6 +783,45 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
LLM_ARCH_QWEN3VL,
|
||||||
|
{
|
||||||
|
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||||
|
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||||
|
{ LLM_TENSOR_OUTPUT, "output" },
|
||||||
|
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||||
|
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||||
|
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
||||||
|
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||||
|
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
||||||
|
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||||
|
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||||
|
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||||
|
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||||
|
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||||
|
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
LLM_ARCH_QWEN3VLMOE,
|
||||||
|
{
|
||||||
|
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||||
|
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||||
|
{ LLM_TENSOR_OUTPUT, "output" },
|
||||||
|
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||||
|
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||||
|
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
||||||
|
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||||
|
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
||||||
|
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||||
|
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||||
|
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||||
|
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||||
|
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||||
|
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||||
|
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
LLM_ARCH_PHI2,
|
LLM_ARCH_PHI2,
|
||||||
{
|
{
|
||||||
|
|
|
||||||
|
|
@ -36,6 +36,8 @@ enum llm_arch {
|
||||||
LLM_ARCH_QWEN2VL,
|
LLM_ARCH_QWEN2VL,
|
||||||
LLM_ARCH_QWEN3,
|
LLM_ARCH_QWEN3,
|
||||||
LLM_ARCH_QWEN3MOE,
|
LLM_ARCH_QWEN3MOE,
|
||||||
|
LLM_ARCH_QWEN3VL,
|
||||||
|
LLM_ARCH_QWEN3VLMOE,
|
||||||
LLM_ARCH_PHI2,
|
LLM_ARCH_PHI2,
|
||||||
LLM_ARCH_PHI3,
|
LLM_ARCH_PHI3,
|
||||||
LLM_ARCH_PHIMOE,
|
LLM_ARCH_PHIMOE,
|
||||||
|
|
@ -150,6 +152,7 @@ enum llm_kv {
|
||||||
LLM_KV_EXPERTS_PER_GROUP,
|
LLM_KV_EXPERTS_PER_GROUP,
|
||||||
LLM_KV_MOE_EVERY_N_LAYERS,
|
LLM_KV_MOE_EVERY_N_LAYERS,
|
||||||
LLM_KV_NEXTN_PREDICT_LAYERS,
|
LLM_KV_NEXTN_PREDICT_LAYERS,
|
||||||
|
LLM_KV_NUM_DEEPSTACK_LAYERS,
|
||||||
LLM_KV_POOLING_TYPE,
|
LLM_KV_POOLING_TYPE,
|
||||||
LLM_KV_LOGIT_SCALE,
|
LLM_KV_LOGIT_SCALE,
|
||||||
LLM_KV_DECODER_START_TOKEN_ID,
|
LLM_KV_DECODER_START_TOKEN_ID,
|
||||||
|
|
|
||||||
|
|
@ -148,7 +148,7 @@ bool llama_hparams::is_recurrent(uint32_t il) const {
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t llama_hparams::n_pos_per_embd() const {
|
uint32_t llama_hparams::n_pos_per_embd() const {
|
||||||
return rope_type == LLAMA_ROPE_TYPE_MROPE ? 4 : 1;
|
return rope_type == LLAMA_ROPE_TYPE_MROPE || rope_type == LLAMA_ROPE_TYPE_IMROPE ? 4 : 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llama_hparams::is_swa(uint32_t il) const {
|
bool llama_hparams::is_swa(uint32_t il) const {
|
||||||
|
|
|
||||||
|
|
@ -183,6 +183,9 @@ struct llama_hparams {
|
||||||
std::array<float, LLAMA_MAX_LAYERS> xielu_beta;
|
std::array<float, LLAMA_MAX_LAYERS> xielu_beta;
|
||||||
std::array<float, LLAMA_MAX_LAYERS> xielu_eps;
|
std::array<float, LLAMA_MAX_LAYERS> xielu_eps;
|
||||||
|
|
||||||
|
// qwen3vl deepstack
|
||||||
|
uint32_t n_deepstack_layers = 0;
|
||||||
|
|
||||||
// needed by encoder-decoder models (e.g. T5, FLAN-T5)
|
// needed by encoder-decoder models (e.g. T5, FLAN-T5)
|
||||||
// ref: https://github.com/ggerganov/llama.cpp/pull/8141
|
// ref: https://github.com/ggerganov/llama.cpp/pull/8141
|
||||||
llama_token dec_start_token_id = LLAMA_TOKEN_NULL;
|
llama_token dec_start_token_id = LLAMA_TOKEN_NULL;
|
||||||
|
|
|
||||||
|
|
@ -1375,7 +1375,7 @@ ggml_tensor * llama_kv_cache::build_rope_shift(
|
||||||
const auto & yarn_beta_slow = cparams.yarn_beta_slow;
|
const auto & yarn_beta_slow = cparams.yarn_beta_slow;
|
||||||
|
|
||||||
const auto & n_rot = hparams.n_rot;
|
const auto & n_rot = hparams.n_rot;
|
||||||
const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE
|
const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE || hparams.rope_type == LLAMA_ROPE_TYPE_IMROPE
|
||||||
// @ngxson : this is a workaround
|
// @ngxson : this is a workaround
|
||||||
// for M-RoPE, we want to rotate the whole vector when doing KV shift
|
// for M-RoPE, we want to rotate the whole vector when doing KV shift
|
||||||
// a normal RoPE should work, we just need to use the correct ordering
|
// a normal RoPE should work, we just need to use the correct ordering
|
||||||
|
|
|
||||||
|
|
@ -1025,6 +1025,21 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||||
default: type = LLM_TYPE_UNKNOWN;
|
default: type = LLM_TYPE_UNKNOWN;
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case LLM_ARCH_QWEN3VL:
|
||||||
|
{
|
||||||
|
ml.get_key(LLM_KV_NUM_DEEPSTACK_LAYERS, hparams.n_deepstack_layers, false);
|
||||||
|
ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true);
|
||||||
|
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||||
|
switch (hparams.n_layer) {
|
||||||
|
case 28: type = LLM_TYPE_1_7B; break;
|
||||||
|
case 36: type = hparams.n_embd == 2560 ? LLM_TYPE_4B : LLM_TYPE_8B; break;
|
||||||
|
case 64: type = LLM_TYPE_32B; break;
|
||||||
|
default: type = LLM_TYPE_UNKNOWN;
|
||||||
|
}
|
||||||
|
// since vision model stacks deepstack features along feature dim
|
||||||
|
// we also create a fake "n_embd" for text model to be the main embd + deepstack embds
|
||||||
|
hparams.n_embd *= hparams.n_deepstack_layers + 1;
|
||||||
|
} break;
|
||||||
case LLM_ARCH_QWEN3MOE:
|
case LLM_ARCH_QWEN3MOE:
|
||||||
{
|
{
|
||||||
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
|
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
|
||||||
|
|
@ -1036,6 +1051,21 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||||
default: type = LLM_TYPE_UNKNOWN;
|
default: type = LLM_TYPE_UNKNOWN;
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case LLM_ARCH_QWEN3VLMOE:
|
||||||
|
{
|
||||||
|
ml.get_key(LLM_KV_NUM_DEEPSTACK_LAYERS, hparams.n_deepstack_layers, false);
|
||||||
|
ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true);
|
||||||
|
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
|
||||||
|
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||||
|
switch (hparams.n_layer) {
|
||||||
|
case 48: type = LLM_TYPE_30B_A3B; break;
|
||||||
|
case 94: type = LLM_TYPE_235B_A22B; break;
|
||||||
|
default: type = LLM_TYPE_UNKNOWN;
|
||||||
|
}
|
||||||
|
// since vision model stacks deepstack features along feature dim
|
||||||
|
// we also create a fake "n_embd" for text model to be the main embd + deepstack embds
|
||||||
|
hparams.n_embd *= hparams.n_deepstack_layers + 1;
|
||||||
|
} break;
|
||||||
case LLM_ARCH_PHI2:
|
case LLM_ARCH_PHI2:
|
||||||
{
|
{
|
||||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
||||||
|
|
@ -3285,7 +3315,12 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
case LLM_ARCH_QWEN3:
|
case LLM_ARCH_QWEN3:
|
||||||
|
case LLM_ARCH_QWEN3VL:
|
||||||
{
|
{
|
||||||
|
// for model loading, the weights only have the main embd
|
||||||
|
// so we need to divide by the number of deepstack layers + 1
|
||||||
|
// n_embd is const int so we declare a new variable
|
||||||
|
int64_t n_embd = hparams.n_embd / (hparams.n_deepstack_layers + 1);
|
||||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||||
|
|
||||||
// output
|
// output
|
||||||
|
|
@ -3319,7 +3354,12 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
case LLM_ARCH_QWEN3MOE:
|
case LLM_ARCH_QWEN3MOE:
|
||||||
|
case LLM_ARCH_QWEN3VLMOE:
|
||||||
{
|
{
|
||||||
|
// for model loading, the weights only have the main embd
|
||||||
|
// so we need to divide by the number of deepstack layers + 1
|
||||||
|
// n_embd is const int so we declare a new variable
|
||||||
|
int64_t n_embd = hparams.n_embd / (hparams.n_deepstack_layers + 1);
|
||||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||||
|
|
||||||
// output
|
// output
|
||||||
|
|
@ -6428,6 +6468,10 @@ void llama_model::print_info() const {
|
||||||
LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train);
|
LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train);
|
||||||
LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn);
|
LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn);
|
||||||
LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown");
|
LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown");
|
||||||
|
// MRoPE (Multi-axis Rotary Position Embedding) sections
|
||||||
|
if (const auto & s = hparams.rope_sections; s[0] || s[1] || s[2] || s[3]) {
|
||||||
|
LLAMA_LOG_INFO("%s: mrope sections = [%d, %d, %d, %d]\n", __func__, s[0], s[1], s[2], s[3]);
|
||||||
|
}
|
||||||
if (!classifier_labels.empty()) {
|
if (!classifier_labels.empty()) {
|
||||||
LLAMA_LOG_INFO("%s: n_cls_out = %u\n", __func__, hparams.n_cls_out);
|
LLAMA_LOG_INFO("%s: n_cls_out = %u\n", __func__, hparams.n_cls_out);
|
||||||
|
|
||||||
|
|
@ -6493,7 +6537,7 @@ void llama_model::print_info() const {
|
||||||
LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp);
|
LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (arch == LLM_ARCH_QWEN3MOE || arch == LLM_ARCH_OPENAI_MOE) {
|
if (arch == LLM_ARCH_QWEN3MOE || arch == LLM_ARCH_OPENAI_MOE || arch == LLM_ARCH_QWEN3VLMOE) {
|
||||||
LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
|
LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -9655,6 +9699,301 @@ struct llm_build_qwen3moe : public llm_graph_context {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct llm_build_qwen3vl : public llm_graph_context {
|
||||||
|
llm_build_qwen3vl(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||||
|
|
||||||
|
const int64_t n_embd_full = hparams.n_embd; // main embd + deepstack embds
|
||||||
|
const size_t n_deepstack_layers = hparams.n_deepstack_layers;
|
||||||
|
const int64_t n_embd = n_embd_full / (n_deepstack_layers + 1);
|
||||||
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||||
|
|
||||||
|
|
||||||
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||||
|
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||||
|
|
||||||
|
ggml_tensor * cur;
|
||||||
|
ggml_tensor * inpL;
|
||||||
|
|
||||||
|
inpL = build_inp_embd(model.tok_embd);
|
||||||
|
|
||||||
|
int sections[4];
|
||||||
|
std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
|
||||||
|
|
||||||
|
std::vector<ggml_tensor *> deepstack_features(n_deepstack_layers, nullptr);
|
||||||
|
|
||||||
|
if (ubatch.embd) {
|
||||||
|
// Image input: split main embd and deepstack embds
|
||||||
|
ggml_tensor * inpL_main = ggml_view_2d(ctx0, inpL, n_embd, n_tokens, inpL->nb[1], 0);
|
||||||
|
for (size_t i = 0; i < n_deepstack_layers; i++) {
|
||||||
|
deepstack_features[i] = ggml_view_2d(ctx0, inpL, n_embd, n_tokens, inpL->nb[1], (i + 1) * n_embd * sizeof(float));
|
||||||
|
}
|
||||||
|
inpL = inpL_main;
|
||||||
|
}
|
||||||
|
|
||||||
|
// inp_pos - contains the positions
|
||||||
|
ggml_tensor * inp_pos = build_inp_pos();
|
||||||
|
|
||||||
|
auto * inp_attn = build_attn_inp_kv();
|
||||||
|
|
||||||
|
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||||
|
|
||||||
|
for (int il = 0; il < n_layer; ++il) {
|
||||||
|
ggml_tensor * inpSA = inpL;
|
||||||
|
|
||||||
|
// norm
|
||||||
|
cur = build_norm(inpL,
|
||||||
|
model.layers[il].attn_norm, NULL,
|
||||||
|
LLM_NORM_RMS, il);
|
||||||
|
cb(cur, "attn_norm", il);
|
||||||
|
|
||||||
|
// self-attention
|
||||||
|
{
|
||||||
|
// compute Q and K and RoPE them
|
||||||
|
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||||
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
|
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
|
||||||
|
cb(Kcur, "Kcur", il);
|
||||||
|
|
||||||
|
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
||||||
|
cb(Vcur, "Vcur", il);
|
||||||
|
|
||||||
|
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||||
|
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||||
|
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||||
|
|
||||||
|
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
|
||||||
|
cb(Qcur, "Qcur_normed", il);
|
||||||
|
|
||||||
|
Qcur = ggml_rope_multi(
|
||||||
|
ctx0, Qcur, inp_pos, nullptr,
|
||||||
|
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
|
);
|
||||||
|
|
||||||
|
Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
|
||||||
|
cb(Kcur, "Kcur_normed", il);
|
||||||
|
|
||||||
|
Kcur = ggml_rope_multi(
|
||||||
|
ctx0, Kcur, inp_pos, nullptr,
|
||||||
|
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
|
);
|
||||||
|
|
||||||
|
cb(Qcur, "Qcur", il);
|
||||||
|
cb(Kcur, "Kcur", il);
|
||||||
|
cb(Vcur, "Vcur", il);
|
||||||
|
|
||||||
|
cur = build_attn(inp_attn,
|
||||||
|
model.layers[il].wo, model.layers[il].bo,
|
||||||
|
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (il == n_layer - 1 && inp_out_ids) {
|
||||||
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||||
|
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||||
|
cb(ffn_inp, "ffn_inp", il);
|
||||||
|
|
||||||
|
// feed-forward network
|
||||||
|
cur = build_norm(ffn_inp,
|
||||||
|
model.layers[il].ffn_norm, NULL,
|
||||||
|
LLM_NORM_RMS, il);
|
||||||
|
cb(cur, "ffn_norm", il);
|
||||||
|
|
||||||
|
cur = build_ffn(cur,
|
||||||
|
model.layers[il].ffn_up, NULL, NULL,
|
||||||
|
model.layers[il].ffn_gate, NULL, NULL,
|
||||||
|
model.layers[il].ffn_down, NULL, NULL,
|
||||||
|
NULL,
|
||||||
|
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||||
|
cb(cur, "ffn_out", il);
|
||||||
|
|
||||||
|
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||||
|
|
||||||
|
cur = build_cvec(cur, il);
|
||||||
|
cb(cur, "l_out", il);
|
||||||
|
|
||||||
|
if (ubatch.embd && (size_t)il < n_deepstack_layers) {
|
||||||
|
cur = ggml_add(ctx0, cur, deepstack_features[il]);
|
||||||
|
cb(cur, "deepstack_out", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
// input for next layer
|
||||||
|
inpL = cur;
|
||||||
|
}
|
||||||
|
|
||||||
|
cur = inpL;
|
||||||
|
|
||||||
|
cur = build_norm(cur,
|
||||||
|
model.output_norm, NULL,
|
||||||
|
LLM_NORM_RMS, -1);
|
||||||
|
|
||||||
|
cb(cur, "result_norm", -1);
|
||||||
|
res->t_embd = cur;
|
||||||
|
|
||||||
|
// lm_head
|
||||||
|
cur = build_lora_mm(model.output, cur);
|
||||||
|
|
||||||
|
cb(cur, "result_output", -1);
|
||||||
|
res->t_logits = cur;
|
||||||
|
|
||||||
|
ggml_build_forward_expand(gf, cur);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct llm_build_qwen3vlmoe : public llm_graph_context {
|
||||||
|
llm_build_qwen3vlmoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||||
|
const int64_t n_embd_full = hparams.n_embd; // main embd + deepstack embds
|
||||||
|
const size_t n_deepstack_layers = hparams.n_deepstack_layers;
|
||||||
|
const int64_t n_embd = n_embd_full / (n_deepstack_layers + 1);
|
||||||
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||||
|
|
||||||
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||||
|
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||||
|
|
||||||
|
ggml_tensor * cur;
|
||||||
|
ggml_tensor * inpL;
|
||||||
|
|
||||||
|
inpL = build_inp_embd(model.tok_embd);
|
||||||
|
|
||||||
|
int sections[4];
|
||||||
|
std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
|
||||||
|
|
||||||
|
std::vector<ggml_tensor *> deepstack_features(n_deepstack_layers, nullptr);
|
||||||
|
|
||||||
|
if (ubatch.embd) {
|
||||||
|
// Image input: split main embd and deepstack embds
|
||||||
|
ggml_tensor * inpL_main = ggml_view_2d(ctx0, inpL, n_embd, n_tokens, inpL->nb[1], 0);
|
||||||
|
for (size_t i = 0; i < n_deepstack_layers; i++) {
|
||||||
|
deepstack_features[i] = ggml_view_2d(ctx0, inpL, n_embd, n_tokens, inpL->nb[1], (i + 1) * n_embd * sizeof(float));
|
||||||
|
}
|
||||||
|
inpL = inpL_main;
|
||||||
|
}
|
||||||
|
|
||||||
|
// inp_pos - contains the positions
|
||||||
|
ggml_tensor * inp_pos = build_inp_pos();
|
||||||
|
|
||||||
|
auto * inp_attn = build_attn_inp_kv();
|
||||||
|
|
||||||
|
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||||
|
|
||||||
|
for (int il = 0; il < n_layer; ++il) {
|
||||||
|
ggml_tensor * inpSA = inpL;
|
||||||
|
|
||||||
|
// norm
|
||||||
|
cur = build_norm(inpL,
|
||||||
|
model.layers[il].attn_norm, NULL,
|
||||||
|
LLM_NORM_RMS, il);
|
||||||
|
cb(cur, "attn_norm", il);
|
||||||
|
|
||||||
|
// self_attention
|
||||||
|
{
|
||||||
|
// compute Q and K and RoPE them
|
||||||
|
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||||
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
|
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
|
||||||
|
cb(Kcur, "Kcur", il);
|
||||||
|
|
||||||
|
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
||||||
|
cb(Vcur, "Vcur", il);
|
||||||
|
|
||||||
|
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||||
|
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||||
|
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||||
|
|
||||||
|
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
|
||||||
|
cb(Qcur, "Qcur_normed", il);
|
||||||
|
|
||||||
|
Qcur = ggml_rope_multi(
|
||||||
|
ctx0, Qcur, inp_pos, nullptr,
|
||||||
|
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
|
);
|
||||||
|
|
||||||
|
Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
|
||||||
|
cb(Kcur, "Kcur_normed", il);
|
||||||
|
|
||||||
|
Kcur = ggml_rope_multi(
|
||||||
|
ctx0, Kcur, inp_pos, nullptr,
|
||||||
|
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
|
);
|
||||||
|
|
||||||
|
cb(Qcur, "Qcur", il);
|
||||||
|
cb(Kcur, "Kcur", il);
|
||||||
|
cb(Vcur, "Vcur", il);
|
||||||
|
|
||||||
|
cur = build_attn(inp_attn,
|
||||||
|
model.layers[il].wo, model.layers[il].bo,
|
||||||
|
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (il == n_layer - 1 && inp_out_ids) {
|
||||||
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||||
|
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||||
|
cb(ffn_inp, "ffn_inp", il);
|
||||||
|
|
||||||
|
// MoE branch
|
||||||
|
cur = build_norm(ffn_inp,
|
||||||
|
model.layers[il].ffn_norm, NULL,
|
||||||
|
LLM_NORM_RMS, il);
|
||||||
|
cb(cur, "ffn_norm", il);
|
||||||
|
|
||||||
|
ggml_tensor * moe_out =
|
||||||
|
build_moe_ffn(cur,
|
||||||
|
model.layers[il].ffn_gate_inp,
|
||||||
|
model.layers[il].ffn_up_exps,
|
||||||
|
model.layers[il].ffn_gate_exps,
|
||||||
|
model.layers[il].ffn_down_exps,
|
||||||
|
nullptr,
|
||||||
|
n_expert, n_expert_used,
|
||||||
|
LLM_FFN_SILU, true,
|
||||||
|
false, 0.0,
|
||||||
|
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
|
||||||
|
il);
|
||||||
|
cb(moe_out, "ffn_moe_out", il);
|
||||||
|
cur = moe_out;
|
||||||
|
|
||||||
|
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||||
|
|
||||||
|
cur = build_cvec(cur, il);
|
||||||
|
cb(cur, "l_out", il);
|
||||||
|
|
||||||
|
if (ubatch.embd && (size_t)il < n_deepstack_layers) {
|
||||||
|
cur = ggml_add(ctx0, cur, deepstack_features[il]);
|
||||||
|
cb(cur, "deepstack_out", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
// input for next layer
|
||||||
|
inpL = cur;
|
||||||
|
}
|
||||||
|
|
||||||
|
cur = inpL;
|
||||||
|
|
||||||
|
cur = build_norm(cur,
|
||||||
|
model.output_norm, NULL,
|
||||||
|
LLM_NORM_RMS, -1);
|
||||||
|
|
||||||
|
cb(cur, "result_norm", -1);
|
||||||
|
res->t_embd = cur;
|
||||||
|
|
||||||
|
// lm_head
|
||||||
|
cur = build_lora_mm(model.output, cur);
|
||||||
|
|
||||||
|
cb(cur, "result_output", -1);
|
||||||
|
res->t_logits = cur;
|
||||||
|
|
||||||
|
ggml_build_forward_expand(gf, cur);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct llm_build_phi2 : public llm_graph_context {
|
struct llm_build_phi2 : public llm_graph_context {
|
||||||
llm_build_phi2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
llm_build_phi2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||||
|
|
@ -20014,6 +20353,14 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
||||||
{
|
{
|
||||||
llm = std::make_unique<llm_build_qwen3moe>(*this, params);
|
llm = std::make_unique<llm_build_qwen3moe>(*this, params);
|
||||||
} break;
|
} break;
|
||||||
|
case LLM_ARCH_QWEN3VL:
|
||||||
|
{
|
||||||
|
llm = std::make_unique<llm_build_qwen3vl>(*this, params);
|
||||||
|
} break;
|
||||||
|
case LLM_ARCH_QWEN3VLMOE:
|
||||||
|
{
|
||||||
|
llm = std::make_unique<llm_build_qwen3vlmoe>(*this, params);
|
||||||
|
} break;
|
||||||
case LLM_ARCH_PHI2:
|
case LLM_ARCH_PHI2:
|
||||||
{
|
{
|
||||||
llm = std::make_unique<llm_build_phi2>(*this, params);
|
llm = std::make_unique<llm_build_phi2>(*this, params);
|
||||||
|
|
@ -20532,6 +20879,9 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
||||||
|
|
||||||
case LLM_ARCH_QWEN2VL:
|
case LLM_ARCH_QWEN2VL:
|
||||||
return LLAMA_ROPE_TYPE_MROPE;
|
return LLAMA_ROPE_TYPE_MROPE;
|
||||||
|
case LLM_ARCH_QWEN3VL:
|
||||||
|
case LLM_ARCH_QWEN3VLMOE:
|
||||||
|
return LLAMA_ROPE_TYPE_IMROPE;
|
||||||
|
|
||||||
// all model arches should be listed explicitly here
|
// all model arches should be listed explicitly here
|
||||||
case LLM_ARCH_UNKNOWN:
|
case LLM_ARCH_UNKNOWN:
|
||||||
|
|
|
||||||
|
|
@ -7076,7 +7076,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||||
test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl 7B)
|
test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl 7B)
|
||||||
test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 20, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw));
|
test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 20, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw));
|
||||||
test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 32, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw));
|
test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 32, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw));
|
||||||
|
test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 128, GGML_ROPE_TYPE_IMROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,imrope (qwen3vl 2B)
|
||||||
|
test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 128, GGML_ROPE_TYPE_IMROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,imrope (qwen3vl 7B)
|
||||||
|
test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 20, GGML_ROPE_TYPE_IMROPE, 512, fs, ef, af, ff, v, fw));
|
||||||
|
test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 32, GGML_ROPE_TYPE_IMROPE, 512, fs, ef, af, ff, v, fw));
|
||||||
test_cases.emplace_back(new test_rope(type, { 80, 16, 2, 1}, 80, GGML_ROPE_TYPE_VISION, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl ViT)
|
test_cases.emplace_back(new test_rope(type, { 80, 16, 2, 1}, 80, GGML_ROPE_TYPE_VISION, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl ViT)
|
||||||
|
test_cases.emplace_back(new test_rope(type, {128, 16, 2, 1}, 128, GGML_ROPE_TYPE_IMROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen3vl)
|
||||||
}
|
}
|
||||||
|
|
||||||
test_cases.emplace_back(new test_rope(type, { 64, 128, 2, 1}, 64, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B)
|
test_cases.emplace_back(new test_rope(type, { 64, 128, 2, 1}, 64, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B)
|
||||||
|
|
@ -7092,7 +7097,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||||
|
|
||||||
// single inplace test per type/mode/ff
|
// single inplace test per type/mode/ff
|
||||||
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
||||||
for (int mode : {GGML_ROPE_TYPE_NORMAL, GGML_ROPE_TYPE_NEOX, GGML_ROPE_TYPE_MROPE, GGML_ROPE_TYPE_VISION}) {
|
for (int mode : {GGML_ROPE_TYPE_NORMAL, GGML_ROPE_TYPE_NEOX, GGML_ROPE_TYPE_MROPE, GGML_ROPE_TYPE_IMROPE, GGML_ROPE_TYPE_VISION}) {
|
||||||
for (bool ff : {false, true}) {
|
for (bool ff : {false, true}) {
|
||||||
test_cases.emplace_back(new test_rope(type, {128, 32, 2, 1}, 128, mode, 512, 1.4245f, 0.7465f, 1.4245f, ff, 0, true, true));
|
test_cases.emplace_back(new test_rope(type, {128, 32, 2, 1}, 128, mode, 512, 1.4245f, 0.7465f, 1.4245f, ff, 0, true, true));
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -138,7 +138,7 @@ int main(int /*argc*/, const char ** /*argv*/) {
|
||||||
struct ggml_tensor * x;
|
struct ggml_tensor * x;
|
||||||
|
|
||||||
// rope f32
|
// rope f32
|
||||||
for (int m = 0; m < 5; ++m) {
|
for (int m = 0; m < 6; ++m) {
|
||||||
const int ndims = 4;
|
const int ndims = 4;
|
||||||
|
|
||||||
const int64_t n_rot = 128;
|
const int64_t n_rot = 128;
|
||||||
|
|
@ -180,7 +180,7 @@ int main(int /*argc*/, const char ** /*argv*/) {
|
||||||
struct ggml_tensor * p2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2] * 4);
|
struct ggml_tensor * p2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2] * 4);
|
||||||
|
|
||||||
int sections[4] = {16, 24, 24, 0};
|
int sections[4] = {16, 24, 24, 0};
|
||||||
mode = (m == 3) ? GGML_ROPE_TYPE_MROPE : GGML_ROPE_TYPE_VISION;
|
mode = (m == 3) ? GGML_ROPE_TYPE_MROPE : (m == 4) ? GGML_ROPE_TYPE_VISION : GGML_ROPE_TYPE_IMROPE;
|
||||||
|
|
||||||
for (int i = 0; i < ne[2]; ++i) {
|
for (int i = 0; i < ne[2]; ++i) {
|
||||||
for (int j = 0; j < 4; ++j) {
|
for (int j = 0; j < 4; ++j) {
|
||||||
|
|
|
||||||
|
|
@ -39,6 +39,7 @@
|
||||||
#define KEY_FEATURE_LAYER "clip.vision.feature_layer"
|
#define KEY_FEATURE_LAYER "clip.vision.feature_layer"
|
||||||
#define KEY_PROJ_SCALE_FACTOR "clip.vision.projector.scale_factor"
|
#define KEY_PROJ_SCALE_FACTOR "clip.vision.projector.scale_factor"
|
||||||
#define KEY_SPATIAL_MERGE_SIZE "clip.vision.spatial_merge_size"
|
#define KEY_SPATIAL_MERGE_SIZE "clip.vision.spatial_merge_size"
|
||||||
|
#define KEY_IS_DEEPSTACK_LAYERS "clip.vision.is_deepstack_layers"
|
||||||
|
|
||||||
#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type"
|
#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type"
|
||||||
#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints"
|
#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints"
|
||||||
|
|
@ -94,6 +95,9 @@
|
||||||
#define TN_TOK_IMG_BREAK "v.token_embd.img_break" // pixtral
|
#define TN_TOK_IMG_BREAK "v.token_embd.img_break" // pixtral
|
||||||
#define TN_TOK_GLM_BOI "adapter.boi" // glm-edge (these embeddings are not in text model)
|
#define TN_TOK_GLM_BOI "adapter.boi" // glm-edge (these embeddings are not in text model)
|
||||||
#define TN_TOK_GLM_EOI "adapter.eoi" // glm-edge (these embeddings are not in text model)
|
#define TN_TOK_GLM_EOI "adapter.eoi" // glm-edge (these embeddings are not in text model)
|
||||||
|
#define TN_DEEPSTACK_NORM "v.deepstack.%d.norm.%s" // qwen3vl deepstack
|
||||||
|
#define TN_DEEPSTACK_FC1 "v.deepstack.%d.fc1.%s" // qwen3vl deepstack
|
||||||
|
#define TN_DEEPSTACK_FC2 "v.deepstack.%d.fc2.%s" // qwen3vl deepstack
|
||||||
|
|
||||||
// mimicpmv
|
// mimicpmv
|
||||||
#define TN_MINICPMV_POS_EMBD_K "resampler.pos_embed_k"
|
#define TN_MINICPMV_POS_EMBD_K "resampler.pos_embed_k"
|
||||||
|
|
@ -136,6 +140,7 @@ enum projector_type {
|
||||||
PROJECTOR_TYPE_MINICPMV,
|
PROJECTOR_TYPE_MINICPMV,
|
||||||
PROJECTOR_TYPE_GLM_EDGE,
|
PROJECTOR_TYPE_GLM_EDGE,
|
||||||
PROJECTOR_TYPE_QWEN2VL,
|
PROJECTOR_TYPE_QWEN2VL,
|
||||||
|
PROJECTOR_TYPE_QWEN3VL,
|
||||||
PROJECTOR_TYPE_GEMMA3,
|
PROJECTOR_TYPE_GEMMA3,
|
||||||
PROJECTOR_TYPE_IDEFICS3,
|
PROJECTOR_TYPE_IDEFICS3,
|
||||||
PROJECTOR_TYPE_PIXTRAL,
|
PROJECTOR_TYPE_PIXTRAL,
|
||||||
|
|
@ -161,6 +166,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
|
||||||
{ PROJECTOR_TYPE_GLM_EDGE, "adapter"},
|
{ PROJECTOR_TYPE_GLM_EDGE, "adapter"},
|
||||||
{ PROJECTOR_TYPE_QWEN2VL, "qwen2vl_merger"},
|
{ PROJECTOR_TYPE_QWEN2VL, "qwen2vl_merger"},
|
||||||
{ PROJECTOR_TYPE_QWEN25VL, "qwen2.5vl_merger"},
|
{ PROJECTOR_TYPE_QWEN25VL, "qwen2.5vl_merger"},
|
||||||
|
{ PROJECTOR_TYPE_QWEN3VL, "qwen3vl_merger"},
|
||||||
{ PROJECTOR_TYPE_GEMMA3, "gemma3"},
|
{ PROJECTOR_TYPE_GEMMA3, "gemma3"},
|
||||||
{ PROJECTOR_TYPE_IDEFICS3, "idefics3"},
|
{ PROJECTOR_TYPE_IDEFICS3, "idefics3"},
|
||||||
{ PROJECTOR_TYPE_PIXTRAL, "pixtral"},
|
{ PROJECTOR_TYPE_PIXTRAL, "pixtral"},
|
||||||
|
|
|
||||||
|
|
@ -241,6 +241,18 @@ struct clip_layer {
|
||||||
// layer scale (no bias)
|
// layer scale (no bias)
|
||||||
ggml_tensor * ls_1_w = nullptr;
|
ggml_tensor * ls_1_w = nullptr;
|
||||||
ggml_tensor * ls_2_w = nullptr;
|
ggml_tensor * ls_2_w = nullptr;
|
||||||
|
|
||||||
|
// qwen3vl deepstack merger
|
||||||
|
ggml_tensor * deepstack_norm_w = nullptr;
|
||||||
|
ggml_tensor * deepstack_norm_b = nullptr;
|
||||||
|
ggml_tensor * deepstack_fc1_w = nullptr;
|
||||||
|
ggml_tensor * deepstack_fc1_b = nullptr;
|
||||||
|
ggml_tensor * deepstack_fc2_w = nullptr;
|
||||||
|
ggml_tensor * deepstack_fc2_b = nullptr;
|
||||||
|
|
||||||
|
bool has_deepstack() const {
|
||||||
|
return deepstack_fc1_w != nullptr;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct clip_model {
|
struct clip_model {
|
||||||
|
|
@ -260,6 +272,8 @@ struct clip_model {
|
||||||
|
|
||||||
std::vector<clip_layer> layers;
|
std::vector<clip_layer> layers;
|
||||||
|
|
||||||
|
int32_t n_deepstack_layers = 0; // used by Qwen3-VL, calculated from clip_layer
|
||||||
|
|
||||||
ggml_tensor * post_ln_w;
|
ggml_tensor * post_ln_w;
|
||||||
ggml_tensor * post_ln_b;
|
ggml_tensor * post_ln_b;
|
||||||
|
|
||||||
|
|
@ -840,6 +854,189 @@ struct clip_graph {
|
||||||
return gf;
|
return gf;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Qwen3VL
|
||||||
|
ggml_cgraph * build_qwen3vl() {
|
||||||
|
GGML_ASSERT(model.patch_bias != nullptr);
|
||||||
|
GGML_ASSERT(model.position_embeddings != nullptr);
|
||||||
|
GGML_ASSERT(model.class_embedding == nullptr);
|
||||||
|
|
||||||
|
const int batch_size = 1;
|
||||||
|
const int n_pos = n_patches;
|
||||||
|
const int num_position_ids = n_pos * 4; // m-rope requires 4 dim per position
|
||||||
|
|
||||||
|
norm_type norm_t = NORM_TYPE_NORMAL;
|
||||||
|
|
||||||
|
int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
|
||||||
|
|
||||||
|
ggml_tensor * inp_raw = build_inp_raw();
|
||||||
|
ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
|
||||||
|
|
||||||
|
GGML_ASSERT(img.nx % (patch_size * 2) == 0);
|
||||||
|
GGML_ASSERT(img.ny % (patch_size * 2) == 0);
|
||||||
|
|
||||||
|
// second conv dimension
|
||||||
|
{
|
||||||
|
auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
|
||||||
|
inp = ggml_add(ctx0, inp, inp_1);
|
||||||
|
|
||||||
|
inp = ggml_permute(ctx0, inp, 1, 2, 0, 3); // [w, h, c, b] -> [c, w, h, b]
|
||||||
|
inp = ggml_cont_4d(
|
||||||
|
ctx0, inp,
|
||||||
|
n_embd * 2, n_patches_x / 2, n_patches_y, batch_size);
|
||||||
|
inp = ggml_reshape_4d(
|
||||||
|
ctx0, inp,
|
||||||
|
n_embd * 2, n_patches_x / 2, 2, batch_size * (n_patches_y / 2));
|
||||||
|
inp = ggml_permute(ctx0, inp, 0, 2, 1, 3);
|
||||||
|
inp = ggml_cont_3d(
|
||||||
|
ctx0, inp,
|
||||||
|
n_embd, n_patches_x * n_patches_y, batch_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
// add patch bias
|
||||||
|
if (model.patch_bias != nullptr) {
|
||||||
|
inp = ggml_add(ctx0, inp, model.patch_bias);
|
||||||
|
cb(inp, "patch_bias", -1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// calculate absolute position embedding and apply
|
||||||
|
ggml_tensor * learned_pos_embd = resize_position_embeddings();
|
||||||
|
learned_pos_embd = ggml_cont_4d(
|
||||||
|
ctx0, learned_pos_embd,
|
||||||
|
n_embd * 2, n_patches_x / 2, n_patches_y, batch_size);
|
||||||
|
learned_pos_embd = ggml_reshape_4d(
|
||||||
|
ctx0, learned_pos_embd,
|
||||||
|
n_embd * 2, n_patches_x / 2, 2, batch_size * (n_patches_y / 2));
|
||||||
|
learned_pos_embd = ggml_permute(ctx0, learned_pos_embd, 0, 2, 1, 3);
|
||||||
|
learned_pos_embd = ggml_cont_3d(
|
||||||
|
ctx0, learned_pos_embd,
|
||||||
|
n_embd, n_patches_x * n_patches_y, batch_size);
|
||||||
|
inp = ggml_add(ctx0, inp, learned_pos_embd);
|
||||||
|
cb(inp, "inp_pos_emb", -1);
|
||||||
|
|
||||||
|
ggml_tensor * inpL = inp;
|
||||||
|
|
||||||
|
ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_position_ids);
|
||||||
|
ggml_set_name(positions, "positions");
|
||||||
|
ggml_set_input(positions);
|
||||||
|
|
||||||
|
// pre-layernorm
|
||||||
|
if (model.pre_ln_w) {
|
||||||
|
inpL = build_norm(inpL, model.pre_ln_w, model.pre_ln_b, norm_t, eps, -1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// deepstack features (stack along the feature dimension), [n_embd * len(deepstack_layers), n_patches_x * n_patches_y, batch_size]
|
||||||
|
ggml_tensor * deepstack_features = nullptr;
|
||||||
|
const int merge_factor = hparams.spatial_merge_size > 0 ? hparams.spatial_merge_size * hparams.spatial_merge_size : 4; // default 2x2=4 for qwen3vl
|
||||||
|
|
||||||
|
// loop over layers
|
||||||
|
for (int il = 0; il < n_layer; il++) {
|
||||||
|
auto & layer = model.layers[il];
|
||||||
|
|
||||||
|
ggml_tensor * cur = inpL; // inpL = residual, cur = hidden_states
|
||||||
|
|
||||||
|
// layernorm1
|
||||||
|
cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, norm_t, eps, il);
|
||||||
|
cb(cur, "ln1", il);
|
||||||
|
|
||||||
|
// self-attention
|
||||||
|
{
|
||||||
|
cur = ggml_mul_mat(ctx0, layer.qkv_w, cur);
|
||||||
|
cur = ggml_add(ctx0, cur, layer.qkv_b);
|
||||||
|
|
||||||
|
ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, d_head*sizeof(float),
|
||||||
|
cur->nb[1], 0);
|
||||||
|
ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, d_head*sizeof(float),
|
||||||
|
cur->nb[1], n_embd * sizeof(float));
|
||||||
|
ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, d_head*sizeof(float),
|
||||||
|
cur->nb[1], 2 * n_embd * sizeof(float));
|
||||||
|
|
||||||
|
cb(Qcur, "Qcur", il);
|
||||||
|
cb(Kcur, "Kcur", il);
|
||||||
|
cb(Vcur, "Vcur", il);
|
||||||
|
|
||||||
|
// apply M-RoPE
|
||||||
|
Qcur = ggml_rope_multi(
|
||||||
|
ctx0, Qcur, positions, nullptr,
|
||||||
|
d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
|
||||||
|
Kcur = ggml_rope_multi(
|
||||||
|
ctx0, Kcur, positions, nullptr,
|
||||||
|
d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
|
||||||
|
|
||||||
|
cb(Qcur, "Qcur_rope", il);
|
||||||
|
cb(Kcur, "Kcur_rope", il);
|
||||||
|
|
||||||
|
cur = build_attn(layer.o_w, layer.o_b,
|
||||||
|
Qcur, Kcur, Vcur, nullptr, kq_scale, il);
|
||||||
|
cb(cur, "attn_out", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
// re-add the layer input, e.g., residual
|
||||||
|
cur = ggml_add(ctx0, cur, inpL);
|
||||||
|
|
||||||
|
inpL = cur; // inpL = residual, cur = hidden_states
|
||||||
|
|
||||||
|
cb(cur, "ffn_inp", il);
|
||||||
|
|
||||||
|
// layernorm2
|
||||||
|
cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, norm_t, eps, il);
|
||||||
|
cb(cur, "ffn_inp_normed", il);
|
||||||
|
|
||||||
|
// ffn
|
||||||
|
cur = build_ffn(cur,
|
||||||
|
layer.ff_up_w, layer.ff_up_b,
|
||||||
|
layer.ff_gate_w, layer.ff_gate_b,
|
||||||
|
layer.ff_down_w, layer.ff_down_b,
|
||||||
|
hparams.ffn_op, il);
|
||||||
|
|
||||||
|
cb(cur, "ffn_out", il);
|
||||||
|
|
||||||
|
// residual 2
|
||||||
|
cur = ggml_add(ctx0, inpL, cur);
|
||||||
|
cb(cur, "layer_out", il);
|
||||||
|
|
||||||
|
if (layer.has_deepstack()) {
|
||||||
|
ggml_tensor * feat = ggml_reshape_3d(ctx0, cur, n_embd * merge_factor, n_pos / merge_factor, batch_size);
|
||||||
|
feat = build_norm(feat, layer.deepstack_norm_w, layer.deepstack_norm_b, norm_t, eps, il);
|
||||||
|
feat = build_ffn(feat,
|
||||||
|
layer.deepstack_fc1_w, layer.deepstack_fc1_b,
|
||||||
|
nullptr, nullptr,
|
||||||
|
layer.deepstack_fc2_w, layer.deepstack_fc2_b,
|
||||||
|
ffn_op_type::FFN_GELU, il);
|
||||||
|
|
||||||
|
if(!deepstack_features) {
|
||||||
|
deepstack_features = feat;
|
||||||
|
} else {
|
||||||
|
// concat along the feature dimension
|
||||||
|
deepstack_features = ggml_concat(ctx0, deepstack_features, feat, 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inpL = cur;
|
||||||
|
}
|
||||||
|
|
||||||
|
// post-layernorm
|
||||||
|
if (model.post_ln_w) {
|
||||||
|
inpL = build_norm(inpL, model.post_ln_w, model.post_ln_b, norm_t, eps, n_layer);
|
||||||
|
}
|
||||||
|
|
||||||
|
// multimodal projection
|
||||||
|
ggml_tensor * embeddings = inpL;
|
||||||
|
embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd * 4, n_pos / 4, batch_size);
|
||||||
|
|
||||||
|
embeddings = build_ffn(embeddings,
|
||||||
|
model.mm_0_w, model.mm_0_b,
|
||||||
|
nullptr, nullptr,
|
||||||
|
model.mm_1_w, model.mm_1_b,
|
||||||
|
ffn_op_type::FFN_GELU, -1);
|
||||||
|
|
||||||
|
embeddings = ggml_concat(ctx0, embeddings, deepstack_features, 0); // concat along the feature dimension
|
||||||
|
|
||||||
|
// build the graph
|
||||||
|
ggml_build_forward_expand(gf, embeddings);
|
||||||
|
|
||||||
|
return gf;
|
||||||
|
}
|
||||||
|
|
||||||
ggml_cgraph * build_minicpmv() {
|
ggml_cgraph * build_minicpmv() {
|
||||||
const int batch_size = 1;
|
const int batch_size = 1;
|
||||||
|
|
||||||
|
|
@ -2211,6 +2408,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||||
{
|
{
|
||||||
res = graph.build_qwen2vl();
|
res = graph.build_qwen2vl();
|
||||||
} break;
|
} break;
|
||||||
|
case PROJECTOR_TYPE_QWEN3VL:
|
||||||
|
{
|
||||||
|
res = graph.build_qwen3vl();
|
||||||
|
} break;
|
||||||
case PROJECTOR_TYPE_MINICPMV:
|
case PROJECTOR_TYPE_MINICPMV:
|
||||||
{
|
{
|
||||||
res = graph.build_minicpmv();
|
res = graph.build_minicpmv();
|
||||||
|
|
@ -2534,6 +2735,12 @@ struct clip_model_loader {
|
||||||
hparams.warmup_image_size = hparams.patch_size * 8;
|
hparams.warmup_image_size = hparams.patch_size * 8;
|
||||||
get_u32(KEY_WIN_ATTN_PATTERN, hparams.n_wa_pattern);
|
get_u32(KEY_WIN_ATTN_PATTERN, hparams.n_wa_pattern);
|
||||||
} break;
|
} break;
|
||||||
|
case PROJECTOR_TYPE_QWEN3VL:
|
||||||
|
{
|
||||||
|
hparams.image_size = 1024; // still need this?
|
||||||
|
hparams.warmup_image_size = hparams.patch_size * 8;
|
||||||
|
get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.spatial_merge_size, false);
|
||||||
|
} break;
|
||||||
case PROJECTOR_TYPE_LLAMA4:
|
case PROJECTOR_TYPE_LLAMA4:
|
||||||
{
|
{
|
||||||
hparams.rope_theta = 10000.0f;
|
hparams.rope_theta = 10000.0f;
|
||||||
|
|
@ -2572,6 +2779,9 @@ struct clip_model_loader {
|
||||||
LOG_INF("%s: minicpmv_version: %d\n", __func__, hparams.minicpmv_version);
|
LOG_INF("%s: minicpmv_version: %d\n", __func__, hparams.minicpmv_version);
|
||||||
LOG_INF("%s: proj_scale_factor: %d\n", __func__, hparams.proj_scale_factor);
|
LOG_INF("%s: proj_scale_factor: %d\n", __func__, hparams.proj_scale_factor);
|
||||||
LOG_INF("%s: n_wa_pattern: %d\n", __func__, hparams.n_wa_pattern);
|
LOG_INF("%s: n_wa_pattern: %d\n", __func__, hparams.n_wa_pattern);
|
||||||
|
if (hparams.spatial_merge_size > 0) {
|
||||||
|
LOG_INF("%s: spatial_merge_size: %d\n", __func__, hparams.spatial_merge_size);
|
||||||
|
}
|
||||||
} else if (is_audio) {
|
} else if (is_audio) {
|
||||||
LOG_INF("\n--- audio hparams ---\n");
|
LOG_INF("\n--- audio hparams ---\n");
|
||||||
LOG_INF("%s: n_mel_bins: %d\n", __func__, hparams.n_mel_bins);
|
LOG_INF("%s: n_mel_bins: %d\n", __func__, hparams.n_mel_bins);
|
||||||
|
|
@ -2671,6 +2881,18 @@ struct clip_model_loader {
|
||||||
layer.ff_down_w = get_tensor(string_format(TN_FFN_DOWN, prefix, il, "weight"));
|
layer.ff_down_w = get_tensor(string_format(TN_FFN_DOWN, prefix, il, "weight"));
|
||||||
layer.ff_down_b = get_tensor(string_format(TN_FFN_DOWN, prefix, il, "bias"), false);
|
layer.ff_down_b = get_tensor(string_format(TN_FFN_DOWN, prefix, il, "bias"), false);
|
||||||
|
|
||||||
|
|
||||||
|
// qwen3vl deepstack layer
|
||||||
|
layer.deepstack_norm_w = get_tensor(string_format(TN_DEEPSTACK_NORM, il, "weight"), false);
|
||||||
|
layer.deepstack_norm_b = get_tensor(string_format(TN_DEEPSTACK_NORM, il, "bias"), false);
|
||||||
|
layer.deepstack_fc1_w = get_tensor(string_format(TN_DEEPSTACK_FC1, il, "weight"), false);
|
||||||
|
layer.deepstack_fc1_b = get_tensor(string_format(TN_DEEPSTACK_FC1, il, "bias"), false);
|
||||||
|
layer.deepstack_fc2_w = get_tensor(string_format(TN_DEEPSTACK_FC2, il, "weight"), false);
|
||||||
|
layer.deepstack_fc2_b = get_tensor(string_format(TN_DEEPSTACK_FC2, il, "bias"), false);
|
||||||
|
if (layer.has_deepstack()) {
|
||||||
|
model.n_deepstack_layers++;
|
||||||
|
}
|
||||||
|
|
||||||
// some models already exported with legacy (incorrect) naming which is quite messy, let's fix it here
|
// some models already exported with legacy (incorrect) naming which is quite messy, let's fix it here
|
||||||
// note: Qwen model converted from the old surgery script has n_ff = 0, so we cannot use n_ff to check!
|
// note: Qwen model converted from the old surgery script has n_ff = 0, so we cannot use n_ff to check!
|
||||||
bool is_ffn_swapped = (
|
bool is_ffn_swapped = (
|
||||||
|
|
@ -2806,6 +3028,13 @@ struct clip_model_loader {
|
||||||
model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));
|
model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));
|
||||||
model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"));
|
model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"));
|
||||||
} break;
|
} break;
|
||||||
|
case PROJECTOR_TYPE_QWEN3VL:
|
||||||
|
{
|
||||||
|
model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight"));
|
||||||
|
model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias"));
|
||||||
|
model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));
|
||||||
|
model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"));
|
||||||
|
} break;
|
||||||
case PROJECTOR_TYPE_GEMMA3:
|
case PROJECTOR_TYPE_GEMMA3:
|
||||||
{
|
{
|
||||||
model.mm_input_proj_w = get_tensor(TN_MM_INP_PROJ);
|
model.mm_input_proj_w = get_tensor(TN_MM_INP_PROJ);
|
||||||
|
|
@ -3689,7 +3918,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
|
||||||
res_imgs->grid_y = inst.grid_size.height;
|
res_imgs->grid_y = inst.grid_size.height;
|
||||||
return true;
|
return true;
|
||||||
|
|
||||||
} else if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL) {
|
} else if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN3VL) {
|
||||||
clip_image_u8 resized;
|
clip_image_u8 resized;
|
||||||
auto patch_size = params.patch_size * 2;
|
auto patch_size = params.patch_size * 2;
|
||||||
auto new_size = image_manipulation::calc_size_preserved_ratio(original_size, patch_size, params.image_size);
|
auto new_size = image_manipulation::calc_size_preserved_ratio(original_size, patch_size, params.image_size);
|
||||||
|
|
@ -3915,7 +4144,7 @@ const char * clip_patch_merge_type(const struct clip_ctx * ctx) {
|
||||||
int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 * img) {
|
int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 * img) {
|
||||||
const auto & params = ctx->model.hparams;
|
const auto & params = ctx->model.hparams;
|
||||||
const int n_total = clip_n_output_tokens(ctx, img);
|
const int n_total = clip_n_output_tokens(ctx, img);
|
||||||
if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL) {
|
if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN3VL) {
|
||||||
return img->nx / (params.patch_size * 2) + (int)(img->nx % params.patch_size > 0);
|
return img->nx / (params.patch_size * 2) + (int)(img->nx % params.patch_size > 0);
|
||||||
}
|
}
|
||||||
return n_total;
|
return n_total;
|
||||||
|
|
@ -3923,7 +4152,7 @@ int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 *
|
||||||
|
|
||||||
int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 * img) {
|
int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 * img) {
|
||||||
const auto & params = ctx->model.hparams;
|
const auto & params = ctx->model.hparams;
|
||||||
if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL) {
|
if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN3VL) {
|
||||||
return img->ny / (params.patch_size * 2) + (int)(img->ny % params.patch_size > 0);
|
return img->ny / (params.patch_size * 2) + (int)(img->ny % params.patch_size > 0);
|
||||||
}
|
}
|
||||||
return 1;
|
return 1;
|
||||||
|
|
@ -3979,6 +4208,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
|
||||||
} break;
|
} break;
|
||||||
case PROJECTOR_TYPE_QWEN2VL:
|
case PROJECTOR_TYPE_QWEN2VL:
|
||||||
case PROJECTOR_TYPE_QWEN25VL:
|
case PROJECTOR_TYPE_QWEN25VL:
|
||||||
|
case PROJECTOR_TYPE_QWEN3VL:
|
||||||
{
|
{
|
||||||
// dynamic size (2 conv, so double patch size)
|
// dynamic size (2 conv, so double patch size)
|
||||||
int patch_size = params.patch_size * 2;
|
int patch_size = params.patch_size * 2;
|
||||||
|
|
@ -4292,6 +4522,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||||
set_input_f32("pos_embed", pos_embed);
|
set_input_f32("pos_embed", pos_embed);
|
||||||
} break;
|
} break;
|
||||||
case PROJECTOR_TYPE_QWEN2VL:
|
case PROJECTOR_TYPE_QWEN2VL:
|
||||||
|
case PROJECTOR_TYPE_QWEN3VL:
|
||||||
{
|
{
|
||||||
const int merge_ratio = 2;
|
const int merge_ratio = 2;
|
||||||
const int pw = image_size_width / patch_size;
|
const int pw = image_size_width / patch_size;
|
||||||
|
|
@ -4540,6 +4771,9 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
|
||||||
case PROJECTOR_TYPE_QWEN2VL:
|
case PROJECTOR_TYPE_QWEN2VL:
|
||||||
case PROJECTOR_TYPE_QWEN25VL:
|
case PROJECTOR_TYPE_QWEN25VL:
|
||||||
return ctx->model.mm_1_b->ne[0];
|
return ctx->model.mm_1_b->ne[0];
|
||||||
|
case PROJECTOR_TYPE_QWEN3VL:
|
||||||
|
// main path + deepstack paths
|
||||||
|
return ctx->model.mm_1_b->ne[0] * (1 + ctx->model.n_deepstack_layers);
|
||||||
case PROJECTOR_TYPE_GEMMA3:
|
case PROJECTOR_TYPE_GEMMA3:
|
||||||
return ctx->model.mm_input_proj_w->ne[0];
|
return ctx->model.mm_input_proj_w->ne[0];
|
||||||
case PROJECTOR_TYPE_IDEFICS3:
|
case PROJECTOR_TYPE_IDEFICS3:
|
||||||
|
|
@ -4576,7 +4810,8 @@ bool clip_is_glm(const struct clip_ctx * ctx) {
|
||||||
|
|
||||||
bool clip_is_qwen2vl(const struct clip_ctx * ctx) {
|
bool clip_is_qwen2vl(const struct clip_ctx * ctx) {
|
||||||
return ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL
|
return ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL
|
||||||
|| ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL;
|
|| ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL
|
||||||
|
|| ctx->proj_type() == PROJECTOR_TYPE_QWEN3VL;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool clip_is_llava(const struct clip_ctx * ctx) {
|
bool clip_is_llava(const struct clip_ctx * ctx) {
|
||||||
|
|
|
||||||
|
|
@ -267,7 +267,7 @@ struct mtmd_context {
|
||||||
// https://github.com/huggingface/transformers/blob/1cd110c6cb6a6237614130c470e9a902dbc1a4bd/docs/source/en/model_doc/pixtral.md
|
// https://github.com/huggingface/transformers/blob/1cd110c6cb6a6237614130c470e9a902dbc1a4bd/docs/source/en/model_doc/pixtral.md
|
||||||
img_end = "[IMG_END]";
|
img_end = "[IMG_END]";
|
||||||
|
|
||||||
} else if (proj == PROJECTOR_TYPE_QWEN2VL || proj == PROJECTOR_TYPE_QWEN25VL) {
|
} else if (proj == PROJECTOR_TYPE_QWEN2VL || proj == PROJECTOR_TYPE_QWEN25VL || proj == PROJECTOR_TYPE_QWEN3VL) {
|
||||||
// <|vision_start|> ... (image embeddings) ... <|vision_end|>
|
// <|vision_start|> ... (image embeddings) ... <|vision_end|>
|
||||||
img_beg = "<|vision_start|>";
|
img_beg = "<|vision_start|>";
|
||||||
img_end = "<|vision_end|>";
|
img_end = "<|vision_end|>";
|
||||||
|
|
|
||||||
|
|
@ -84,6 +84,7 @@ if [ "$RUN_BIG_TESTS" = true ]; then
|
||||||
add_test_vision "ggml-org/Qwen2-VL-7B-Instruct-GGUF:Q4_K_M"
|
add_test_vision "ggml-org/Qwen2-VL-7B-Instruct-GGUF:Q4_K_M"
|
||||||
add_test_vision "ggml-org/Qwen2.5-VL-3B-Instruct-GGUF:Q4_K_M"
|
add_test_vision "ggml-org/Qwen2.5-VL-3B-Instruct-GGUF:Q4_K_M"
|
||||||
add_test_vision "ggml-org/Qwen2.5-VL-7B-Instruct-GGUF:Q4_K_M"
|
add_test_vision "ggml-org/Qwen2.5-VL-7B-Instruct-GGUF:Q4_K_M"
|
||||||
|
add_test_vision "ggml-org/Qwen3-VL-2B-Instruct-GGUF:Q8_0"
|
||||||
add_test_vision "ggml-org/InternVL3-8B-Instruct-GGUF:Q4_K_M"
|
add_test_vision "ggml-org/InternVL3-8B-Instruct-GGUF:Q4_K_M"
|
||||||
add_test_vision "ggml-org/InternVL3-14B-Instruct-GGUF:Q4_K_M"
|
add_test_vision "ggml-org/InternVL3-14B-Instruct-GGUF:Q4_K_M"
|
||||||
add_test_vision "ggml-org/Qwen2.5-Omni-7B-GGUF:Q4_K_M"
|
add_test_vision "ggml-org/Qwen2.5-Omni-7B-GGUF:Q4_K_M"
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue