From ecbcb7ea9d3303097519723b264a8b5f1e977028 Mon Sep 17 00:00:00 2001 From: Jayant Lohia Date: Sat, 28 Feb 2026 00:07:26 +0530 Subject: [PATCH] CUDA: add CDNA3 MFMA support for flash attention MMA kernel (#19806) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * CUDA: add CDNA3 MFMA support for flash attention MMA kernel Add MI300X (gfx942) MFMA tensor core flash attention using v_mfma_f32_16x16x16_f16 (FP16 in, FP32 accumulate). - Add FATTN_WARP_SIZE=64 for CDNA wavefront64 - Add CDNA config for head sizes 64, 80, 96, 112, 128 - Add FP16 MFMA intrinsic path in mma.cuh - Add manual V transpose load for MFMA register layout - Route CDNA to MMA for prompt processing, VEC for token generation - Fix Q loading and combine stride granularity for non-power-of-2 heads Benchmarks (Qwen2.5-1.5B Q4_K_M, MI300X): pp512 +7%, pp1024 +13%, pp2048 +23%, pp4096 +39% tg128 -10% (FA overhead, VEC used for both) All 2480 flash attention tests pass. Ref: https://github.com/ggml-org/llama.cpp/issues/17917 * address review: replace FATTN_WARP_SIZE with constexpr, improve dispatch - Replace #define FATTN_WARP_SIZE with constexpr int warp_size = ggml_cuda_get_physical_warp_size() in each device function - Use ne[1]*gqa_ratio threshold for MMA vs tile dispatch. Benchmarked crossover on MI300X @ d32768 with power-of-2 GQA models: hsk=64 (Llama 1B, gqa=4): MMA wins at eff >= 128 (+11%) hsk=128 (Llama 3B, gqa=4): MMA wins at eff >= 128 (+4%) Unified threshold: eff_nq >= 128 for all head sizes. - Remove VEC fallback; small batches fall through to tile kernel * Update ggml/src/ggml-cuda/fattn.cu * use ggml_cuda_info().devices warp_size instead of hardcoded check --------- Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 246 ++++++++++++++++++--------- ggml/src/ggml-cuda/fattn.cu | 12 ++ ggml/src/ggml-cuda/mma.cuh | 30 +++- 3 files changed, 203 insertions(+), 85 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 0b8ef90794..beb7e32e4f 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -111,6 +111,44 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols); } +static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_cdna(const int DKQ, const int DV, const int ncols) { + // Conservative configs for CDNA (MI100+): 64KB LDS, wavefront64, nstages=1 (no cp.async). + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 8, 128, 2, 128, 32, 32, 32, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 16, 128, 2, 64, 32, 32, 32, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 32, 128, 2, 64, 32, 32, 32, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 64, 256, 2, 64, 32, 32, 32, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 8, 128, 2, 128, 40, 40, 40, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 16, 128, 2, 64, 40, 40, 40, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 32, 128, 2, 64, 40, 40, 40, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 64, 256, 2, 64, 40, 40, 40, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 8, 128, 2, 128, 48, 48, 48, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 16, 128, 2, 64, 48, 48, 48, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 32, 128, 2, 64, 48, 48, 48, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 64, 256, 2, 64, 48, 48, 48, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 8, 128, 2, 128, 56, 56, 56, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 16, 128, 2, 64, 56, 56, 56, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 32, 128, 2, 64, 56, 56, 56, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 64, 256, 2, 64, 56, 56, 56, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 8, 128, 2, 128, 64, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 16, 128, 2, 64, 64, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 32, 128, 2, 64, 64, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 64, 256, 2, 64, 64, 64, 64, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 64, 4, 64, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 64, 4, 32, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 256, 2, 32, 128, 128, 128, 1, true); + + // Fallback for unsupported DKQ values (e.g. 576). Must return non-zero values to satisfy + // compile-time static_asserts even though the kernel guard prevents runtime execution. + // nthreads=256 gives nwarps=4 (warp_size=64) or 8 (warp_size=32), nbatch_fa=128 satisfies np*16 divisibility. + return fattn_mma_config(256, 1, 128, 4, 4, 4, 1, false); +} + static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, const int DV, const int ncols, const int cc) { if (ampere_mma_available(cc)) { return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols); @@ -118,6 +156,9 @@ static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, c if (turing_mma_available(cc)) { return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols); } + if (amd_mfma_available(cc)) { + return ggml_cuda_fattn_mma_get_config_cdna(DKQ, DV, ncols); + } if (amd_wmma_available(cc)) { return ggml_cuda_fattn_mma_get_config_rdna(DKQ, DV, ncols); } @@ -130,6 +171,8 @@ static constexpr __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config(cons return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols); #elif defined(TURING_MMA_AVAILABLE) return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols); +#elif defined(AMD_MFMA_AVAILABLE) + return ggml_cuda_fattn_mma_get_config_cdna(DKQ, DV, ncols); #elif defined(VOLTA_MMA_AVAILABLE) return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols); #elif defined(AMD_WMMA_AVAILABLE) @@ -205,15 +248,15 @@ static constexpr __device__ bool ggml_cuda_fattn_mma_get_Q_in_reg(const int DKQ, } static constexpr __device__ int get_cols_per_thread() { -#if defined(AMD_WMMA_AVAILABLE) - return 1; // RDNA has a single column. +#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) + return 1; // AMD has a single column per thread. #else return 2; // This is specifically KQ columns, Volta only has a single VKQ column. -#endif // defined(AMD_WMMA_AVAILABLE) +#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) } static __host__ int get_cols_per_warp(const int cc) { - if (turing_mma_available(cc) || amd_wmma_available(cc)) { + if (turing_mma_available(cc) || amd_wmma_available(cc) || amd_mfma_available(cc)) { return 16; } else { // Volta @@ -241,6 +284,7 @@ static constexpr __device__ int ggml_cuda_fattn_mma_get_nstages(const int DKQ, c template static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV, const int i_sup) { + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); // K/V data is loaded with decreasing granularity for D for better memory bandwidth. // The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes. if constexpr (use_cp_async) { @@ -252,10 +296,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV); auto load = [&] __device__ (auto n) { - const int stride_k = WARP_SIZE >> n; - const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k); + const int stride_k = warp_size >> n; + const int k0_start = stride_k == warp_size ? 0 : chunks_per_row - chunks_per_row % (2*stride_k); const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k); - const int stride_i = WARP_SIZE / stride_k; + const int stride_i = warp_size / stride_k; if (k0_start == k0_stop) { return; @@ -263,7 +307,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( #pragma unroll for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) { - const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + const int i = i0 + threadIdx.y*stride_i + (stride_k == warp_size ? 0 : threadIdx.x / stride_k); if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) { break; @@ -271,7 +315,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( #pragma unroll for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { - const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k); cp_async_cg_16(tile_KV_32 + i*(stride_tile*sizeof(half2)) + k*16, KV + i*stride_KV + k*h2_per_chunk); } @@ -287,10 +331,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( } else { // TODO use ggml_cuda_memcpy_1 auto load = [&] __device__ (const int n) { - const int stride_k = WARP_SIZE >> n; - const int k0_start = stride_k == WARP_SIZE ? 0 : D2 - D2 % (2*stride_k); + const int stride_k = warp_size >> n; + const int k0_start = stride_k == warp_size ? 0 : D2 - D2 % (2*stride_k); const int k0_stop = D2 - D2 % (1*stride_k); - const int stride_i = WARP_SIZE / stride_k; + const int stride_i = warp_size / stride_k; if (k0_start == k0_stop) { return; @@ -298,7 +342,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( #pragma unroll for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) { - const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + const int i = i0 + threadIdx.y*stride_i + (stride_k == warp_size ? 0 : threadIdx.x / stride_k); if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) { break; @@ -306,7 +350,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( #pragma unroll for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { - const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k); tile_KV[i*stride_tile + k] = !oob_check || i < i_sup ? KV[i*stride_KV + k] : make_half2(0.0f, 0.0f); } @@ -324,18 +368,19 @@ template= 32 ? nbatch_fa * sizeof(half) : 64; - constexpr int cols_per_warp = 8*WARP_SIZE/nbatch_fa; + constexpr int cols_per_warp = 8*warp_size/nbatch_fa; constexpr int stride_j = nwarps * cols_per_warp; const unsigned int tile_mask_32 = ggml_cuda_cvta_generic_to_shared(tile_mask); #pragma unroll for (int j1 = 0; j1 < ncols1; j1 += stride_j) { - const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (WARP_SIZE/cols_per_warp); + const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (warp_size/cols_per_warp); const int j_vram = fastmodulo(j0 + j_sram, ne01); if (j1 + stride_j > ncols1 && j_sram >= ncols1) { @@ -357,25 +402,25 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( } #pragma unroll - for (int i0 = 0; i0 < nbatch_fa; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < nbatch_fa; i0 += warp_size) { const int i = i0 + threadIdx.x; tile_mask[j_sram*(nbatch_fa + 8) + i] = i < i_sup ? mask_h[j_vram*stride_mask + i] : half(0.0f); } } - } else if constexpr (nbatch_fa < 2*WARP_SIZE) { - constexpr int cols_per_warp = 2*WARP_SIZE/nbatch_fa; + } else if constexpr (nbatch_fa < 2*warp_size) { + constexpr int cols_per_warp = 2*warp_size/nbatch_fa; constexpr int stride_j = nwarps * cols_per_warp; #pragma unroll for (int j1 = 0; j1 < ncols1; j1 += stride_j) { - const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (WARP_SIZE/cols_per_warp); + const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (warp_size/cols_per_warp); const int j_vram = fastmodulo(j0 + j_sram, ne01); if (j1 + stride_j > ncols1 && j_sram >= ncols1) { break; } - const int i = threadIdx.x % (WARP_SIZE/cols_per_warp); + const int i = threadIdx.x % (warp_size/cols_per_warp); ggml_cuda_memcpy_1(tile_mask + j_sram*(nbatch_fa + 8) + 2*i, mask_h + j_vram*stride_mask + 2*i); } @@ -390,7 +435,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( } #pragma unroll - for (int i0 = 0; i0 < nbatch_fa; i0 += 2*WARP_SIZE) { + for (int i0 = 0; i0 < nbatch_fa; i0 += 2*warp_size) { const int i = i0 + 2*threadIdx.x; ggml_cuda_memcpy_1(tile_mask + j_sram*(nbatch_fa + 8) + i, mask_h + j_vram*stride_mask + i); @@ -428,7 +473,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( const int jt, const int kb0, const int k_VKQ_sup) { -#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) +#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE) + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr int ncols = ncols1 * ncols2; constexpr int cols_per_warp = T_B_KQ::I; constexpr int cols_per_thread = get_cols_per_thread(); @@ -447,7 +493,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( const int k_VKQ_0 = kb0 * nbatch_fa; #if defined(TURING_MMA_AVAILABLE) T_C_KQ KQ_C[nbatch_fa/(np*(cols_per_warp == 8 ? T_C_KQ::I : T_C_KQ::J))]; -#elif defined(AMD_WMMA_AVAILABLE) +#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)]; #else // Volta T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)]; @@ -500,13 +546,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]); } else { // Wide version of KQ_C is column-major -#if defined(AMD_WMMA_AVAILABLE) - // RDNA matrix C is column-major. +#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) + // AMD matrix C is column-major. mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]); #else // swap A and B for CUDA. mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[k_KQ_0/T_A_KQ::J], K_A); -#endif // defined(AMD_WMMA_AVAILABLE) +#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) } } } @@ -526,13 +572,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]); } else { // Wide version of KQ_C is column-major -#if defined(AMD_WMMA_AVAILABLE) - // RDNA matrix C is column-major. +#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) + // AMD matrix C is column-major. mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]); #else // swap A and B for CUDA. mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A); -#endif // defined(AMD_WMMA_AVAILABLE) +#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) } } } @@ -585,12 +631,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( #pragma unroll for (int l = 0; l < T_C_KQ::ne; ++l) { if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) { -#if defined(AMD_WMMA_AVAILABLE) +#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) constexpr int KQ_idx = 0; #else // Turing + Volta: const int KQ_idx = l % 2; -#endif // defined(AMD_WMMA_AVAILABLE) +#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[k0/(np*T_C_KQ::I)].x[l] + FATTN_KQ_MAX_OFFSET); } } @@ -601,7 +647,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( for (int col = 0; col < cols_per_thread; ++col) { #pragma unroll for (int offset = 16; offset >= 4; offset >>= 1) { - KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE)); + KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, warp_size)); } } @@ -611,12 +657,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( #pragma unroll for (int l = 0; l < T_C_KQ::ne; ++l) { if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) { -#if defined(AMD_WMMA_AVAILABLE) +#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) constexpr int KQ_idx = 0; #else // Turing + Volta: const int KQ_idx = l % 2; -#endif // defined(AMD_WMMA_AVAILABLE) +#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) KQ_C[k0/(np*T_C_KQ::I)].x[l] = expf(KQ_C[k0/(np*T_C_KQ::I)].x[l] - KQ_max_new[KQ_idx]); KQ_rowsum_add[KQ_idx] += KQ_C[k0/(np*T_C_KQ::I)].x[l]; } else { @@ -649,12 +695,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( #pragma unroll for (int l = 0; l < T_C_KQ::ne; ++l) { if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) { -#if defined(AMD_WMMA_AVAILABLE) +#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) constexpr int KQ_idx = 0; #else // Turing + Volta: const int KQ_idx = (l/2) % 2; -#endif // defined(AMD_WMMA_AVAILABLE) +#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + FATTN_KQ_MAX_OFFSET); } } @@ -666,6 +712,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( // Values per KQ column are spread across 4 threads: constexpr int offset_first = 2; constexpr int offset_last = 1; +#elif defined(AMD_MFMA_AVAILABLE) + // MFMA: 4 threads per Q column (threadIdx.x % 16 == col, spaced by 16). + constexpr int offset_first = 32; + constexpr int offset_last = 16; #elif defined(AMD_WMMA_AVAILABLE) // Values per KQ column are spread across 2 threads: constexpr int offset_first = 16; @@ -677,7 +727,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( #endif // defined(TURING_MMA_AVAILABLE) #pragma unroll for (int offset = offset_first; offset >= offset_last; offset >>= 1) { - KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE)); + KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, warp_size)); } } @@ -687,12 +737,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( #pragma unroll for (int l = 0; l < T_C_KQ::ne; ++l) { if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) { -#if defined(AMD_WMMA_AVAILABLE) +#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) constexpr int KQ_idx = 0; #else // Turing + Volta: const int KQ_idx = (l/2) % 2; -#endif // defined(AMD_WMMA_AVAILABLE) +#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) KQ_C[(k0/(np*T_C_KQ::J))].x[l] = expf(KQ_C[(k0/(np*T_C_KQ::J))].x[l] - KQ_max_new[KQ_idx]); KQ_rowsum_add[KQ_idx] += KQ_C[(k0/(np*T_C_KQ::J))].x[l]; } else { @@ -739,7 +789,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } } } -#elif defined(AMD_WMMA_AVAILABLE) +#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) const half2 KQ_max_scale_h2 = make_half2( KQ_max_scale[0], KQ_max_scale[0]); #pragma unroll @@ -818,7 +868,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } const half2 * tile_V_i = !V_is_K_view || i0_stop > 2*nbatch_K2 ? tile_V : tile_V + i0_start/2; -#if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) +#if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) constexpr int i0_stride = cols_per_warp == 8 ? T_C_VKQ::I : 2*T_C_VKQ::J; #pragma unroll for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += i0_stride) { @@ -830,24 +880,38 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( T_A_VKQ A; // Transposed in SRAM but not in registers, gets transposed on load. #if defined(LDMATRIX_TRANS_AVAILABLE) load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V); +#elif defined(AMD_MFMA_AVAILABLE) + // MFMA A register layout: A_mat[i=lane%16][k=4*(lane/16)+reg]. + // Normal load gives A_mat[seq][dv] but we need A_mat[dv][seq] = V^T. + // Load with transposed addressing: 4 strided half loads. + { + const half2 * xs0 = tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2; + const half * xs0_h = (const half *) xs0; + const int stride_h = stride_tile_V * 2; // stride in half units + half * A_h = (half *) A.x; +#pragma unroll + for (int l = 0; l < 4; ++l) { + A_h[l] = xs0_h[(4*(threadIdx.x / 16) + l) * stride_h + threadIdx.x % 16]; + } + } #else // TODO: Try to transpose tile_V when loading gmem to smem. // Use mma to transpose T_A_VKQ for RDNA. T_A_VKQ A_trans; load_ldmatrix(A_trans, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V); mma(A, A_trans, A_identity); -#endif // defined(TURING_MMA_AVAILABLE) +#endif // defined(LDMATRIX_TRANS_AVAILABLE) if constexpr (T_B_KQ::I == 8) { mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]); } else { // Wide version of VKQ_C is column-major. -#if defined(AMD_WMMA_AVAILABLE) - // RDNA matrix C is column-major. +#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) + // AMD matrix C is column-major. mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]); #else // swap A and B for CUDA. mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::J)], A); -#endif // defined(AMD_WMMA_AVAILABLE) +#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) } } } @@ -866,7 +930,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::I)], A); } } -#endif // defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) +#endif // defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) if constexpr (nstages <= 1) { __syncthreads(); // Only needed if tile_K == tile_V. @@ -879,7 +943,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); NO_DEVICE_CODE; -#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) +#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE) } #if defined(TURING_MMA_AVAILABLE) @@ -899,7 +963,7 @@ template<> struct mma_tile_sizes<8> { using T_B_VKQ = tile< 8, 8, half2>; // column-major using T_C_VKQ = tile<16, 4, half2>; // row-major }; -#elif defined(AMD_WMMA_AVAILABLE) +#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) template struct mma_tile_sizes { using T_A_KQ = tile<16, 8, half2>; // row-major using T_B_KQ = tile<16, 8, half2>; // column-major @@ -944,9 +1008,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int zt_gqa, const int kb0_start, const int kb0_stop) { -#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) +#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE) //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr int ncols = ncols1 * ncols2; using T_A_KQ = typename mma_tile_sizes::T_A_KQ; using T_B_KQ = typename mma_tile_sizes::T_B_KQ; @@ -986,7 +1051,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( T_B_KQ Q_B[(Q_in_reg ? DKQ/(2*T_B_KQ::J) : 1)]; #if defined(TURING_MMA_AVAILABLE) T_C_VKQ VKQ_C[cols_per_warp == 8 ? DV/T_C_VKQ::I : DV/(2*T_C_VKQ::J)]; -#elif defined(AMD_WMMA_AVAILABLE) +#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) T_C_VKQ VKQ_C[ DV/(2*T_C_VKQ::J)]; #else // Volta T_C_VKQ VKQ_C[ DV/(2*T_C_VKQ::J)]; @@ -1004,10 +1069,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( // The loading is done with decreasing granularity for D for better memory bandwidth. const half2 scale_h2 = make_half2(scale, scale); #pragma unroll - for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { - const int k0_start = stride_k == WARP_SIZE ? 0 : DKQ/2 - (DKQ/2) % (2*stride_k); + for (int stride_k : {warp_size, warp_size/2, warp_size/4, warp_size/8}) { + const int k0_start = stride_k == warp_size ? 0 : DKQ/2 - (DKQ/2) % (2*stride_k); const int k0_stop = DKQ/2 - (DKQ/2) % (1*stride_k); - const int stride_jc = WARP_SIZE / stride_k; + const int stride_jc = warp_size / stride_k; if (k0_start == k0_stop) { continue; @@ -1015,7 +1080,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( #pragma unroll for (int jc0 = 0; jc0 < ncols; jc0 += nwarps*stride_jc) { - const int jc = jc0 + threadIdx.y*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + const int jc = jc0 + threadIdx.y*stride_jc + (stride_k == warp_size ? 0 : threadIdx.x / stride_k); if (jc0 + nwarps*stride_jc > ncols && jc >= ncols) { break; @@ -1027,7 +1092,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( if ((ncols1 == 1 || jt*ncols1 + j < int(ne01.z)) && (ncols2 == 1 || zt_gqa*ncols2 + c < gqa_ratio)) { #pragma unroll for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { - const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k); const float2 tmp = Q_f2[(jt*ncols1 + j)*stride_Q1 + c*stride_Q2 + k]; tile_Q[jc*stride_tile_Q + k] = scale_h2 * make_half2(tmp.x, tmp.y); @@ -1035,7 +1100,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } else { #pragma unroll for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { - const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k); tile_Q[jc*stride_tile_Q + k] = make_half2(0.0f, 0.0f); } @@ -1127,6 +1192,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( // The partial sums are spread across 8/4 threads. constexpr int offset_first = cols_per_warp == 8 ? 16 : 2; constexpr int offset_last = cols_per_warp == 8 ? 4 : 1; +#elif defined(AMD_MFMA_AVAILABLE) + // The partial sums are spread across 4 threads (wavefront64, 16 cols). + constexpr int offset_first = 32; + constexpr int offset_last = 16; #elif defined(AMD_WMMA_AVAILABLE) // The partial sums are spread across 2 threads. constexpr int offset_first = 16; @@ -1140,7 +1209,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( for (int col = 0; col < cols_per_thread; ++col) { #pragma unroll for (int offset = offset_first; offset >= offset_last; offset >>= 1) { - KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, WARP_SIZE); + KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, warp_size); } } } @@ -1189,7 +1258,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } } } -#elif defined(AMD_WMMA_AVAILABLE) +#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[0]); #pragma unroll for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) { @@ -1249,7 +1318,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(threadIdx.x % 4); const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]); const bool thread_should_write = threadIdx.x % 4 < cols_per_thread; -#elif defined(AMD_WMMA_AVAILABLE) +#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(0); const float2 KQ_cmr = make_float2(KQ_max[0], KQ_rowsum[0]); const bool thread_should_write = threadIdx.x / 16 < cols_per_thread; @@ -1283,14 +1352,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( // Warps with threadIdx.y % np != 0 must NOT return early. // All threads must return simultaneously to avoid race conditions with work on the next tile. - constexpr int nmeta = np*cols_per_warp >= WARP_SIZE ? np*cols_per_warp/WARP_SIZE : 1; + constexpr int nmeta = np*cols_per_warp >= warp_size ? np*cols_per_warp/warp_size : 1; - const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < WARP_SIZE ? threadIdx.x % (np*cols_per_warp) : threadIdx.x); + const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < warp_size ? threadIdx.x % (np*cols_per_warp) : threadIdx.x); float2 * const meta_ptr = ((float2 *) tile_Q) + jc_meta*(tile_stride/2) + nbatch_combine/2; float2 meta[nmeta]; #pragma unroll for (int imeta = 0; imeta < nmeta; ++imeta) { - meta[imeta] = meta_ptr[imeta * WARP_SIZE * tile_stride/2]; + meta[imeta] = meta_ptr[imeta * warp_size * tile_stride/2]; } float KQ_cmn = meta[0].x; // KQ combine max new, max between all parallel warps. @@ -1300,8 +1369,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } #pragma unroll for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) { - if (offset < WARP_SIZE) { - KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE)); + if (offset < warp_size) { + KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, warp_size)); } } @@ -1318,8 +1387,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } #pragma unroll for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) { - if (offset < WARP_SIZE) { - KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE); + if (offset < warp_size) { + KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, warp_size); } } @@ -1328,19 +1397,19 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( // Write back combined meta data: #pragma unroll for (int imeta = 0; imeta < nmeta; ++imeta) { - if (np*cols_per_warp >= WARP_SIZE || threadIdx.x < np*cols_per_warp) { + if (np*cols_per_warp >= warp_size || threadIdx.x < np*cols_per_warp) { // Combined KQ max scale + rowsum. - meta_ptr[imeta * WARP_SIZE * tile_stride/2] = make_float2(KQ_cms[imeta], KQ_crs); + meta_ptr[imeta * warp_size * tile_stride/2] = make_float2(KQ_cms[imeta], KQ_crs); } } // Combined KQ max + rowsum. - static_assert(cols_per_warp <= WARP_SIZE); - if (needs_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) { + static_assert(cols_per_warp <= warp_size); + if (needs_fixup && (cols_per_warp == warp_size || threadIdx.x < cols_per_warp)) { float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); } - if (is_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) { + if (is_fixup && (cols_per_warp == warp_size || threadIdx.x < cols_per_warp)) { float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); } @@ -1388,10 +1457,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( float2 * dstk_fixup_data = dstk_fixup + gridDim.x*(2*ncols) + blockIdx.x*(ncols*(DV/2)); #pragma unroll - for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { - const int k0_start = stride_k == WARP_SIZE ? 0 : nbatch_combine - nbatch_combine % (2*stride_k); + for (int stride_k : {warp_size, warp_size/2, warp_size/4, warp_size/8}) { + const int k0_start = stride_k == warp_size ? 0 : nbatch_combine - nbatch_combine % (2*stride_k); const int k0_stop = nbatch_combine - nbatch_combine % (1*stride_k); - const int stride_jc = WARP_SIZE / stride_k; + const int stride_jc = warp_size / stride_k; if (k0_start == k0_stop) { continue; @@ -1399,7 +1468,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( #pragma unroll for (int jc0_dst = 0; jc0_dst < ncols; jc0_dst += (nwarps/np)*stride_jc) { - const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == warp_size ? 0 : threadIdx.x / stride_k); if (jc0_dst + (nwarps/np)*stride_jc > ncols && jc_dst >= ncols) { break; @@ -1417,7 +1486,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const float * meta_j = (const float *) tile_Q + jc_tile_K*tile_stride + nbatch_combine; #pragma unroll for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { - const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k); float2 dstk_val = make_float2(0.0f, 0.0f); #pragma unroll @@ -1453,7 +1522,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop); NO_DEVICE_CODE; -#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) +#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE) } template @@ -1480,7 +1549,7 @@ static __global__ void flash_attn_ext_f16( const int32_t nb21, const int32_t nb22, const int64_t nb23, const int32_t ne31, const int32_t ne32, const int32_t ne33, const int32_t nb31, const int32_t nb32, const int64_t nb33) { -#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))) +#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)) // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) { @@ -1508,10 +1577,18 @@ static __global__ void flash_attn_ext_f16( } #endif // defined(AMD_WMMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) + if (DKQ != 64 && DKQ != 80 && DKQ != 96 && DKQ != 112 && DKQ != 128) { + NO_DEVICE_CODE; + return; + } +#endif // defined(AMD_MFMA_AVAILABLE) + + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr int ncols = ncols1 * ncols2; constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols); constexpr int nthreads = ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols); - constexpr int nwarps = nthreads / WARP_SIZE; + constexpr int nwarps = nthreads / warp_size; const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. @@ -1624,7 +1701,7 @@ static __global__ void flash_attn_ext_f16( ne31, ne32, ne33, nb31, nb32, nb33); NO_DEVICE_CODE; -#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))) +#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)) } template @@ -1644,7 +1721,8 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml const int nstages = ggml_cuda_fattn_mma_get_nstages (DKQ, DV, ncols1, ncols2, cc); const int cols_per_warp = std::min(ncols, get_cols_per_warp(cc)); - const int nwarps = nthreads / WARP_SIZE; + const int warp_size_host = ggml_cuda_info().devices[ctx.device].warp_size; + const int nwarps = nthreads / warp_size_host; constexpr bool V_is_K_view = DKQ == 576; // Guaranteed by the kernel selection logic in fattn.cu @@ -1694,7 +1772,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml } launch_fattn - (ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, nbatch_fa, true, true, true); + (ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, nbatch_fa, true, true, true, warp_size_host); } diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 721edd9994..85c177f496 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -440,6 +440,18 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const return BEST_FATTN_KERNEL_MMA_F16; } + // Use MFMA flash attention for CDNA (MI100+): + if (amd_mfma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 256 && Q->ne[0] != 576) { + const int64_t eff_nq = Q->ne[1] * (gqa_opt_applies ? gqa_ratio : 1); + // MMA vs tile crossover benchmarked on MI300X @ d32768: + // hsk=64 (gqa=4): MMA wins at eff >= 128 (+11%) + // hsk=128 (gqa=4): MMA wins at eff >= 128 (+4%) + if (eff_nq >= (GGML_CUDA_CC_IS_CDNA1(cc) && Q->ne[0] == 64 ? 64 : 128)) { + return BEST_FATTN_KERNEL_MMA_F16; + } + // Fall through to tile kernel for small effective batch sizes. + } + // If there are no tensor cores available, use the generic tile kernel: if (can_use_vector_kernel) { if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) { diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index dd45d6c78f..5d1dadd3e4 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -668,7 +668,7 @@ namespace ggml_cuda_mma { return ret; } -#elif defined(AMD_WMMA_AVAILABLE) +#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) template static __device__ __forceinline__ tile get_half2(const tile & tile_float) { tile ret; @@ -964,6 +964,34 @@ namespace ggml_cuda_mma { GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; #endif // defined(RDNA4) +#elif defined(AMD_MFMA_AVAILABLE) + // MFMA: FP16 input, FP32 accumulate, convert back to half2. + using halfx4_t = __attribute__((ext_vector_type(4))) _Float16; + using floatx4_t = __attribute__((ext_vector_type(4))) float; + + // Convert existing half2 accumulator to float for MFMA: + floatx4_t acc_f32; + { + const halfx4_t acc_h = reinterpret_cast(D.x[0]); +#pragma unroll + for (int i = 0; i < 4; ++i) { + acc_f32[i] = (float)acc_h[i]; + } + } + + const halfx4_t& a_frag = reinterpret_cast(A.x[0]); + const halfx4_t& b_frag = reinterpret_cast(B.x[0]); + acc_f32 = __builtin_amdgcn_mfma_f32_16x16x16f16(a_frag, b_frag, acc_f32, 0, 0, 0); + + // Convert back to half2: + { + halfx4_t result_h; +#pragma unroll + for (int i = 0; i < 4; ++i) { + result_h[i] = (_Float16)acc_f32[i]; + } + reinterpret_cast(D.x[0]) = result_h; + } #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE;