mtmd : support Kimi VL model (#15458)
* convert : fix tensor naming conflict for llama 4 vision * convert ok * support kimi vision model * clean up * fix style * fix calc number of output tokens * refactor resize_position_embeddings * add test case * rename build fn * correct a small bug
This commit is contained in:
parent
85cc1ae998
commit
79a546220c
|
|
@ -6254,9 +6254,11 @@ class DeepseekModel(TextModel):
|
|||
raise ValueError(f"Unprocessed experts: {experts}")
|
||||
|
||||
|
||||
@ModelBase.register("DeepseekV2ForCausalLM")
|
||||
@ModelBase.register("DeepseekV3ForCausalLM")
|
||||
@ModelBase.register("KimiVLForConditionalGeneration")
|
||||
@ModelBase.register(
|
||||
"DeepseekV2ForCausalLM",
|
||||
"DeepseekV3ForCausalLM",
|
||||
"KimiVLForConditionalGeneration",
|
||||
)
|
||||
class DeepseekV2Model(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.DEEPSEEK2
|
||||
|
||||
|
|
@ -8507,6 +8509,43 @@ class PixtralModel(LlavaVisionModel):
|
|||
return "mm.2.weight"
|
||||
return super().map_tensor_name(name, try_suffixes)
|
||||
|
||||
|
||||
@ModelBase.register("KimiVLForConditionalGeneration")
|
||||
class KimiVLModel(MmprojModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
assert self.hparams_vision is not None
|
||||
self.hparams_vision["image_size"] = 64 * 14 # for compatibility
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.KIMIVL)
|
||||
self.gguf_writer.add_vision_use_gelu(True)
|
||||
self.gguf_writer.add_vision_projector_scale_factor(2)
|
||||
# eps is the same as pytorch's default value
|
||||
assert self.hparams_vision is not None
|
||||
self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams_vision.get("layer_norm_eps", 1e-5))
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
del bid # unused
|
||||
is_vision_tensor = "vision_tower" in name or "multi_modal_projector" in name
|
||||
|
||||
if is_vision_tensor:
|
||||
if "pos_emb.weight" in name:
|
||||
data_torch = data_torch.view(data_torch.shape[0] * data_torch.shape[1], data_torch.shape[2])
|
||||
elif "wqkv" in name:
|
||||
split_dim = 0 if "weight" in name else -1
|
||||
wq, wk, wv = data_torch.chunk(3, dim=split_dim)
|
||||
return [
|
||||
(self.map_tensor_name(name.replace("wqkv", "wq")), wq),
|
||||
(self.map_tensor_name(name.replace("wqkv", "wk")), wk),
|
||||
(self.map_tensor_name(name.replace("wqkv", "wv")), wv)
|
||||
]
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
return [] # skip other tensors
|
||||
|
||||
###### CONVERSION LOGIC ######
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2850,6 +2850,7 @@ class VisionProjectorType:
|
|||
QWEN25O = "qwen2.5o" # omni
|
||||
VOXTRAL = "voxtral"
|
||||
LFM2 = "lfm2"
|
||||
KIMIVL = "kimivl"
|
||||
|
||||
|
||||
# Items here are (block size, type size)
|
||||
|
|
|
|||
|
|
@ -1122,6 +1122,7 @@ class TensorNameMap:
|
|||
"vision_encoder.patch_conv", # pixtral
|
||||
"vision_model.patch_embedding.linear", # llama 4
|
||||
"visual.patch_embed.proj", # qwen2vl
|
||||
"vision_tower.patch_embed.proj", # kimi-vl
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_EMBD_POS: (
|
||||
|
|
@ -1130,6 +1131,7 @@ class TensorNameMap:
|
|||
"vpm.embeddings.position_embedding",
|
||||
"model.vision_model.embeddings.position_embedding", # SmolVLM
|
||||
"vision_model.positional_embedding_vlm", # llama 4
|
||||
"vision_tower.patch_embed.pos_emb", # kimi-vl
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_ATTN_Q: (
|
||||
|
|
@ -1141,6 +1143,7 @@ class TensorNameMap:
|
|||
"vision_tower.transformer.layers.{bid}.attention.q_proj", # pixtral-hf
|
||||
"vision_encoder.transformer.layers.{bid}.attention.wq", # pixtral
|
||||
"visual.blocks.{bid}.attn.q", # qwen2vl, generated
|
||||
"vision_tower.encoder.blocks.{bid}.wq", # kimi-vl, generated
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_ATTN_Q_NORM: (
|
||||
|
|
@ -1157,6 +1160,7 @@ class TensorNameMap:
|
|||
"vision_tower.transformer.layers.{bid}.attention.k_proj", # pixtral-hf
|
||||
"vision_encoder.transformer.layers.{bid}.attention.wk", # pixtral
|
||||
"visual.blocks.{bid}.attn.k", # qwen2vl, generated
|
||||
"vision_tower.encoder.blocks.{bid}.wk", # kimi-vl, generated
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_ATTN_K_NORM: (
|
||||
|
|
@ -1173,6 +1177,7 @@ class TensorNameMap:
|
|||
"vision_tower.transformer.layers.{bid}.attention.v_proj", # pixtral-hf
|
||||
"vision_encoder.transformer.layers.{bid}.attention.wv", # pixtral
|
||||
"visual.blocks.{bid}.attn.v", # qwen2vl, generated
|
||||
"vision_tower.encoder.blocks.{bid}.wv", # kimi-vl, generated
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_INPUT_NORM: (
|
||||
|
|
@ -1185,6 +1190,7 @@ class TensorNameMap:
|
|||
"vision_encoder.transformer.layers.{bid}.attention_norm", # pixtral
|
||||
"vision_model.model.layers.{bid}.input_layernorm", # llama4
|
||||
"visual.blocks.{bid}.norm1", # qwen2vl
|
||||
"vision_tower.encoder.blocks.{bid}.norm0", # kimi-vl (norm0/norm1)
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_ATTN_O: (
|
||||
|
|
@ -1197,6 +1203,7 @@ class TensorNameMap:
|
|||
"vision_tower.transformer.layers.{bid}.attention.o_proj", # pixtral-hf
|
||||
"vision_encoder.transformer.layers.{bid}.attention.wo", # pixtral
|
||||
"visual.blocks.{bid}.attn.proj", # qwen2vl
|
||||
"vision_tower.encoder.blocks.{bid}.wo", # kimi-vl
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_POST_ATTN_NORM: (
|
||||
|
|
@ -1209,6 +1216,7 @@ class TensorNameMap:
|
|||
"vision_tower.transformer.layers.{bid}.ffn_norm", # pixtral-hf
|
||||
"vision_encoder.transformer.layers.{bid}.ffn_norm", # pixtral
|
||||
"visual.blocks.{bid}.norm2", # qwen2vl
|
||||
"vision_tower.encoder.blocks.{bid}.norm1", # kimi-vl (norm0/norm1)
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_FFN_UP: (
|
||||
|
|
@ -1221,6 +1229,7 @@ class TensorNameMap:
|
|||
"vision_model.model.layers.{bid}.mlp.fc1", # llama4
|
||||
"visual.blocks.{bid}.mlp.fc1", # qwen2vl
|
||||
"visual.blocks.{bid}.mlp.up_proj", # qwen2.5vl
|
||||
"vision_tower.encoder.blocks.{bid}.mlp.fc0", # kimi-vl (fc0/fc1)
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_FFN_GATE: (
|
||||
|
|
@ -1239,6 +1248,7 @@ class TensorNameMap:
|
|||
"vision_model.model.layers.{bid}.mlp.fc2", # llama4
|
||||
"visual.blocks.{bid}.mlp.fc2", # qwen2vl
|
||||
"visual.blocks.{bid}.mlp.down_proj", # qwen2.5vl
|
||||
"vision_tower.encoder.blocks.{bid}.mlp.fc1", # kimi-vl (fc0/fc1)
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_LAYER_SCALE_1: (
|
||||
|
|
@ -1263,6 +1273,7 @@ class TensorNameMap:
|
|||
"model.vision_model.post_layernorm", # SmolVLM
|
||||
"vision_model.layernorm_post", # llama4
|
||||
"visual.merger.ln_q", # qwen2vl
|
||||
"vision_tower.encoder.final_layernorm", # kimi-vl
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_MM_INP_PROJ: (
|
||||
|
|
@ -1272,6 +1283,7 @@ class TensorNameMap:
|
|||
MODEL_TENSOR.V_MM_INP_NORM: (
|
||||
"multi_modal_projector.norm",
|
||||
"multi_modal_projector.layer_norm",
|
||||
"multi_modal_projector.pre_norm",
|
||||
"pre_mm_projector_norm",
|
||||
),
|
||||
|
||||
|
|
|
|||
|
|
@ -135,6 +135,7 @@ enum projector_type {
|
|||
PROJECTOR_TYPE_QWEN25O, // will be replaced by QWEN2A or QWEN25VL depending on clip_ctx
|
||||
PROJECTOR_TYPE_VOXTRAL,
|
||||
PROJECTOR_TYPE_LFM2,
|
||||
PROJECTOR_TYPE_KIMIVL,
|
||||
PROJECTOR_TYPE_UNKNOWN,
|
||||
};
|
||||
|
||||
|
|
@ -156,6 +157,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
|
|||
{ PROJECTOR_TYPE_QWEN25O, "qwen2.5o"},
|
||||
{ PROJECTOR_TYPE_VOXTRAL, "voxtral"},
|
||||
{ PROJECTOR_TYPE_LFM2, "lfm2"},
|
||||
{ PROJECTOR_TYPE_KIMIVL, "kimivl"},
|
||||
};
|
||||
|
||||
static projector_type clip_projector_type_from_string(const std::string & str) {
|
||||
|
|
|
|||
|
|
@ -526,57 +526,16 @@ struct clip_graph {
|
|||
cur);
|
||||
|
||||
} else if (ctx->proj_type() == PROJECTOR_TYPE_IDEFICS3) {
|
||||
// pixel_shuffle
|
||||
// https://github.com/huggingface/transformers/blob/0a950e0bbe1ed58d5401a6b547af19f15f0c195e/src/transformers/models/idefics3/modeling_idefics3.py#L578
|
||||
|
||||
const int scale_factor = model.hparams.proj_scale_factor;
|
||||
const int n_embd = cur->ne[0];
|
||||
const int seq = cur->ne[1];
|
||||
const int bsz = 1; // batch size, always 1 for now since we don't support batching
|
||||
const int height = std::sqrt(seq);
|
||||
const int width = std::sqrt(seq);
|
||||
GGML_ASSERT(scale_factor != 0);
|
||||
cur = ggml_reshape_4d(ctx0, cur, n_embd * scale_factor, width / scale_factor, height, bsz);
|
||||
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
|
||||
cur = ggml_cont_4d(ctx0, cur,
|
||||
n_embd * scale_factor * scale_factor,
|
||||
height / scale_factor,
|
||||
width / scale_factor,
|
||||
bsz);
|
||||
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
|
||||
cur = ggml_cont_3d(ctx0, cur,
|
||||
n_embd * scale_factor * scale_factor,
|
||||
seq / (scale_factor * scale_factor),
|
||||
bsz);
|
||||
|
||||
cur = build_patch_merge_permute(cur, scale_factor);
|
||||
cur = ggml_mul_mat(ctx0, model.projection, cur);
|
||||
|
||||
} else if (ctx->proj_type() == PROJECTOR_TYPE_LFM2) {
|
||||
// pixel unshuffle block
|
||||
const int scale_factor = model.hparams.proj_scale_factor;
|
||||
GGML_ASSERT(scale_factor > 1);
|
||||
|
||||
const int n_embd = cur->ne[0];
|
||||
int width = img.nx / patch_size;
|
||||
int height = img.ny / patch_size;
|
||||
|
||||
// pad width and height to factor
|
||||
const int64_t pad_width = CLIP_ALIGN(width, scale_factor) - width;
|
||||
const int64_t pad_height = CLIP_ALIGN(height, scale_factor) - height;
|
||||
cur = ggml_reshape_3d(ctx0, cur, n_embd, width, height);
|
||||
if (pad_width || pad_height) {
|
||||
cur = ggml_pad(ctx0, cur, 0, pad_width, pad_height, 0);
|
||||
width += pad_width;
|
||||
height += pad_height;
|
||||
}
|
||||
|
||||
// unshuffle h
|
||||
cur = ggml_reshape_3d(ctx0, cur, n_embd * scale_factor, width / scale_factor, height);
|
||||
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
|
||||
|
||||
// unshuffle w
|
||||
cur = ggml_cont_3d(ctx0, cur, n_embd * scale_factor * scale_factor, height / scale_factor, width / scale_factor);
|
||||
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
|
||||
|
||||
cur = ggml_cont_2d(ctx0, cur, cur->ne[0], cur->ne[1] * cur->ne[2]);
|
||||
cur = build_patch_merge_permute(cur, scale_factor);
|
||||
|
||||
// projection
|
||||
cur = ggml_norm(ctx0, cur, 1e-5); // default nn.LayerNorm
|
||||
|
|
@ -1086,7 +1045,7 @@ struct clip_graph {
|
|||
n_patches_x / scale_factor,
|
||||
n_patches_y / scale_factor,
|
||||
bsz);
|
||||
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
|
||||
//cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
|
||||
// flatten to 2D
|
||||
cur = ggml_cont_2d(ctx0, cur,
|
||||
n_embd * scale_factor * scale_factor,
|
||||
|
|
@ -1113,6 +1072,67 @@ struct clip_graph {
|
|||
return gf;
|
||||
}
|
||||
|
||||
ggml_cgraph * build_kimivl() {
|
||||
// 2D input positions
|
||||
ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches);
|
||||
ggml_set_name(pos_h, "pos_h");
|
||||
ggml_set_input(pos_h);
|
||||
|
||||
ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches);
|
||||
ggml_set_name(pos_w, "pos_w");
|
||||
ggml_set_input(pos_w);
|
||||
|
||||
ggml_tensor * learned_pos_embd = resize_position_embeddings();
|
||||
|
||||
// build ViT with 2D position embeddings
|
||||
auto add_pos = [&](ggml_tensor * cur, const clip_layer &) {
|
||||
// first half is X axis and second half is Y axis
|
||||
return build_rope_2d(ctx0, cur, pos_w, pos_h, hparams.rope_theta, false);
|
||||
};
|
||||
|
||||
ggml_tensor * inp = build_inp();
|
||||
ggml_tensor * cur = build_vit(
|
||||
inp, n_patches,
|
||||
NORM_TYPE_NORMAL,
|
||||
hparams.ffn_op,
|
||||
learned_pos_embd,
|
||||
add_pos);
|
||||
|
||||
cb(cur, "vit_out", -1);
|
||||
|
||||
{
|
||||
// patch_merger
|
||||
const int scale_factor = model.hparams.proj_scale_factor;
|
||||
cur = build_patch_merge_permute(cur, scale_factor);
|
||||
|
||||
// projection norm
|
||||
int proj_inp_dim = cur->ne[0];
|
||||
cur = ggml_view_2d(ctx0, cur,
|
||||
n_embd, cur->ne[1] * scale_factor * scale_factor,
|
||||
ggml_row_size(cur->type, n_embd), 0);
|
||||
cur = ggml_norm(ctx0, cur, 1e-5); // default nn.LayerNorm
|
||||
cur = ggml_mul(ctx0, cur, model.mm_input_norm_w);
|
||||
cur = ggml_add(ctx0, cur, model.mm_input_norm_b);
|
||||
cur = ggml_view_2d(ctx0, cur,
|
||||
proj_inp_dim, cur->ne[1] / scale_factor / scale_factor,
|
||||
ggml_row_size(cur->type, proj_inp_dim), 0);
|
||||
cb(cur, "proj_inp_normed", -1);
|
||||
|
||||
// projection mlp
|
||||
cur = ggml_mul_mat(ctx0, model.mm_1_w, cur);
|
||||
cur = ggml_add(ctx0, cur, model.mm_1_b);
|
||||
cur = ggml_gelu(ctx0, cur);
|
||||
cur = ggml_mul_mat(ctx0, model.mm_2_w, cur);
|
||||
cur = ggml_add(ctx0, cur, model.mm_2_b);
|
||||
cb(cur, "proj_out", -1);
|
||||
}
|
||||
|
||||
// build the graph
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
return gf;
|
||||
}
|
||||
|
||||
// this graph is used by llava, granite and glm
|
||||
// due to having embedding_stack (used by granite), we cannot reuse build_vit
|
||||
ggml_cgraph * build_llava() {
|
||||
|
|
@ -1611,18 +1631,20 @@ private:
|
|||
ggml_tensor * pos_embd = model.position_embeddings;
|
||||
const int height = img.ny / patch_size;
|
||||
const int width = img.nx / patch_size;
|
||||
const uint32_t mode = GGML_SCALE_MODE_BILINEAR;
|
||||
const int n_per_side = (int)std::sqrt(pos_embd->ne[1]);
|
||||
|
||||
if (!pos_embd || height * width == pos_embd->ne[1]) {
|
||||
GGML_ASSERT(pos_embd);
|
||||
|
||||
if (height == n_per_side && width == n_per_side) {
|
||||
return pos_embd;
|
||||
}
|
||||
|
||||
const int n_pos_embd = std::sqrt(pos_embd->ne[1]);
|
||||
pos_embd = ggml_reshape_3d(ctx0, pos_embd, n_embd, n_pos_embd, n_pos_embd); // -> (n_embd, n_pos_embd, n_pos_embd)
|
||||
pos_embd = ggml_permute(ctx0, pos_embd, 2, 0, 1, 3); // -> (n_pos_embd, n_pos_embd, n_embd)
|
||||
pos_embd = ggml_interpolate(ctx0, pos_embd, width, height, n_embd, 1, 1); // -> (width, height, n_embd)
|
||||
pos_embd = ggml_reshape_2d(ctx0, pos_embd, height * width, n_embd); // -> (height * width, n_embd)
|
||||
pos_embd = ggml_transpose(ctx0, pos_embd); // -> (n_embd, height * width)
|
||||
pos_embd = ggml_cont(ctx0, pos_embd);
|
||||
pos_embd = ggml_reshape_3d(ctx0, pos_embd, n_embd, n_per_side, n_per_side); // -> (n_embd, n_per_side, n_per_side)
|
||||
pos_embd = ggml_permute(ctx0, pos_embd, 2, 0, 1, 3); // -> (n_per_side, n_per_side, n_embd)
|
||||
pos_embd = ggml_interpolate(ctx0, pos_embd, width, height, n_embd, 1, mode); // -> (width, height, n_embd)
|
||||
pos_embd = ggml_permute(ctx0, pos_embd, 1, 2, 0, 3); // -> (n_embd, width, height)
|
||||
pos_embd = ggml_cont_2d(ctx0, pos_embd, n_embd, width * height); // -> (n_embd, width * height)
|
||||
|
||||
return pos_embd;
|
||||
}
|
||||
|
|
@ -2021,6 +2043,39 @@ private:
|
|||
return cur;
|
||||
}
|
||||
|
||||
// aka pixel_shuffle / pixel_unshuffle / patch_merger (Kimi-VL)
|
||||
// support dynamic resolution
|
||||
ggml_tensor * build_patch_merge_permute(ggml_tensor * cur, int scale_factor) {
|
||||
GGML_ASSERT(scale_factor > 1);
|
||||
|
||||
const int n_embd = cur->ne[0];
|
||||
int width = img.nx / patch_size;
|
||||
int height = img.ny / patch_size;
|
||||
|
||||
// pad width and height to factor
|
||||
const int64_t pad_width = CLIP_ALIGN(width, scale_factor) - width;
|
||||
const int64_t pad_height = CLIP_ALIGN(height, scale_factor) - height;
|
||||
cur = ggml_reshape_3d(ctx0, cur, n_embd, width, height);
|
||||
if (pad_width || pad_height) {
|
||||
cur = ggml_pad(ctx0, cur, 0, pad_width, pad_height, 0);
|
||||
width += pad_width;
|
||||
height += pad_height;
|
||||
}
|
||||
|
||||
// unshuffle h
|
||||
cur = ggml_reshape_3d(ctx0, cur, n_embd * scale_factor, width / scale_factor, height);
|
||||
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
|
||||
|
||||
// unshuffle w
|
||||
cur = ggml_cont_3d(ctx0, cur, n_embd * scale_factor * scale_factor, height / scale_factor, width / scale_factor);
|
||||
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
|
||||
|
||||
cur = ggml_cont_2d(ctx0, cur, cur->ne[0], cur->ne[1] * cur->ne[2]);
|
||||
cb(cur, "pixel_shuffle", -1);
|
||||
|
||||
return cur;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch & imgs) {
|
||||
|
|
@ -2063,6 +2118,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
|||
{
|
||||
res = graph.build_whisper_enc();
|
||||
} break;
|
||||
case PROJECTOR_TYPE_KIMIVL:
|
||||
{
|
||||
res = graph.build_kimivl();
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
res = graph.build_llava();
|
||||
|
|
@ -2313,6 +2372,12 @@ struct clip_model_loader {
|
|||
hparams.image_size = 1024;
|
||||
get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.spatial_merge_size, false);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_KIMIVL:
|
||||
{
|
||||
hparams.rope_theta = 10000.0f;
|
||||
hparams.warmup_image_size = hparams.patch_size * 8;
|
||||
get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_GEMMA3:
|
||||
{
|
||||
// default value (used by all model sizes in gemma 3 family)
|
||||
|
|
@ -2477,7 +2542,20 @@ struct clip_model_loader {
|
|||
|
||||
// some models already exported with legacy (incorrect) naming which is quite messy, let's fix it here
|
||||
// note: Qwen model converted from the old surgery script has n_ff = 0, so we cannot use n_ff to check!
|
||||
if (layer.ff_up_w && layer.ff_down_w && layer.ff_down_w->ne[0] == hparams.n_embd) {
|
||||
bool is_ffn_swapped = (
|
||||
// only old models need this fix
|
||||
model.proj_type == PROJECTOR_TYPE_MLP
|
||||
|| model.proj_type == PROJECTOR_TYPE_MLP_NORM
|
||||
|| model.proj_type == PROJECTOR_TYPE_LDP
|
||||
|| model.proj_type == PROJECTOR_TYPE_LDPV2
|
||||
|| model.proj_type == PROJECTOR_TYPE_QWEN2VL
|
||||
|| model.proj_type == PROJECTOR_TYPE_QWEN25VL
|
||||
|| model.proj_type == PROJECTOR_TYPE_GLM_EDGE
|
||||
|| model.proj_type == PROJECTOR_TYPE_GEMMA3
|
||||
|| model.proj_type == PROJECTOR_TYPE_IDEFICS3
|
||||
|| model.proj_type == PROJECTOR_TYPE_MINICPMV
|
||||
) && layer.ff_up_w && layer.ff_down_w && layer.ff_down_w->ne[0] == hparams.n_embd;
|
||||
if (is_ffn_swapped) {
|
||||
// swap up and down weights
|
||||
ggml_tensor * tmp = layer.ff_up_w;
|
||||
layer.ff_up_w = layer.ff_down_w;
|
||||
|
|
@ -2486,6 +2564,9 @@ struct clip_model_loader {
|
|||
tmp = layer.ff_up_b;
|
||||
layer.ff_up_b = layer.ff_down_b;
|
||||
layer.ff_down_b = tmp;
|
||||
if (il == 0) {
|
||||
LOG_WRN("%s: ffn up/down are swapped\n", __func__);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -2604,6 +2685,7 @@ struct clip_model_loader {
|
|||
model.projection = get_tensor(TN_MM_PROJECTOR);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_LFM2:
|
||||
case PROJECTOR_TYPE_KIMIVL:
|
||||
{
|
||||
model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM);
|
||||
model.mm_input_norm_b = get_tensor(TN_MM_INP_NORM_B);
|
||||
|
|
@ -3507,7 +3589,9 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
|
|||
res_imgs->grid_y = inst.grid_size.height;
|
||||
return true;
|
||||
|
||||
} else if (ctx->proj_type() == PROJECTOR_TYPE_LFM2) {
|
||||
} else if ( ctx->proj_type() == PROJECTOR_TYPE_LFM2
|
||||
|| ctx->proj_type() == PROJECTOR_TYPE_KIMIVL
|
||||
) {
|
||||
GGML_ASSERT(params.proj_scale_factor);
|
||||
|
||||
// smart resize
|
||||
|
|
@ -3708,12 +3792,21 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
|
|||
case PROJECTOR_TYPE_IDEFICS3:
|
||||
case PROJECTOR_TYPE_INTERNVL:
|
||||
case PROJECTOR_TYPE_LLAMA4:
|
||||
case PROJECTOR_TYPE_LFM2:
|
||||
{
|
||||
// both W and H are divided by proj_scale_factor
|
||||
// both X and Y are downscaled by the scale factor
|
||||
int scale_factor = ctx->model.hparams.proj_scale_factor;
|
||||
n_patches /= (scale_factor * scale_factor);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_LFM2:
|
||||
case PROJECTOR_TYPE_KIMIVL:
|
||||
{
|
||||
// dynamic size
|
||||
int scale_factor = ctx->model.hparams.proj_scale_factor;
|
||||
int out_patch_size = params.patch_size * scale_factor;
|
||||
int x_patch = CLIP_ALIGN(img->nx, out_patch_size) / out_patch_size;
|
||||
int y_patch = CLIP_ALIGN(img->ny, out_patch_size) / out_patch_size;
|
||||
n_patches = x_patch * y_patch;
|
||||
} break;
|
||||
case PROJECTOR_TYPE_PIXTRAL:
|
||||
{
|
||||
// dynamic size
|
||||
|
|
@ -4096,6 +4189,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
|||
set_input_i32("positions", positions);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_PIXTRAL:
|
||||
case PROJECTOR_TYPE_KIMIVL:
|
||||
{
|
||||
// set the 2D positions
|
||||
int n_patches_per_col = image_size_width / patch_size;
|
||||
|
|
@ -4250,6 +4344,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
|
|||
case PROJECTOR_TYPE_QWEN2A:
|
||||
return ctx->model.mm_fc_w->ne[1];
|
||||
case PROJECTOR_TYPE_LFM2:
|
||||
case PROJECTOR_TYPE_KIMIVL:
|
||||
return ctx->model.mm_2_w->ne[1];
|
||||
default:
|
||||
GGML_ABORT("Unknown projector type");
|
||||
|
|
|
|||
|
|
@ -86,6 +86,7 @@ if [ "$RUN_BIG_TESTS" = true ]; then
|
|||
add_test_vision "ggml-org/InternVL3-14B-Instruct-GGUF:Q4_K_M"
|
||||
add_test_vision "ggml-org/Qwen2.5-Omni-7B-GGUF:Q4_K_M"
|
||||
# add_test_vision "ggml-org/Qwen2.5-VL-32B-Instruct-GGUF:Q4_K_M" # does not work on my mac M3 Ultra
|
||||
add_test_vision "ggml-org/Kimi-VL-A3B-Thinking-2506-GGUF:Q4_K_M"
|
||||
|
||||
add_test_audio "ggml-org/ultravox-v0_5-llama-3_1-8b-GGUF:Q4_K_M"
|
||||
add_test_audio "ggml-org/Qwen2.5-Omni-7B-GGUF:Q4_K_M"
|
||||
|
|
|
|||
Loading…
Reference in New Issue