From 8f91ca54ec0b22f3ff3a495f32be8e8300638cdf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sat, 24 Jan 2026 10:09:36 +0100 Subject: [PATCH] CUDA: re-use MLA K data for V in MMA FA (#19057) --- ggml/src/ggml-cuda/fattn-common.cuh | 74 ++++++++++++++-------------- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 63 +++++++++++------------ ggml/src/ggml-cuda/fattn.cu | 5 ++ 3 files changed, 72 insertions(+), 70 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index a781fb91f5..40c7725784 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -782,12 +782,7 @@ void launch_fattn( const ggml_tensor * K = dst->src[1]; const ggml_tensor * V = dst->src[2]; - // TODO: make this more generic by removing the notion of "MLA". - // for example "is V a view of K?" so we can skip loading it. - // V strides should be driven by V itself and avoid assumption of the data layout - const bool is_mla = V->op == GGML_OP_VIEW && V->src[0] == K; - - GGML_ASSERT(V || is_mla); + const bool V_is_K_view = V->op == GGML_OP_VIEW && V->src[0] == K && V->data == K->data; const ggml_tensor * mask = dst->src[3]; const ggml_tensor * sinks = dst->src[4]; @@ -797,9 +792,9 @@ void launch_fattn( GGML_ASSERT(Q->type == GGML_TYPE_F32); GGML_ASSERT(KQV->type == GGML_TYPE_F32); - GGML_ASSERT( Q->nb[0] == ggml_element_size(Q)); - GGML_ASSERT( K->nb[0] == ggml_element_size(K)); - GGML_ASSERT(!V || V->nb[0] == ggml_element_size(V)); + GGML_ASSERT(Q->nb[0] == ggml_element_size(Q)); + GGML_ASSERT(K->nb[0] == ggml_element_size(K)); + GGML_ASSERT(V->nb[0] == ggml_element_size(V)); GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16); @@ -820,10 +815,10 @@ void launch_fattn( size_t nb12 = K->nb[2]; size_t nb13 = K->nb[3]; - const char * V_data = V ? (const char *) V->data : nullptr; - size_t nb21 = V ? V->nb[1] : nb11; - size_t nb22 = V ? V->nb[2] : nb12; - size_t nb23 = V ? V->nb[3] : nb13; + const char * V_data = (const char *) V->data; + size_t nb21 = V->nb[1]; + size_t nb22 = V->nb[2]; + size_t nb23 = V->nb[3]; if (need_f16_K && K->type != GGML_TYPE_F16) { const size_t bs = ggml_blck_size(K->type); @@ -852,32 +847,39 @@ void launch_fattn( K_data = (char *) K_f16.ptr; } - if (V && need_f16_V && V->type != GGML_TYPE_F16) { - const size_t bs = ggml_blck_size(V->type); - const size_t ts = ggml_type_size(V->type); - - V_f16.alloc(ggml_nelements(V)); - if (ggml_is_contiguously_allocated(V)) { - to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type); - to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream); - V_data = (char *) V_f16.ptr; - - nb21 = nb21*bs*sizeof(half)/ts; - nb22 = nb22*bs*sizeof(half)/ts; - nb23 = nb23*bs*sizeof(half)/ts; + if (need_f16_V && V->type != GGML_TYPE_F16) { + if (V_is_K_view) { + V_data = K_data; + nb21 = nb11; + nb22 = nb12; + nb23 = nb13; } else { - GGML_ASSERT(V->nb[0] == ts); - to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(V->type); - const int64_t s01 = nb21 / ts; - const int64_t s02 = nb22 / ts; - const int64_t s03 = nb23 / ts; - to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream); + const size_t bs = ggml_blck_size(V->type); + const size_t ts = ggml_type_size(V->type); - nb21 = V->ne[0] * sizeof(half); - nb22 = V->ne[1] * nb21; - nb23 = V->ne[2] * nb22; + V_f16.alloc(ggml_nelements(V)); + if (ggml_is_contiguously_allocated(V)) { + to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type); + to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream); + V_data = (char *) V_f16.ptr; + + nb21 = nb21*bs*sizeof(half)/ts; + nb22 = nb22*bs*sizeof(half)/ts; + nb23 = nb23*bs*sizeof(half)/ts; + } else { + GGML_ASSERT(V->nb[0] == ts); + to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(V->type); + const int64_t s01 = nb21 / ts; + const int64_t s02 = nb22 / ts; + const int64_t s03 = nb23 / ts; + to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream); + + nb21 = V->ne[0] * sizeof(half); + nb22 = V->ne[1] * nb21; + nb23 = V->ne[2] * nb22; + } + V_data = (char *) V_f16.ptr; } - V_data = (char *) V_f16.ptr; } const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1); diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 203569e345..3e7d67b40d 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -400,7 +400,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( } template static __device__ __forceinline__ void flash_attn_ext_f16_iter( const float2 * const __restrict__ Q_f2, @@ -442,8 +442,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( constexpr int stride_tile_Q = DKQ/2 + 4; constexpr int stride_tile_K = nbatch_K2 + 4; - static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA"); - constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4; + constexpr int stride_tile_V = V_is_K_view ? stride_tile_K : nbatch_V2 + 4; const int k_VKQ_0 = kb0 * nbatch_fa; #if defined(TURING_MMA_AVAILABLE) @@ -456,7 +455,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( if constexpr (nstages > 1) { static_assert(!oob_check, "OOB check incompatible with multi-stage pipeline"); - static_assert(!mla, "multi-stage loading not implemented for MLA"); + static_assert(!V_is_K_view, "K data reuse not implemented multi-stage loading"); static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading"); constexpr bool use_cp_async = true; cp_async_wait_all(); @@ -471,8 +470,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } } + // For MLA K and V have the same data. + // Therefore, iterate over K in reverse and later re-use the data if possible. #pragma unroll - for (int k0_start = 0; k0_start < DKQ/2; k0_start += nbatch_K2) { + for (int k0_start = (DKQ/2-1) - (DKQ/2-1) % nbatch_K2; k0_start >= 0; k0_start -= nbatch_K2) { const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2; const int k0_diff = k0_stop - k0_start; @@ -776,6 +777,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } if constexpr (nstages > 1) { + static_assert(!V_is_K_view, "K data reuse not implemented multi-stage loading"); // Preload K tile for next iteration: constexpr bool use_cp_async = true; cp_async_wait_all(); @@ -791,11 +793,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } - // For MLA K and V have the same data. - // Therefore, iterate over V in reverse and re-use the data if possible. - static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented"); - // constexpr int reusable_cutoff = mla ? (DV - 1) - (DV - 1) % (2*nbatch_K2) : DV; - constexpr int reusable_cutoff = DV; // TODO implement properly #if defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE) T_A_VKQ A_identity; make_identity_mat(A_identity); @@ -803,12 +800,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( // Calculate VKQ tile, need to use logical rather than physical elements for i0 due to transposition of V: #pragma unroll - for (int i0_stop = DV; i0_stop > 0; i0_stop -= 2*nbatch_V2) { - const int i0_start = i0_stop - 2*nbatch_V2 > 0 ? i0_stop - 2*nbatch_V2 : 0; - const int i0_diff = i0_stop - i0_start; + for (int i0_start = 0; i0_start < DV; i0_start += 2*nbatch_V2) { + static_assert(DV % (2*nbatch_V2) == 0, "bad loop size"); + const int i0_stop = i0_start + 2*nbatch_V2; + const int i0_diff = i0_stop - i0_start; if constexpr (nstages <= 1) { - if (i0_start < reusable_cutoff) { + if (!V_is_K_view || i0_stop > 2*nbatch_K2) { constexpr bool use_cp_async = nstages == 1; flash_attn_ext_f16_load_tile (V_h2 + int64_t(k_VKQ_0)*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V, k_VKQ_sup); @@ -818,7 +816,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( __syncthreads(); } } - const half2 * tile_V_i = i0_start < reusable_cutoff ? tile_V : tile_V + (i0_start - reusable_cutoff)/2; + const half2 * tile_V_i = !V_is_K_view || i0_stop > 2*nbatch_K2 ? tile_V : tile_V + i0_start/2; #if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int i0_stride = cols_per_warp == 8 ? T_C_VKQ::I : 2*T_C_VKQ::J; @@ -921,7 +919,7 @@ template struct mma_tile_sizes { }; #endif // defined(TURING_MMA_AVAILABLE) -template +template static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const float2 * const __restrict__ Q_f2, const half2 * const __restrict__ K_h2, @@ -975,8 +973,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( constexpr int stride_tile_Q = DKQ/2 + 4; constexpr int stride_tile_K = nbatch_K2 + 4; - static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA"); - constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4; + constexpr int stride_tile_V = V_is_K_view ? stride_tile_K : nbatch_V2 + 4; constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V; extern __shared__ half2 tile_Q[]; @@ -1080,7 +1077,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( constexpr bool last_iter = false; constexpr int k_VKQ_sup = nbatch_fa; flash_attn_ext_f16_iter - (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap, ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, @@ -1089,7 +1086,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( constexpr bool last_iter = true; const int k_VKQ_sup = ne11 - kb0*nbatch_fa; flash_attn_ext_f16_iter - (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap, ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, @@ -1100,7 +1097,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( constexpr bool last_iter = false; constexpr int k_VKQ_sup = nbatch_fa; flash_attn_ext_f16_iter - (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap, ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, @@ -1109,7 +1106,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( constexpr bool last_iter = true; constexpr int k_VKQ_sup = nbatch_fa; flash_attn_ext_f16_iter - (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap, ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, @@ -1457,7 +1454,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( #endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) } -template +template __launch_bounds__(ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols1*ncols2), ggml_cuda_fattn_mma_get_occupancy(DKQ, DV, ncols1*ncols2)) static __global__ void flash_attn_ext_f16( const char * __restrict__ Q, @@ -1509,8 +1506,6 @@ static __global__ void flash_attn_ext_f16( } #endif // defined(AMD_WMMA_AVAILABLE) - static_assert(!mla || DKQ >= DV, "MLA needs DKQ >= DV"); - constexpr int ncols = ncols1 * ncols2; constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols); constexpr int nthreads = ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols); @@ -1523,7 +1518,7 @@ static __global__ void flash_attn_ext_f16( const int stride_K = nb11 / sizeof(half2); const int stride_mask = nb31 / sizeof(half); - const int stride_V = mla ? stride_K : nb21 / sizeof(half2); + const int stride_V = V_is_K_view ? stride_K : nb21 / sizeof(half2); const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa; const int iter_j = (ne01.z + (ncols1 - 1)) / ncols1; @@ -1553,7 +1548,7 @@ static __global__ void flash_attn_ext_f16( (const half *) (mask + nb33*(sequence % ne33)); float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2); - const half2 * V_h2 = mla ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); + const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr; const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f; @@ -1564,12 +1559,12 @@ static __global__ void flash_attn_ext_f16( constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. if (kb0_start == 0) { constexpr bool needs_fixup = false; // CUDA block is working on an entire tile. - flash_attn_ext_f16_process_tile + flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop); } else { constexpr bool needs_fixup = true; // CUDA block is missing the beginning of a tile. - flash_attn_ext_f16_process_tile + flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop); } @@ -1597,7 +1592,7 @@ static __global__ void flash_attn_ext_f16( (const half *) (mask + nb33*(sequence % ne33)); float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2); - const half2 * V_h2 = mla ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); + const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr; const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f; @@ -1608,7 +1603,7 @@ static __global__ void flash_attn_ext_f16( constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. constexpr bool needs_fixup = false; - flash_attn_ext_f16_process_tile + flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop); #else @@ -1644,7 +1639,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml const int cols_per_warp = std::min(ncols, get_cols_per_warp(cc)); const int nwarps = nthreads / WARP_SIZE; - constexpr bool mla = DKQ == 576; + constexpr bool V_is_K_view = DKQ == 576; // Guaranteed by the kernel selection logic in fattn.cu const size_t nbytes_shared_KV_1stage = nbatch_fa * std::max(nbatch_K2 + 4, nbatch_V2 + 4) * sizeof(half2); const size_t nbytes_shared_KV_2stage = nbatch_fa * (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2); @@ -1669,7 +1664,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml fattn_kernel_t fattn_kernel; if (logit_softcap == 0.0f) { constexpr bool use_logit_softcap = false; - fattn_kernel = flash_attn_ext_f16; + fattn_kernel = flash_attn_ext_f16; #if !defined(GGML_USE_MUSA) static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; @@ -1680,7 +1675,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml #endif // !defined(GGML_USE_MUSA) } else { constexpr bool use_logit_softcap = true; - fattn_kernel = flash_attn_ext_f16; + fattn_kernel = flash_attn_ext_f16; #if !defined(GGML_USE_MUSA) static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 87f07a2f93..ba2b96bc32 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -247,6 +247,8 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const } } + const bool V_is_K_view = V->op == GGML_OP_VIEW && V->src[0] == K && V->data == K->data; + const int cc = ggml_cuda_info().devices[device].cc; switch (K->ne[0]) { @@ -269,6 +271,9 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const if (!gqa_opt_applies || gqa_ratio % 4 != 0) { return BEST_FATTN_KERNEL_NONE; } + if (!V_is_K_view) { + return BEST_FATTN_KERNEL_NONE; + } break; default: return BEST_FATTN_KERNEL_NONE;