ggml-cuda : add flash attention support for head size 88
Llama 4 vision models use a head dimension of D=88. Previously, this fell back to unoptimized operations, causing massive VRAM bloat and slow inference. This adds D=88 to the CUDA flash attention tile backend. It explicitly excludes 88 from the Turing/Volta/WMMA/MMA Tensor Core checks to prevent memory misalignment/segfaults, forcing the fallback to the TILE kernel. Also updates generate_cu_files.py to dynamically generate the required template instance.
This commit is contained in:
parent
23fbfcb1ad
commit
41e6c8caf4
|
|
@ -22,6 +22,10 @@ void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor
|
|||
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
||||
ggml_cuda_flash_attn_ext_tile_case< 80, 80>(ctx, dst);
|
||||
} break;
|
||||
case 88: {
|
||||
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
||||
ggml_cuda_flash_attn_ext_tile_case< 88, 88>(ctx, dst);
|
||||
} break;
|
||||
case 96: {
|
||||
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
||||
ggml_cuda_flash_attn_ext_tile_case< 96, 96>(ctx, dst);
|
||||
|
|
|
|||
|
|
@ -44,6 +44,12 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
|
|||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 64, 40)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 64, 40)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 2, 64, 2, 64, 88)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 4, 128, 2, 64, 88)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 8, 256, 2, 64, 88)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 16, 256, 2, 64, 88)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 32, 256, 2, 64, 88)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 64, 48)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 64, 48)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 64, 48)
|
||||
|
|
@ -100,6 +106,12 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
|
|||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 32, 40)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 2, 64, 2, 32, 88)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 4, 128, 2, 32, 88)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 8, 256, 2, 32, 88)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 16, 256, 2, 32, 88)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 32, 256, 2, 32, 88)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48)
|
||||
|
|
@ -160,6 +172,13 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
|
|||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 64, 256, 2, 32, 40)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 2, 64, 2, 32, 88)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 4, 128, 2, 32, 88)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 8, 256, 2, 32, 88)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 16, 256, 2, 32, 88)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 32, 256, 2, 32, 88)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 64, 256, 2, 32, 88)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48)
|
||||
|
|
@ -224,6 +243,13 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
|
|||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 64, 256, 2, 32, 40)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 2, 64, 2, 32, 88)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 4, 128, 2, 32, 88)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 8, 256, 2, 32, 88)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 16, 256, 2, 32, 88)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 32, 256, 2, 32, 88)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 88, 88, 64, 256, 2, 32, 88)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48)
|
||||
|
|
@ -1251,6 +1277,7 @@ extern DECL_FATTN_TILE_CASE( 40, 40);
|
|||
extern DECL_FATTN_TILE_CASE( 64, 64);
|
||||
extern DECL_FATTN_TILE_CASE( 72, 72);
|
||||
extern DECL_FATTN_TILE_CASE( 80, 80);
|
||||
extern DECL_FATTN_TILE_CASE( 88, 88);
|
||||
extern DECL_FATTN_TILE_CASE( 96, 96);
|
||||
extern DECL_FATTN_TILE_CASE(112, 112);
|
||||
extern DECL_FATTN_TILE_CASE(128, 128);
|
||||
|
|
|
|||
|
|
@ -317,6 +317,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
|||
case 64:
|
||||
case 72:
|
||||
case 80:
|
||||
case 88:
|
||||
case 96:
|
||||
case 128:
|
||||
case 112:
|
||||
|
|
@ -368,7 +369,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
|||
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0;
|
||||
|
||||
// If Turing tensor cores are available, use them:
|
||||
if (turing_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) {
|
||||
if (turing_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 88) {
|
||||
if (can_use_vector_kernel) {
|
||||
if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
|
||||
if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) {
|
||||
|
|
@ -392,7 +393,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
|||
return BEST_FATTN_KERNEL_MMA_F16;
|
||||
}
|
||||
|
||||
if (volta_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) {
|
||||
if (volta_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 88) {
|
||||
int gqa_ratio_eff = 1;
|
||||
const int ncols2_max = Q->ne[0] == 576 ? 16 : 8;
|
||||
while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) {
|
||||
|
|
@ -408,14 +409,14 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
|||
}
|
||||
|
||||
// Use the WMMA kernel if possible:
|
||||
if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 576) {
|
||||
if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 88 && Q->ne[0] != 576) {
|
||||
if (can_use_vector_kernel && Q->ne[1] <= 2) {
|
||||
return BEST_FATTN_KERNEL_VEC;
|
||||
}
|
||||
return BEST_FATTN_KERNEL_WMMA_F16;
|
||||
}
|
||||
|
||||
if (amd_wmma_available(cc) && GGML_CUDA_CC_IS_RDNA4(cc) && gqa_opt_applies && Q->ne[0] <= 128 && Q->ne[0] != 40 && Q->ne[0] != 72) {
|
||||
if (amd_wmma_available(cc) && GGML_CUDA_CC_IS_RDNA4(cc) && gqa_opt_applies && Q->ne[0] <= 128 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 88) {
|
||||
if (can_use_vector_kernel) {
|
||||
if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
|
||||
if (Q->ne[1] == 1) {
|
||||
|
|
@ -441,7 +442,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
|||
}
|
||||
|
||||
// Use MFMA flash attention for CDNA (MI100+):
|
||||
if (amd_mfma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 256 && Q->ne[0] != 576) {
|
||||
if (amd_mfma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 88 && Q->ne[0] != 256 && Q->ne[0] != 576) {
|
||||
const int64_t eff_nq = Q->ne[1] * (gqa_opt_applies ? gqa_ratio : 1);
|
||||
// MMA vs tile crossover benchmarked on MI300X @ d32768:
|
||||
// hsk=64 (gqa=4): MMA wins at eff >= 128 (+11%)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,5 @@
|
|||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-tile.cuh"
|
||||
|
||||
DECL_FATTN_TILE_CASE(88, 88);
|
||||
|
|
@ -3,7 +3,7 @@
|
|||
from glob import glob
|
||||
import os
|
||||
|
||||
HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 576]
|
||||
HEAD_SIZES_KQ = [40, 64, 72, 80, 88, 96, 112, 128, 256, 576]
|
||||
|
||||
TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0"]
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue