CUDA: skip compilation of superfluous FA kernels (#21768)

This commit is contained in:
Johannes Gäßler 2026-04-11 18:52:11 +02:00 committed by GitHub
parent 073bb2c20b
commit ff5ef82786
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 19 additions and 11 deletions

View File

@ -75,13 +75,17 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con
return;
}
if (use_gqa_opt && gqa_ratio % 2 == 0) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2>(ctx, dst);
return;
}
if constexpr (DKQ <= 256) {
if (use_gqa_opt && gqa_ratio % 2 == 0) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2>(ctx, dst);
return;
}
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 1>(ctx, dst);
return;
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 1>(ctx, dst);
return;
} else {
GGML_ABORT("fatal error");
}
}
if (use_gqa_opt && gqa_ratio > 4) {
@ -94,12 +98,16 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con
return;
}
if (use_gqa_opt && gqa_ratio > 1) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2>(ctx, dst);
return;
}
if constexpr (DKQ <= 256) {
if (use_gqa_opt && gqa_ratio > 1) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2>(ctx, dst);
return;
}
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 1>(ctx, dst);
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 1>(ctx, dst);
} else {
GGML_ABORT("fatal error");
}
}
static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {