This commit is contained in:
CaffeinatedBits 2026-03-16 12:39:21 +11:00 committed by GitHub
commit 9c5d68f926
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 43 additions and 6 deletions

View File

@ -22,6 +22,10 @@ void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor
GGML_ASSERT(V->ne[0] == K->ne[0]);
ggml_cuda_flash_attn_ext_tile_case< 80, 80>(ctx, dst);
} break;
case 88: {
GGML_ASSERT(V->ne[0] == K->ne[0]);
ggml_cuda_flash_attn_ext_tile_case< 88, 88>(ctx, dst);
} break;
case 96: {
GGML_ASSERT(V->ne[0] == K->ne[0]);
ggml_cuda_flash_attn_ext_tile_case< 96, 96>(ctx, dst);

View File

@ -44,6 +44,12 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 64, 40)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 64, 40)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 2, 64, 2, 64, 88)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 4, 128, 2, 64, 88)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 8, 256, 2, 64, 88)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 16, 256, 2, 64, 88)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 32, 256, 2, 64, 88)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 64, 48)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 64, 48)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 64, 48)
@ -100,6 +106,12 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 32, 40)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 2, 64, 2, 32, 88)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 4, 128, 2, 32, 88)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 8, 256, 2, 32, 88)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 16, 256, 2, 32, 88)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 32, 256, 2, 32, 88)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48)
@ -160,6 +172,13 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 64, 256, 2, 32, 40)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 2, 64, 2, 32, 88)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 4, 128, 2, 32, 88)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 8, 256, 2, 32, 88)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 16, 256, 2, 32, 88)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 32, 256, 2, 32, 88)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 64, 256, 2, 32, 88)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48)
@ -224,6 +243,13 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 64, 256, 2, 32, 40)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 2, 64, 2, 32, 88)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 4, 128, 2, 32, 88)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 8, 256, 2, 32, 88)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 16, 256, 2, 32, 88)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 32, 256, 2, 32, 88)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 64, 256, 2, 32, 88)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48)
@ -1251,6 +1277,7 @@ extern DECL_FATTN_TILE_CASE( 40, 40);
extern DECL_FATTN_TILE_CASE( 64, 64);
extern DECL_FATTN_TILE_CASE( 72, 72);
extern DECL_FATTN_TILE_CASE( 80, 80);
extern DECL_FATTN_TILE_CASE( 88, 88);
extern DECL_FATTN_TILE_CASE( 96, 96);
extern DECL_FATTN_TILE_CASE(112, 112);
extern DECL_FATTN_TILE_CASE(128, 128);

View File

@ -317,6 +317,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
case 64:
case 72:
case 80:
case 88:
case 96:
case 128:
case 112:
@ -368,7 +369,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0;
// If Turing tensor cores are available, use them:
if (turing_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) {
if (turing_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 88) {
if (can_use_vector_kernel) {
if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) {
@ -392,7 +393,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
return BEST_FATTN_KERNEL_MMA_F16;
}
if (volta_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) {
if (volta_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 88) {
int gqa_ratio_eff = 1;
const int ncols2_max = Q->ne[0] == 576 ? 16 : 8;
while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) {
@ -408,14 +409,14 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
}
// Use the WMMA kernel if possible:
if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 576) {
if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 88 && Q->ne[0] != 576) {
if (can_use_vector_kernel && Q->ne[1] <= 2) {
return BEST_FATTN_KERNEL_VEC;
}
return BEST_FATTN_KERNEL_WMMA_F16;
}
if (amd_wmma_available(cc) && GGML_CUDA_CC_IS_RDNA4(cc) && gqa_opt_applies && Q->ne[0] <= 128 && Q->ne[0] != 40 && Q->ne[0] != 72) {
if (amd_wmma_available(cc) && GGML_CUDA_CC_IS_RDNA4(cc) && gqa_opt_applies && Q->ne[0] <= 128 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 88) {
if (can_use_vector_kernel) {
if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
if (Q->ne[1] == 1) {
@ -441,7 +442,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
}
// Use MFMA flash attention for CDNA (MI100+):
if (amd_mfma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 256 && Q->ne[0] != 576) {
if (amd_mfma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 88 && Q->ne[0] != 256 && Q->ne[0] != 576) {
const int64_t eff_nq = Q->ne[1] * (gqa_opt_applies ? gqa_ratio : 1);
// MMA vs tile crossover benchmarked on MI300X @ d32768:
// hsk=64 (gqa=4): MMA wins at eff >= 128 (+11%)

View File

@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-tile.cuh"
DECL_FATTN_TILE_CASE(88, 88);

View File

@ -3,7 +3,7 @@
from glob import glob
import os
HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 576]
HEAD_SIZES_KQ = [40, 64, 72, 80, 88, 96, 112, 128, 256, 576]
TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0"]