diff --git a/ggml/src/ggml-cuda/fattn-tile.cu b/ggml/src/ggml-cuda/fattn-tile.cu index 3fcb09b7a2..549ea286ef 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cu +++ b/ggml/src/ggml-cuda/fattn-tile.cu @@ -22,6 +22,10 @@ void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor GGML_ASSERT(V->ne[0] == K->ne[0]); ggml_cuda_flash_attn_ext_tile_case< 80, 80>(ctx, dst); } break; + case 88: { + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_cuda_flash_attn_ext_tile_case< 88, 88>(ctx, dst); + } break; case 96: { GGML_ASSERT(V->ne[0] == K->ne[0]); ggml_cuda_flash_attn_ext_tile_case< 96, 96>(ctx, dst); diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh index f3fa80ab23..1f3f146381 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cuh +++ b/ggml/src/ggml-cuda/fattn-tile.cuh @@ -44,6 +44,12 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 64, 40) GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 64, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 2, 64, 2, 64, 88) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 4, 128, 2, 64, 88) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 8, 256, 2, 64, 88) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 16, 256, 2, 64, 88) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 32, 256, 2, 64, 88) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 64, 48) GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 64, 48) GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 64, 48) @@ -100,6 +106,12 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 32, 40) GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 2, 64, 2, 32, 88) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 4, 128, 2, 32, 88) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 8, 256, 2, 32, 88) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 16, 256, 2, 32, 88) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 32, 256, 2, 32, 88) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48) GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48) GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48) @@ -160,6 +172,13 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40) GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 64, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 2, 64, 2, 32, 88) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 4, 128, 2, 32, 88) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 8, 256, 2, 32, 88) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 16, 256, 2, 32, 88) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 32, 256, 2, 32, 88) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 64, 256, 2, 32, 88) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48) GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48) GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48) @@ -224,6 +243,13 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40) GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 64, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 2, 64, 2, 32, 88) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 4, 128, 2, 32, 88) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 8, 256, 2, 32, 88) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 16, 256, 2, 32, 88) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 32, 256, 2, 32, 88) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 64, 256, 2, 32, 88) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48) GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48) GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48) @@ -1251,6 +1277,7 @@ extern DECL_FATTN_TILE_CASE( 40, 40); extern DECL_FATTN_TILE_CASE( 64, 64); extern DECL_FATTN_TILE_CASE( 72, 72); extern DECL_FATTN_TILE_CASE( 80, 80); +extern DECL_FATTN_TILE_CASE( 88, 88); extern DECL_FATTN_TILE_CASE( 96, 96); extern DECL_FATTN_TILE_CASE(112, 112); extern DECL_FATTN_TILE_CASE(128, 128); diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 85c177f496..d99f4efde7 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -317,6 +317,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const case 64: case 72: case 80: + case 88: case 96: case 128: case 112: @@ -368,7 +369,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0; // If Turing tensor cores are available, use them: - if (turing_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) { + if (turing_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 88) { if (can_use_vector_kernel) { if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) { if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) { @@ -392,7 +393,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const return BEST_FATTN_KERNEL_MMA_F16; } - if (volta_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) { + if (volta_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 88) { int gqa_ratio_eff = 1; const int ncols2_max = Q->ne[0] == 576 ? 16 : 8; while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) { @@ -408,14 +409,14 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const } // Use the WMMA kernel if possible: - if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 576) { + if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 88 && Q->ne[0] != 576) { if (can_use_vector_kernel && Q->ne[1] <= 2) { return BEST_FATTN_KERNEL_VEC; } return BEST_FATTN_KERNEL_WMMA_F16; } - if (amd_wmma_available(cc) && GGML_CUDA_CC_IS_RDNA4(cc) && gqa_opt_applies && Q->ne[0] <= 128 && Q->ne[0] != 40 && Q->ne[0] != 72) { + if (amd_wmma_available(cc) && GGML_CUDA_CC_IS_RDNA4(cc) && gqa_opt_applies && Q->ne[0] <= 128 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 88) { if (can_use_vector_kernel) { if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) { if (Q->ne[1] == 1) { @@ -441,7 +442,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const } // Use MFMA flash attention for CDNA (MI100+): - if (amd_mfma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 256 && Q->ne[0] != 576) { + if (amd_mfma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 88 && Q->ne[0] != 256 && Q->ne[0] != 576) { const int64_t eff_nq = Q->ne[1] * (gqa_opt_applies ? gqa_ratio : 1); // MMA vs tile crossover benchmarked on MI300X @ d32768: // hsk=64 (gqa=4): MMA wins at eff >= 128 (+11%) diff --git a/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq88-dv88.cu b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq88-dv88.cu new file mode 100644 index 0000000000..8a8ae539ed --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq88-dv88.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(88, 88); diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index e382df1ae2..82b390bb25 100755 --- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -3,7 +3,7 @@ from glob import glob import os -HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 576] +HEAD_SIZES_KQ = [40, 64, 72, 80, 88, 96, 112, 128, 256, 576] TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0"]