mtmd: use causal attn for gemma 4 audio (#21824)
This commit is contained in:
parent
974c8c94cc
commit
920b3e78cb
|
|
@ -274,7 +274,8 @@ int32_t mtmd_helper_decode_image_chunk(
|
|||
batch_embd.set_position_normal(n_past, seq_id);
|
||||
}
|
||||
|
||||
if (mtmd_decode_use_non_causal(ctx)) {
|
||||
const bool use_non_causal = mtmd_decode_use_non_causal(ctx, chunk);
|
||||
if (use_non_causal) {
|
||||
llama_set_causal_attn(lctx, false);
|
||||
// TODO @ngxson : need to make sure only one image is processed at a time, and n_ubatch must be enough to hold the image
|
||||
}
|
||||
|
|
@ -302,7 +303,7 @@ int32_t mtmd_helper_decode_image_chunk(
|
|||
n_past += mtmd_input_chunk_get_n_pos(chunk);
|
||||
*new_n_past = n_past;
|
||||
|
||||
if (mtmd_decode_use_non_causal(ctx)) {
|
||||
if (use_non_causal) {
|
||||
llama_set_causal_attn(lctx, true);
|
||||
}
|
||||
return 0;
|
||||
|
|
|
|||
|
|
@ -1017,8 +1017,12 @@ float * mtmd_get_output_embd(mtmd_context * ctx) {
|
|||
return ctx->image_embd_v.data();
|
||||
}
|
||||
|
||||
bool mtmd_decode_use_non_causal(mtmd_context * ctx) {
|
||||
switch (ctx->proj_type_v()) {
|
||||
bool mtmd_decode_use_non_causal(mtmd_context * ctx, const mtmd_input_chunk * chunk) {
|
||||
auto proj_type = ctx->proj_type_v();
|
||||
if (chunk && chunk->type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
|
||||
proj_type = ctx->proj_type_a();
|
||||
}
|
||||
switch (proj_type) {
|
||||
case PROJECTOR_TYPE_GEMMA3:
|
||||
case PROJECTOR_TYPE_GEMMA4V:
|
||||
return true;
|
||||
|
|
|
|||
|
|
@ -114,7 +114,8 @@ MTMD_API mtmd_context * mtmd_init_from_file(const char * mmproj_fname,
|
|||
MTMD_API void mtmd_free(mtmd_context * ctx);
|
||||
|
||||
// whether we need to set non-causal mask before llama_decode
|
||||
MTMD_API bool mtmd_decode_use_non_causal(mtmd_context * ctx);
|
||||
// if chunk is nullptr, we assume the default case where chunk is an image chunk
|
||||
MTMD_API bool mtmd_decode_use_non_causal(mtmd_context * ctx, const mtmd_input_chunk * chunk);
|
||||
|
||||
// whether the current model use M-RoPE for llama_decode
|
||||
MTMD_API bool mtmd_decode_use_mrope(mtmd_context * ctx);
|
||||
|
|
|
|||
Loading…
Reference in New Issue