Add image support for Kimi-K2.5

This commit is contained in:
Aes Sedai 2026-02-01 02:14:26 -08:00
parent a4c9a08270
commit 9c44981c01
8 changed files with 299 additions and 1 deletions

View File

@ -10688,7 +10688,7 @@ class LightOnOCRVisionModel(LlavaVisionModel):
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("KimiVLForConditionalGeneration", "KimiK25ForConditionalGeneration")
@ModelBase.register("KimiVLForConditionalGeneration")
class KimiVLModel(MmprojModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@ -10729,6 +10729,75 @@ class KimiVLModel(MmprojModel):
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("KimiK25ForConditionalGeneration")
class KimiK25Model(MmprojModel):
"""Kimi-K2.5 with MoonViT3d vision encoder"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert self.hparams_vision is not None, "Kimi-K2.5 requires vision_config in model config"
self.merge_kernel_size = tuple(self.hparams_vision.get("merge_kernel_size", [2, 2]))
self.patch_size = self.hparams_vision.get("patch_size", 14)
# Set image_size for compatibility with base class
# Use position embedding dimensions as image_size reference
pos_emb_h = self.hparams_vision.get("init_pos_emb_height", 64)
self.hparams_vision["image_size"] = pos_emb_h * self.patch_size
def set_gguf_parameters(self):
# Base class MmprojModel.set_gguf_parameters() already writes:
# - vision_block_count, vision_head_count, vision_embedding_length
# - vision_feed_forward_length, vision_patch_size, image_mean, image_std
# via find_vparam() which handles the vt_* prefixed keys in Kimi-K2.5's config
super().set_gguf_parameters()
assert self.hparams_vision is not None
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.KIMIK25)
# Position embedding parameters (for interpolation) - KimiK25-specific
self.gguf_writer.add_uint32("vision.pos_emb_height", self.hparams_vision.get("init_pos_emb_height", 64))
self.gguf_writer.add_uint32("vision.pos_emb_width", self.hparams_vision.get("init_pos_emb_width", 64))
self.gguf_writer.add_uint32("vision.pos_emb_time", self.hparams_vision.get("init_pos_emb_time", 4))
# Projector parameters
self.gguf_writer.add_vision_use_gelu(self.hparams_vision.get("projector_hidden_act", "gelu") == "gelu")
self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams_vision.get("projector_ln_eps", 1e-5))
self.gguf_writer.add_vision_projector_scale_factor(self.merge_kernel_size[0])
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# Only process vision and projector tensors
is_vision = any(x in name for x in ["vision_tower", "mm_projector"])
if not is_vision:
return
# Split fused QKV tensors in vision encoder
if "wqkv" in name:
split_dim = 0 if "weight" in name else -1
wq, wk, wv = data_torch.chunk(3, dim=split_dim)
yield from super().modify_tensors(wq, name.replace("wqkv", "wq"), bid)
yield from super().modify_tensors(wk, name.replace("wqkv", "wk"), bid)
yield from super().modify_tensors(wv, name.replace("wqkv", "wv"), bid)
return
# Temporal embeddings: (T, 1, C) → (T, C)
if "pos_emb.time_weight" in name:
T, _, C = data_torch.shape
data_torch = data_torch.reshape(T, C)
# PatchMergerMLP tensor name mapping
# proj.0.weight → proj.linear_1.weight
# proj.2.weight → proj.linear_2.weight
if "mm_projector.proj.0." in name:
name = name.replace(".proj.0.", ".proj.linear_1.")
elif "mm_projector.proj.2." in name:
name = name.replace(".proj.2.", ".proj.linear_2.")
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("CogVLMForCausalLM")
class CogVLMVisionModel(MmprojModel):

View File

@ -3610,6 +3610,7 @@ class VisionProjectorType:
VOXTRAL = "voxtral"
LFM2 = "lfm2"
KIMIVL = "kimivl"
KIMIK25 = "kimik25"
LIGHTONOCR = "lightonocr"
COGVLM = "cogvlm"
JANUS_PRO = "janus_pro"

View File

@ -19,6 +19,7 @@ add_library(mtmd
models/glm4v.cpp
models/internvl.cpp
models/kimivl.cpp
models/kimik25.cpp
models/llama4.cpp
models/llava.cpp
models/minicpmv.cpp

View File

@ -107,6 +107,17 @@ struct clip_graph {
const bool interleave_freq
);
// 2D RoPE with interleaved frequency
// Pattern: [x_freq0, y_freq0, x_freq1, y_freq1, ...]
// build_rope_2d uses split pattern: [x_freq0, x_freq1, ..., y_freq0, y_freq1, ...]
ggml_tensor * build_rope_2d_interleaved(
ggml_context * ctx0,
ggml_tensor * cur, // [n_dim, n_head, n_pos]
ggml_tensor * pos_w, // [n_pos] - X/width positions
ggml_tensor * pos_h, // [n_pos] - Y/height positions
const float freq_base
);
// aka pixel_shuffle / pixel_unshuffle / patch_merger (Kimi-VL)
// support dynamic resolution
ggml_tensor * build_patch_merge_permute(ggml_tensor * cur, int scale_factor);

View File

@ -233,6 +233,7 @@ enum projector_type {
PROJECTOR_TYPE_LFM2A,
PROJECTOR_TYPE_GLM4V,
PROJECTOR_TYPE_YOUTUVL,
PROJECTOR_TYPE_KIMIK25,
PROJECTOR_TYPE_UNKNOWN,
};
@ -266,6 +267,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
{ PROJECTOR_TYPE_LFM2A, "lfm2a"},
{ PROJECTOR_TYPE_GLM4V, "glm4v"},
{ PROJECTOR_TYPE_YOUTUVL, "youtuvl"},
{ PROJECTOR_TYPE_KIMIK25, "kimik25"},
};
static projector_type clip_projector_type_from_string(const std::string & str) {

View File

@ -710,6 +710,83 @@ ggml_tensor * clip_graph::build_rope_2d(
return cur;
}
// 2D RoPE with interleaved frequency
// Pattern: [x_freq0, y_freq0, x_freq1, y_freq1, ...]
// build_rope_2d uses split pattern: [x_freq0, x_freq1, ..., y_freq0, y_freq1, ...]
ggml_tensor * clip_graph::build_rope_2d_interleaved(
ggml_context * ctx0,
ggml_tensor * cur, // [n_dim, n_head, n_pos]
ggml_tensor * pos_w, // [n_pos] - X/width positions
ggml_tensor * pos_h, // [n_pos] - Y/height positions
const float freq_base
) {
const int64_t n_dim = cur->ne[0];
const int64_t n_head = cur->ne[1];
const int64_t n_pos = cur->ne[2];
GGML_ASSERT(n_dim % 4 == 0); // Must be divisible by 4 for interleaved x,y pairs
// Step 1: Reshape to expose interleaved structure
// cur: [n_dim, n_head, n_pos] -> [4, n_dim/4, n_head, n_pos]
ggml_tensor * reshaped = ggml_reshape_4d(ctx0, cur, 4, n_dim/4, n_head, n_pos);
// Step 2: Extract X pairs (elements 0,1 of each group of 4)
// x_pairs: [2, n_dim/4, n_head, n_pos]
ggml_tensor * x_pairs = ggml_view_4d(ctx0, reshaped,
2, n_dim/4, n_head, n_pos,
reshaped->nb[1], reshaped->nb[2], reshaped->nb[3],
0);
// Step 3: Extract Y pairs (elements 2,3 of each group of 4)
// y_pairs: [2, n_dim/4, n_head, n_pos]
ggml_tensor * y_pairs = ggml_view_4d(ctx0, reshaped,
2, n_dim/4, n_head, n_pos,
reshaped->nb[1], reshaped->nb[2], reshaped->nb[3],
2 * ggml_element_size(reshaped));
// Step 4: Make contiguous and reshape for rope_ext
// [2, n_dim/4, n_head, n_pos] -> [n_dim/2, n_head, n_pos]
x_pairs = ggml_cont(ctx0, x_pairs);
x_pairs = ggml_reshape_3d(ctx0, x_pairs, n_dim/2, n_head, n_pos);
y_pairs = ggml_cont(ctx0, y_pairs);
y_pairs = ggml_reshape_3d(ctx0, y_pairs, n_dim/2, n_head, n_pos);
// Step 5: Apply RoPE to X pairs using pos_w, Y pairs using pos_h
x_pairs = ggml_rope_ext(
ctx0,
x_pairs,
pos_w,
nullptr,
n_dim/2,
0, 0, freq_base,
1.0f, 0.0f, 1.0f, 0.0f, 0.0f
);
y_pairs = ggml_rope_ext(
ctx0,
y_pairs,
pos_h,
nullptr,
n_dim/2,
0, 0, freq_base,
1.0f, 0.0f, 1.0f, 0.0f, 0.0f
);
// Step 6: Reshape back to [2, n_dim/4, n_head, n_pos] for interleaving
x_pairs = ggml_reshape_4d(ctx0, x_pairs, 2, n_dim/4, n_head, n_pos);
y_pairs = ggml_reshape_4d(ctx0, y_pairs, 2, n_dim/4, n_head, n_pos);
// Step 7: Interleave X and Y pairs back together
// Concatenate along dimension 0: [4, n_dim/4, n_head, n_pos]
ggml_tensor * result = ggml_concat(ctx0, x_pairs, y_pairs, 0);
// Step 8: Reshape back to original: [n_dim, n_head, n_pos]
result = ggml_reshape_3d(ctx0, result, n_dim, n_head, n_pos);
return result;
}
// Generic function to stack frames for audio processing
// Abstracts out the StackAudioFrames logic used by ultravox
ggml_tensor * clip_graph::build_stack(ggml_tensor * cur, int32_t stack_factor, int32_t n_embed) {
@ -825,6 +902,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
{
builder = std::make_unique<clip_graph_kimivl>(ctx, img);
} break;
case PROJECTOR_TYPE_KIMIK25:
{
builder = std::make_unique<clip_graph_kimik25>(ctx, img);
} break;
case PROJECTOR_TYPE_COGVLM:
{
builder = std::make_unique<clip_graph_cogvlm>(ctx, img);
@ -1139,6 +1220,13 @@ struct clip_model_loader {
hparams.set_limit_image_tokens(8, 1024);
hparams.set_warmup_n_tokens(256); // avoid OOM on warmup
} break;
case PROJECTOR_TYPE_KIMIK25:
{
hparams.rope_theta = 10000.0f;
get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false);
hparams.set_limit_image_tokens(8, 4096);
hparams.set_warmup_n_tokens(256);
} break;
case PROJECTOR_TYPE_GEMMA3:
{
// default value (used by all model sizes in gemma 3 family)
@ -1668,6 +1756,7 @@ struct clip_model_loader {
model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"));
} break;
case PROJECTOR_TYPE_KIMIVL:
case PROJECTOR_TYPE_KIMIK25:
{
model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM);
model.mm_input_norm_b = get_tensor(TN_MM_INP_NORM_B);
@ -3039,6 +3128,23 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
res_imgs->entries.push_back(std::move(res));
} break;
case PROJECTOR_TYPE_KIMIK25:
{
GGML_ASSERT(params.image_min_pixels > 0 && params.image_max_pixels > 0);
const clip_image_size target_size = img_tool::calc_size_preserved_ratio(
original_size,
params.patch_size * params.n_merge,
params.image_min_pixels,
params.image_max_pixels);
const std::array<uint8_t, 3> pad_color = {0, 0, 0};
clip_image_u8 resized_img;
img_tool::resize(*img, resized_img, target_size, img_tool::RESIZE_ALGO_BICUBIC, true, pad_color);
clip_image_f32_ptr res(clip_image_f32_init());
normalize_image_u8_to_f32(resized_img, *res, params.image_mean, params.image_std);
res_imgs->entries.push_back(std::move(res));
} break;
case PROJECTOR_TYPE_MLP:
case PROJECTOR_TYPE_MLP_NORM:
case PROJECTOR_TYPE_LDP:
@ -3247,6 +3353,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
} break;
case PROJECTOR_TYPE_LFM2:
case PROJECTOR_TYPE_KIMIVL:
case PROJECTOR_TYPE_KIMIK25:
{
// dynamic size
int out_patch_size = params.patch_size * ctx->model.hparams.n_merge;
@ -3588,6 +3695,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
} break;
case PROJECTOR_TYPE_PIXTRAL:
case PROJECTOR_TYPE_KIMIVL:
case PROJECTOR_TYPE_KIMIK25:
case PROJECTOR_TYPE_LIGHTONOCR:
{
// set the 2D positions
@ -3770,6 +3878,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
return ctx->model.mm_2_w->ne[1];
case PROJECTOR_TYPE_LFM2:
case PROJECTOR_TYPE_KIMIVL:
case PROJECTOR_TYPE_KIMIK25:
return ctx->model.mm_2_w->ne[1];
case PROJECTOR_TYPE_COGVLM:
return ctx->model.mm_4h_to_h_w->ne[1];

View File

@ -0,0 +1,98 @@
#include "models.h"
#include <cstring>
#include <cmath>
// note: this is similar to clip_graph::resize_position_embeddings, major difference is having
// the w/h in ne[1] and ne[2] instead of assuming with sqrt. Could try storing the tensor in 2D instead
// with a w*h? Also the permute is a bit different at (2, 1, 0, 3) instead of (2, 0, 1, 3).
ggml_tensor * clip_graph_kimik25::resize_position_embeddings_3d(uint32_t interpolation_mode) {
ggml_tensor * pos_embd = model.position_embeddings;
const int height = img.ny / patch_size;
const int width = img.nx / patch_size;
const uint32_t mode = interpolation_mode;
GGML_ASSERT(pos_embd);
const int64_t stored_c = pos_embd->ne[0]; // C = 1152
const int64_t orig_w = pos_embd->ne[1]; // W = 64
const int64_t orig_h = pos_embd->ne[2]; // H = 64
GGML_ASSERT(stored_c == n_embd);
if (height == (int)orig_h && width == (int)orig_w) {
// No interpolation needed, just flatten to [C, H*W]
return ggml_cont_2d(ctx0, pos_embd, n_embd, width * height);
}
pos_embd = ggml_permute(ctx0, pos_embd, 2, 1, 0, 3);
pos_embd = ggml_interpolate(ctx0, pos_embd, height, width, n_embd, 1, mode);
pos_embd = ggml_permute(ctx0, pos_embd, 1, 2, 0, 3);
pos_embd = ggml_cont_2d(ctx0, pos_embd, n_embd, width * height);
return pos_embd;
}
ggml_cgraph * clip_graph_kimik25::build() {
ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches);
ggml_set_name(pos_h, "pos_h");
ggml_set_input(pos_h);
ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches);
ggml_set_name(pos_w, "pos_w");
ggml_set_input(pos_w);
ggml_tensor * learned_pos_embd = resize_position_embeddings_3d(GGML_SCALE_MODE_BICUBIC);
// Kimi-K2.5 uses INTERLEAVED frequency pattern: [x_freq0, y_freq0, x_freq1, y_freq1, ...]
auto add_pos = [&](ggml_tensor * cur, const clip_layer &) {
return build_rope_2d_interleaved(ctx0, cur, pos_w, pos_h, hparams.rope_theta);
};
ggml_tensor * inp = build_inp();
// I don't know why, but doing this in the build_vit lead to the ggml_add not occurring?
// Doing it manually here does work.
inp = ggml_add(ctx0, inp, learned_pos_embd);
ggml_tensor * cur = build_vit(
inp, n_patches,
NORM_TYPE_NORMAL,
hparams.ffn_op,
nullptr,
add_pos);
cb(cur, "vit_out", -1);
{
// patch_merger
const int scale_factor = model.hparams.n_merge;
cur = build_patch_merge_permute(cur, scale_factor);
// projection norm
int proj_inp_dim = cur->ne[0];
cur = ggml_view_2d(ctx0, cur,
n_embd, cur->ne[1] * scale_factor * scale_factor,
ggml_row_size(cur->type, n_embd), 0);
cur = ggml_norm(ctx0, cur, hparams.eps);
cur = ggml_mul(ctx0, cur, model.mm_input_norm_w);
cur = ggml_add(ctx0, cur, model.mm_input_norm_b);
cur = ggml_view_2d(ctx0, cur,
proj_inp_dim, cur->ne[1],
ggml_row_size(cur->type, proj_inp_dim), 0);
cb(cur, "proj_inp_normed", -1);
// projection mlp
cur = build_ffn(cur,
model.mm_1_w, model.mm_1_b,
nullptr, nullptr,
model.mm_2_w, model.mm_2_b,
FFN_GELU,
-1);
cb(cur, "proj_out", -1);
}
// build the graph
ggml_build_forward_expand(gf, cur);
return gf;
}

View File

@ -109,3 +109,10 @@ struct clip_graph_mobilenetv5 : clip_graph {
ggml_tensor * inp,
const mobilenetv5_block & block);
};
struct clip_graph_kimik25 : clip_graph {
clip_graph_kimik25(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
ggml_cgraph * build() override;
ggml_tensor * resize_position_embeddings_3d(uint32_t interpolation_mode);
};