mtmd: fix use_non_causal being reported incorrectly (#18793)

* mtmd: fix use_non_causal being reported incorrectly

* move clip_is_mrope to mtmd_decode_use_mrope

* fix sloppy code ggml_cpy
This commit is contained in:
Xuan-Son Nguyen 2026-01-13 12:19:38 +01:00 committed by GitHub
parent 0a57271ab6
commit e047f9ee9d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 15 additions and 26 deletions

View File

@ -258,12 +258,12 @@ ggml_tensor * llm_build_gemma3n_iswa::get_per_layer_inputs() {
res->add_input(std::move(inp));
} else {
// Vision embedding path: use padding token (ID=0) embedding
// TODO: verify if this is the correct behavior in transformers implementation
const int64_t embd_size = model.tok_embd_per_layer->ne[0]; // n_embd_altup * n_layer
// Extract and dequantize padding token embedding (column 0)
ggml_tensor * padding_q = ggml_view_1d(ctx0, model.tok_embd_per_layer, embd_size, 0);
ggml_tensor * padding_f32 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, embd_size);
inp_per_layer = ggml_cpy(ctx0, padding_q, padding_f32);
// Extract and dequantize padding token embedding (row 0)
ggml_tensor * padding = ggml_view_1d(ctx0, model.tok_embd_per_layer, embd_size, 0);
inp_per_layer = ggml_cast(ctx0, padding, GGML_TYPE_F32);
// Reshape to [n_embd_altup, n_layer, 1]
inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, 1);

View File

@ -3808,18 +3808,6 @@ bool clip_is_glm(const struct clip_ctx * ctx) {
return ctx->proj_type() == PROJECTOR_TYPE_GLM_EDGE;
}
bool clip_is_mrope(const struct clip_ctx * ctx) {
switch (ctx->proj_type()) {
case PROJECTOR_TYPE_QWEN2VL:
case PROJECTOR_TYPE_QWEN25VL:
case PROJECTOR_TYPE_QWEN3VL:
case PROJECTOR_TYPE_GLM4V:
return true;
default:
return false;
}
}
bool clip_is_llava(const struct clip_ctx * ctx) {
return ctx->model.hparams.has_llava_projector;
}

View File

@ -104,7 +104,6 @@ bool clip_image_batch_encode(struct clip_ctx * ctx, int n_threads, const struct
int clip_is_minicpmv(const struct clip_ctx * ctx);
bool clip_is_glm(const struct clip_ctx * ctx);
bool clip_is_mrope(const struct clip_ctx * ctx);
bool clip_is_llava(const struct clip_ctx * ctx);
// note for contributor: this clip_is_(model) pattern is deprecated
// do NOT add new functions like this

View File

@ -146,8 +146,6 @@ struct mtmd_context {
bool tok_row_end_trail = false;
bool ov_img_first = false;
bool use_mrope = false; // for Qwen2VL, we need to use M-RoPE
// string template for slice image delimiters with row/col (idefics3)
std::string sli_img_start_tmpl;
@ -217,7 +215,6 @@ struct mtmd_context {
void init_vision() {
GGML_ASSERT(ctx_v != nullptr);
use_mrope = clip_is_mrope(ctx_v);
projector_type proj = clip_get_projector_type(ctx_v);
int minicpmv_version = clip_is_minicpmv(ctx_v);
@ -627,7 +624,7 @@ struct mtmd_tokenizer {
}
mtmd_image_tokens_ptr image_tokens(new mtmd_image_tokens);
if (ctx->use_mrope) {
if (mtmd_decode_use_mrope(ctx)) {
// for Qwen2VL, we need this information for M-RoPE decoding positions
image_tokens->nx = clip_n_output_tokens_x(ctx->ctx_v, batch_f32.entries[0].get());
image_tokens->ny = clip_n_output_tokens_y(ctx->ctx_v, batch_f32.entries[0].get());
@ -863,10 +860,7 @@ float * mtmd_get_output_embd(mtmd_context * ctx) {
bool mtmd_decode_use_non_causal(mtmd_context * ctx) {
switch (ctx->proj_type_v()) {
case PROJECTOR_TYPE_QWEN2VL:
case PROJECTOR_TYPE_QWEN25VL:
case PROJECTOR_TYPE_QWEN3VL:
case PROJECTOR_TYPE_YOUTUVL:
case PROJECTOR_TYPE_GEMMA3:
return true;
default:
return false;
@ -874,7 +868,15 @@ bool mtmd_decode_use_non_causal(mtmd_context * ctx) {
}
bool mtmd_decode_use_mrope(mtmd_context * ctx) {
return ctx->use_mrope;
switch (ctx->proj_type_v()) {
case PROJECTOR_TYPE_QWEN2VL:
case PROJECTOR_TYPE_QWEN25VL:
case PROJECTOR_TYPE_QWEN3VL:
case PROJECTOR_TYPE_GLM4V:
return true;
default:
return false;
}
}
bool mtmd_support_vision(mtmd_context * ctx) {