From 88458164c77509d2022e45f71aaf97040667abe2 Mon Sep 17 00:00:00 2001 From: Anav Prasad Date: Wed, 1 Apr 2026 07:07:24 +0000 Subject: [PATCH] CUDA: Add Flash Attention Support for Head Dimension 512 (#20998) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * flash attention support for head dimension 512 added * FA D=512 - match 576 configs, limit ncols2, revert vec cap * fix HIP tile kernel build for D=512 * fix HIP tile kernel occupancy for D=512 on AMD * Apply suggestions from code review Co-authored-by: Johannes Gäßler * fix tile FA compilation --------- Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 30 ++++++++++++++- ggml/src/ggml-cuda/fattn-tile.cu | 4 ++ ggml/src/ggml-cuda/fattn-tile.cuh | 37 +++++++++++++++---- ggml/src/ggml-cuda/fattn.cu | 11 ++++-- ...attn-mma-f16-instance-ncols1_1-ncols2_8.cu | 1 + ...ttn-mma-f16-instance-ncols1_16-ncols2_4.cu | 1 + ...attn-mma-f16-instance-ncols1_2-ncols2_4.cu | 1 + ...attn-mma-f16-instance-ncols1_2-ncols2_8.cu | 1 + ...attn-mma-f16-instance-ncols1_4-ncols2_4.cu | 1 + ...attn-mma-f16-instance-ncols1_4-ncols2_8.cu | 1 + ...attn-mma-f16-instance-ncols1_8-ncols2_4.cu | 1 + ...attn-mma-f16-instance-ncols1_8-ncols2_8.cu | 1 + .../fattn-tile-instance-dkq512-dv512.cu | 5 +++ .../template-instances/generate_cu_files.py | 4 +- 14 files changed, 86 insertions(+), 13 deletions(-) create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index fff70c8eb8..b613ae61fb 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -66,6 +66,11 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 32, 128, 128, 128, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 256, 256, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 256, 256, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false); @@ -80,6 +85,11 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 96, 64, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 96, 64, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false); @@ -89,6 +99,11 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co } static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_volta(const int DKQ, const int DV, const int ncols) { + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 256, 256, 64, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 256, 256, 64, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 64, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 64, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 64, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 64, 1, false); @@ -103,6 +118,10 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 128, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false); @@ -1552,7 +1571,7 @@ static __global__ void flash_attn_ext_f16( #if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)) // Skip unused kernel variants for faster compilation: - if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) { + if (use_logit_softcap && !(DKQ == 128 || DKQ == 256 || DKQ == 512)) { NO_DEVICE_CODE; return; } @@ -1815,6 +1834,15 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 64) DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 64) DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64) +extern DECL_FATTN_MMA_F16_CASE(512, 512, 2, 4); +extern DECL_FATTN_MMA_F16_CASE(512, 512, 4, 4); +extern DECL_FATTN_MMA_F16_CASE(512, 512, 8, 4); +extern DECL_FATTN_MMA_F16_CASE(512, 512, 16, 4); +extern DECL_FATTN_MMA_F16_CASE(512, 512, 1, 8); +extern DECL_FATTN_MMA_F16_CASE(512, 512, 2, 8); +extern DECL_FATTN_MMA_F16_CASE(512, 512, 4, 8); +extern DECL_FATTN_MMA_F16_CASE(512, 512, 8, 8); + // The number of viable configurations for Deepseek is very limited: extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16); extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16); diff --git a/ggml/src/ggml-cuda/fattn-tile.cu b/ggml/src/ggml-cuda/fattn-tile.cu index 3fcb09b7a2..25b16e83ca 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cu +++ b/ggml/src/ggml-cuda/fattn-tile.cu @@ -38,6 +38,10 @@ void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor GGML_ASSERT(V->ne[0] == K->ne[0]); ggml_cuda_flash_attn_ext_tile_case<256, 256>(ctx, dst); } break; + case 512: { + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_cuda_flash_attn_ext_tile_case<512, 512>(ctx, dst); + } break; case 576: { GGML_ASSERT(V->ne[0] == 512); ggml_cuda_flash_attn_ext_tile_case<576, 512>(ctx, dst); diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh index f3fa80ab23..26721cc4c7 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cuh +++ b/ggml/src/ggml-cuda/fattn-tile.cuh @@ -68,6 +68,10 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64) @@ -124,6 +128,10 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 32, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 32, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64) @@ -187,6 +195,11 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 32, 512, 1, 128, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64) @@ -251,6 +264,11 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 4, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 32, 256, 2, 128, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64) @@ -767,7 +785,7 @@ static __global__ void flash_attn_tile( #ifdef GGML_USE_WMMA_FATTN (ncols2 != 1 && DV != 40 && DV != 72 && DV != 512) || #endif // GGML_USE_WMMA_FATTN - (use_logit_softcap && !(DV == 128 || DV == 256)) + (use_logit_softcap && !(DV == 128 || DV == 256 || DV == 512)) ) { GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, max_bias, m0, m1, n_head_log2, logit_softcap, @@ -1192,7 +1210,7 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm const int gqa_limit = nvidia && gqa_ratio <= 4 && DV <= 256 ? 16 : INT_MAX; const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0; - if constexpr (DV == 512) { + if constexpr (DKQ == 576) { if (use_gqa_opt && gqa_ratio % 16 == 0) { launch_fattn_tile_switch_ncols1(ctx, dst); return; @@ -1203,7 +1221,7 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm } } - if constexpr (DV <= 256) { + if constexpr (DKQ <= 512) { if (use_gqa_opt && gqa_ratio % 8 == 0) { launch_fattn_tile_switch_ncols1(ctx, dst); return; @@ -1214,13 +1232,15 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm return; } - if (use_gqa_opt && gqa_ratio % 2 == 0) { - launch_fattn_tile_switch_ncols1(ctx, dst); + if constexpr (DV <= 256) { + if (use_gqa_opt && gqa_ratio % 2 == 0) { + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } + + launch_fattn_tile_switch_ncols1(ctx, dst); return; } - - launch_fattn_tile_switch_ncols1(ctx, dst); - return; } GGML_ABORT("fatal error"); } @@ -1255,4 +1275,5 @@ extern DECL_FATTN_TILE_CASE( 96, 96); extern DECL_FATTN_TILE_CASE(112, 112); extern DECL_FATTN_TILE_CASE(128, 128); extern DECL_FATTN_TILE_CASE(256, 256); +extern DECL_FATTN_TILE_CASE(512, 512); extern DECL_FATTN_TILE_CASE(576, 512); diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index a25a890db6..a21c536104 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -135,6 +135,10 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg GGML_ASSERT(V->ne[0] == 256); ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst); break; + case 512: + GGML_ASSERT(V->ne[0] == 512); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<512, 512>(ctx, dst); + break; case 576: { // For Deepseek, go straight to the ncols1 switch to avoid compiling unnecessary kernels. GGML_ASSERT(V->ne[0] == 512); @@ -336,7 +340,8 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const case 128: case 112: case 256: - if (V->ne[0] != K->ne[0]) { + case 512: + if (!gqa_opt_applies) { return BEST_FATTN_KERNEL_NONE; } break; @@ -424,7 +429,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const } // Use the WMMA kernel if possible: - if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 576) { + if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 512 && Q->ne[0] != 576) { if (can_use_vector_kernel && Q->ne[1] <= 2) { return BEST_FATTN_KERNEL_VEC; } @@ -457,7 +462,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const } // Use MFMA flash attention for CDNA (MI100+): - if (amd_mfma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 256 && Q->ne[0] != 576) { + if (amd_mfma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 256 && Q->ne[0] != 512 && Q->ne[0] != 576) { const int64_t eff_nq = Q->ne[1] * (gqa_opt_applies ? gqa_ratio : 1); // MMA vs tile crossover benchmarked on MI300X @ d32768: // hsk=64 (gqa=4): MMA wins at eff >= 128 (+11%) diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu index dc16829021..22d383173f 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu @@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 1, 8); DECL_FATTN_MMA_F16_CASE(112, 112, 1, 8); DECL_FATTN_MMA_F16_CASE(128, 128, 1, 8); DECL_FATTN_MMA_F16_CASE(256, 256, 1, 8); +DECL_FATTN_MMA_F16_CASE(512, 512, 1, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu index 517993cb06..d2415bfa95 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu @@ -8,4 +8,5 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4); DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4); DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4); DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4); +DECL_FATTN_MMA_F16_CASE(512, 512, 16, 4); DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu index 97b19c67ad..8eec1d74e2 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu @@ -8,4 +8,5 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4); DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4); DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4); DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4); +DECL_FATTN_MMA_F16_CASE(512, 512, 2, 4); DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu index 163b1d939e..84b674cd05 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu @@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 2, 8); DECL_FATTN_MMA_F16_CASE(112, 112, 2, 8); DECL_FATTN_MMA_F16_CASE(128, 128, 2, 8); DECL_FATTN_MMA_F16_CASE(256, 256, 2, 8); +DECL_FATTN_MMA_F16_CASE(512, 512, 2, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu index 989626dfa5..3475dfea08 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu @@ -8,4 +8,5 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4); DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4); DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4); DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4); +DECL_FATTN_MMA_F16_CASE(512, 512, 4, 4); DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu index bad296b414..5906398db9 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu @@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 8); DECL_FATTN_MMA_F16_CASE(112, 112, 4, 8); DECL_FATTN_MMA_F16_CASE(128, 128, 4, 8); DECL_FATTN_MMA_F16_CASE(256, 256, 4, 8); +DECL_FATTN_MMA_F16_CASE(512, 512, 4, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu index 173de7aac7..684cd25ce0 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu @@ -8,4 +8,5 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4); DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4); DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4); DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4); +DECL_FATTN_MMA_F16_CASE(512, 512, 8, 4); DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu index 680a13ca6d..4bc60d62f9 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu @@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 8); DECL_FATTN_MMA_F16_CASE(112, 112, 8, 8); DECL_FATTN_MMA_F16_CASE(128, 128, 8, 8); DECL_FATTN_MMA_F16_CASE(256, 256, 8, 8); +DECL_FATTN_MMA_F16_CASE(512, 512, 8, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu new file mode 100644 index 0000000000..7c61d8d2ec --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(512, 512); diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index 3b5ab12fc4..b7b5832293 100755 --- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -3,7 +3,7 @@ from glob import glob import os -HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 576] +HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 512, 576] TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_BF16"] @@ -83,6 +83,8 @@ for ncols in [8, 16, 32, 64]: continue if head_size_kq == 72: continue + if head_size_kq == 512 and ncols2 not in (4, 8): + continue if head_size_kq != 576 and ncols2 in (16, 32): continue if head_size_kq == 576 and ncols2 not in (4, 16, 32):