diff --git a/tools/mtmd/mtmd-helper.cpp b/tools/mtmd/mtmd-helper.cpp index 778aacb61d..2f45dab447 100644 --- a/tools/mtmd/mtmd-helper.cpp +++ b/tools/mtmd/mtmd-helper.cpp @@ -274,7 +274,8 @@ int32_t mtmd_helper_decode_image_chunk( batch_embd.set_position_normal(n_past, seq_id); } - if (mtmd_decode_use_non_causal(ctx)) { + const bool use_non_causal = mtmd_decode_use_non_causal(ctx, chunk); + if (use_non_causal) { llama_set_causal_attn(lctx, false); // TODO @ngxson : need to make sure only one image is processed at a time, and n_ubatch must be enough to hold the image } @@ -302,7 +303,7 @@ int32_t mtmd_helper_decode_image_chunk( n_past += mtmd_input_chunk_get_n_pos(chunk); *new_n_past = n_past; - if (mtmd_decode_use_non_causal(ctx)) { + if (use_non_causal) { llama_set_causal_attn(lctx, true); } return 0; diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index 0b27f960be..dc2bde1944 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -1017,8 +1017,12 @@ float * mtmd_get_output_embd(mtmd_context * ctx) { return ctx->image_embd_v.data(); } -bool mtmd_decode_use_non_causal(mtmd_context * ctx) { - switch (ctx->proj_type_v()) { +bool mtmd_decode_use_non_causal(mtmd_context * ctx, const mtmd_input_chunk * chunk) { + auto proj_type = ctx->proj_type_v(); + if (chunk && chunk->type == MTMD_INPUT_CHUNK_TYPE_AUDIO) { + proj_type = ctx->proj_type_a(); + } + switch (proj_type) { case PROJECTOR_TYPE_GEMMA3: case PROJECTOR_TYPE_GEMMA4V: return true; diff --git a/tools/mtmd/mtmd.h b/tools/mtmd/mtmd.h index ebb4a18fb3..2ecf95694d 100644 --- a/tools/mtmd/mtmd.h +++ b/tools/mtmd/mtmd.h @@ -114,7 +114,8 @@ MTMD_API mtmd_context * mtmd_init_from_file(const char * mmproj_fname, MTMD_API void mtmd_free(mtmd_context * ctx); // whether we need to set non-causal mask before llama_decode -MTMD_API bool mtmd_decode_use_non_causal(mtmd_context * ctx); +// if chunk is nullptr, we assume the default case where chunk is an image chunk +MTMD_API bool mtmd_decode_use_non_causal(mtmd_context * ctx, const mtmd_input_chunk * chunk); // whether the current model use M-RoPE for llama_decode MTMD_API bool mtmd_decode_use_mrope(mtmd_context * ctx);