CUDA: tune GLM 4.7 Flash FA kernel selection logic (#19097)
This commit is contained in:
parent
c0204a0893
commit
a5bb8ba4c5
|
|
@ -148,6 +148,10 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
|
|||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||
if (gqa_ratio == 20) { // GLM 4.7 Flash
|
||||
if (cc >= GGML_CUDA_CC_BLACKWELL) {
|
||||
if (Q->ne[1] <= 4 && K->ne[1] >= 65536) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
|
||||
break;
|
||||
}
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
|
||||
break;
|
||||
}
|
||||
|
|
@ -161,6 +165,10 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
|
|||
}
|
||||
if (cc >= GGML_CUDA_CC_TURING) {
|
||||
if (Q->ne[1] <= 4) {
|
||||
if (K->ne[1] <= 16384) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
|
||||
break;
|
||||
}
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 32>(ctx, dst);
|
||||
break;
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue