mtmd: use causal attn for gemma 4 audio (#21824)

This commit is contained in:
Xuan-Son Nguyen 2026-04-13 09:47:55 +02:00 committed by GitHub
parent 974c8c94cc
commit 920b3e78cb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 11 additions and 5 deletions

View File

@ -274,7 +274,8 @@ int32_t mtmd_helper_decode_image_chunk(
batch_embd.set_position_normal(n_past, seq_id);
}
if (mtmd_decode_use_non_causal(ctx)) {
const bool use_non_causal = mtmd_decode_use_non_causal(ctx, chunk);
if (use_non_causal) {
llama_set_causal_attn(lctx, false);
// TODO @ngxson : need to make sure only one image is processed at a time, and n_ubatch must be enough to hold the image
}
@ -302,7 +303,7 @@ int32_t mtmd_helper_decode_image_chunk(
n_past += mtmd_input_chunk_get_n_pos(chunk);
*new_n_past = n_past;
if (mtmd_decode_use_non_causal(ctx)) {
if (use_non_causal) {
llama_set_causal_attn(lctx, true);
}
return 0;

View File

@ -1017,8 +1017,12 @@ float * mtmd_get_output_embd(mtmd_context * ctx) {
return ctx->image_embd_v.data();
}
bool mtmd_decode_use_non_causal(mtmd_context * ctx) {
switch (ctx->proj_type_v()) {
bool mtmd_decode_use_non_causal(mtmd_context * ctx, const mtmd_input_chunk * chunk) {
auto proj_type = ctx->proj_type_v();
if (chunk && chunk->type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
proj_type = ctx->proj_type_a();
}
switch (proj_type) {
case PROJECTOR_TYPE_GEMMA3:
case PROJECTOR_TYPE_GEMMA4V:
return true;

View File

@ -114,7 +114,8 @@ MTMD_API mtmd_context * mtmd_init_from_file(const char * mmproj_fname,
MTMD_API void mtmd_free(mtmd_context * ctx);
// whether we need to set non-causal mask before llama_decode
MTMD_API bool mtmd_decode_use_non_causal(mtmd_context * ctx);
// if chunk is nullptr, we assume the default case where chunk is an image chunk
MTMD_API bool mtmd_decode_use_non_causal(mtmd_context * ctx, const mtmd_input_chunk * chunk);
// whether the current model use M-RoPE for llama_decode
MTMD_API bool mtmd_decode_use_mrope(mtmd_context * ctx);