From 5c662d21a3a1c6a41d8abe401f5791712a5c02ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Thu, 15 Jan 2026 15:14:50 +0100 Subject: [PATCH] CUDA: fix allignment on register spill for FA (#18815) --- ggml/src/ggml-cuda/fattn-common.cuh | 4 +-- ggml/src/ggml-cuda/fattn-tile.cuh | 42 ++++++++++++++--------------- ggml/src/ggml-cuda/fattn-vec.cuh | 4 +-- 3 files changed, 25 insertions(+), 25 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 6b55f784f3..8468ba8488 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -59,7 +59,7 @@ static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_f16( #pragma unroll for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += nthreads*cpy_ne) { - half2 tmp[cpy_ne]; + __align__(16) half2 tmp[cpy_ne]; ggml_cuda_memcpy_1(tmp, K_h2 + k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne); #pragma unroll for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) { @@ -309,7 +309,7 @@ static __device__ __forceinline__ void dequantize_V_f16(const void * __restrict_ ggml_cuda_memcpy_1(dst, (const half *) vx + i0); } else if constexpr (std::is_same_v) { static_assert(ne % 2 == 0, "bad ne"); - half2 tmp[ne/2]; + __align__(16) half2 tmp[ne/2]; ggml_cuda_memcpy_1(tmp, (const half *) vx + i0); float2 * dst_f2 = (float2 *) dst; #pragma unroll diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh index 7c4d6fe67f..f055da8e2b 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cuh +++ b/ggml/src/ggml-cuda/fattn-tile.cuh @@ -343,7 +343,7 @@ static __device__ __forceinline__ void flash_attn_tile_load_tile( for (int j0 = j0_start; j0 < j0_stop; j0 += stride_j) { const int j = j0*cpy_ne + (stride_j == warp_size ? threadIdx.x : threadIdx.x % stride_j)*cpy_ne; - const half2 zero[cpy_ne] = {{0.0f, 0.0f}}; + const __align__(16) half2 zero[cpy_ne] = {{0.0f, 0.0f}}; ggml_cuda_memcpy_1( tile_KV + i*(J/2 + J_padding) + j, !oob_check || i < i_sup ? KV + i*stride_KV + j : zero); @@ -394,11 +394,11 @@ static __device__ __forceinline__ void flash_attn_tile_load_tile( const int j = j0*(cpy_ne/2) + (stride_j == warp_size ? threadIdx.x : threadIdx.x % stride_j)*(cpy_ne/2); const half2 zero[cpy_ne/2] = {{0.0f, 0.0f}}; - half2 tmp_h2[cpy_ne/2]; + __align__(16) half2 tmp_h2[cpy_ne/2]; ggml_cuda_memcpy_1( tmp_h2, !oob_check || i < i_sup ? KV + i*stride_KV + j : zero); - float2 tmp_f2[cpy_ne/2]; + __align__(16) float2 tmp_f2[cpy_ne/2]; #pragma unroll for (int l = 0; l < cpy_ne/2; ++l) { tmp_f2[l] = __half22float2(tmp_h2[l]); @@ -445,14 +445,14 @@ static __device__ __forceinline__ void flash_attn_tile_iter_KQ( static_assert((nbatch_K/2) % cpy_ne == 0, "bad nbatch_K"); #pragma unroll for (int k_KQ_1 = 0; k_KQ_1 < nbatch_K/2; k_KQ_1 += cpy_ne) { - half2 K_k[nbatch_fa/(np*warp_size)][cpy_ne]; - half2 Q_k[cpw][cpy_ne]; + __align__(16) half2 K_k[nbatch_fa/(np*warp_size)][cpy_ne]; + __align__(16) half2 Q_k[cpw][cpy_ne]; #else static_assert(nbatch_K % cpy_ne == 0, "bad nbatch_K"); #pragma unroll for (int k_KQ_1 = 0; k_KQ_1 < nbatch_K; k_KQ_1 += cpy_ne) { - float K_k[nbatch_fa/(np*warp_size)][cpy_ne]; - float Q_k[cpw][cpy_ne]; + __align__(16) float K_k[nbatch_fa/(np*warp_size)][cpy_ne]; + __align__(16) float Q_k[cpw][cpy_ne]; #endif // FAST_FP16_AVAILABLE #pragma unroll @@ -602,9 +602,9 @@ static __device__ __forceinline__ void flash_attn_tile_iter( #pragma unroll for (int jc0 = 0; jc0 < cpw; jc0 += KQ_cs) { #ifdef FAST_FP16_AVAILABLE - half tmp[nbatch_fa/(np*warp_size)][KQ_cs]; + __align__(16) half tmp[nbatch_fa/(np*warp_size)][KQ_cs]; #else - float tmp[nbatch_fa/(np*warp_size)][KQ_cs]; + __align__(16) float tmp[nbatch_fa/(np*warp_size)][KQ_cs]; #endif // FAST_FP16_AVAILABLE #pragma unroll @@ -664,8 +664,8 @@ static __device__ __forceinline__ void flash_attn_tile_iter( #ifdef FAST_FP16_AVAILABLE #pragma unroll for (int k1 = 0; k1 < nbatch_V; k1 += np) { - half2 V_k[(DVp/2)/warp_size]; - half2 KQ_k[cpw]; + __align__(16) half2 V_k[(DVp/2)/warp_size]; + __align__(16) half2 KQ_k[cpw]; constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size; #pragma unroll @@ -676,7 +676,7 @@ static __device__ __forceinline__ void flash_attn_tile_iter( for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; jc_VKQ_0 += KQ_cs) { const int jc_KQ = jc_VKQ_0/KQ_cs + (threadIdx.y / np)*(cpw/KQ_cs); - half tmp[KQ_cs]; + __align__(16) half tmp[KQ_cs]; ggml_cuda_memcpy_1( &tmp, KQ + jc_KQ*(nbatch_fa*KQ_cs) + (k0 + k1 + threadIdx.y % np)*KQ_cs); #pragma unroll @@ -696,8 +696,8 @@ static __device__ __forceinline__ void flash_attn_tile_iter( #else #pragma unroll for (int k1 = 0; k1 < nbatch_V; k1 += np) { - float2 V_k[(DVp/2)/warp_size]; - float KQ_k[cpw]; + __align__(16) float2 V_k[(DVp/2)/warp_size]; + __align__(16) float KQ_k[cpw]; constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size; #pragma unroll @@ -821,12 +821,12 @@ static __global__ void flash_attn_tile( __shared__ half2 Q_tmp[ncols * DKQ/2]; __shared__ half2 KV_tmp[nbatch_fa * (nbatch_K/2 + cpy_ne) + DVp-DV]; __shared__ half KQ[ncols * nbatch_fa]; - half2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}}; + __align__(16) half2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}}; #else __shared__ float Q_tmp[ncols * DKQ]; __shared__ float KV_tmp[nbatch_fa * (nbatch_K + cpy_ne) + DVp-DV]; __shared__ float KQ[ncols * nbatch_fa]; - float2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}}; + __align__(16) float2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}}; #endif // FAST_FP16_AVAILABLE float KQ_max[cpw]; @@ -849,7 +849,7 @@ static __global__ void flash_attn_tile( #pragma unroll for (int i0 = 0; i0 < DKQp; i0 += np*warp_size*cpy_ne_D) { if (i0 + np*warp_size*cpy_ne_D <= DKQ || i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D < DKQ) { - float tmp_f[cpy_ne_D] = {0.0f}; + __align__(16) float tmp_f[cpy_ne_D] = {0.0f}; ggml_cuda_memcpy_1 (tmp_f, &Q_f[c*(nb02/sizeof(float)) + fastmodulo(col_Q_0 + j, ne01)*(nb01/sizeof(float)) + i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D]); @@ -860,7 +860,7 @@ static __global__ void flash_attn_tile( } #ifdef FAST_FP16_AVAILABLE - half2 tmp_h2[cpy_ne_D/2]; + __align__(16) half2 tmp_h2[cpy_ne_D/2]; #pragma unroll for (int i1 = 0; i1 < cpy_ne_D; i1 += 2) { tmp_h2[i1/2] = make_half2(tmp_f[i1 + 0], tmp_f[i1 + 1]); @@ -959,7 +959,7 @@ static __global__ void flash_attn_tile( constexpr int cpy_ne_D = cpy_ne < (DVp/2)/warp_size ? cpy_ne : (DVp/2)/warp_size; #pragma unroll for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) { - half2 tmp[cpy_ne_D]; + __align__(16) half2 tmp[cpy_ne_D]; ggml_cuda_memcpy_1(tmp, &VKQ_combine[(threadIdx.y + ip)*(DVp/2) + i0 + threadIdx.x*cpy_ne_D]); #pragma unroll for (int i1 = 0; i1 < cpy_ne_D; ++i1) { @@ -970,7 +970,7 @@ static __global__ void flash_attn_tile( constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size; #pragma unroll for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) { - float tmp[cpy_ne_D]; + __align__(16) float tmp[cpy_ne_D]; ggml_cuda_memcpy_1(tmp, &VKQ_combine[(threadIdx.y + ip)*DVp + i0 + threadIdx.x*cpy_ne_D]); #pragma unroll for (int i1 = 0; i1 < cpy_ne_D; ++i1) { @@ -1033,7 +1033,7 @@ static __global__ void flash_attn_tile( constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size; #pragma unroll for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) { - float2 tmp[cpy_ne_D]; + __align__(16) float2 tmp[cpy_ne_D]; #pragma unroll for (int i1 = 0; i1 < cpy_ne_D; ++i1) { tmp[i1] = __half22float2(VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size + i1]); diff --git a/ggml/src/ggml-cuda/fattn-vec.cuh b/ggml/src/ggml-cuda/fattn-vec.cuh index 86f4dc0f7f..3f4a78cc6e 100644 --- a/ggml/src/ggml-cuda/fattn-vec.cuh +++ b/ggml/src/ggml-cuda/fattn-vec.cuh @@ -132,7 +132,7 @@ static __global__ void flash_attn_ext_vec( #ifdef V_DOT2_F32_F16_AVAILABLE half2 Q_reg[ncols][(D/2)/nthreads_KQ]; // Will be initialized completely. #else - float2 Q_reg[ncols][(D/2)/nthreads_KQ] = {{{0.0f, 0.0f}}}; // May be only partially initialized. + __align__(16) float2 Q_reg[ncols][(D/2)/nthreads_KQ] = {{{0.0f, 0.0f}}}; // May be only partially initialized. #endif // V_DOT2_F32_F16_AVAILABLE int Q_i32[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)]; float2 Q_ds[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)]; @@ -200,7 +200,7 @@ static __global__ void flash_attn_ext_vec( for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) { const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ)*cpy_ne; - float2 tmp[cpy_ne] = {{0.0f, 0.0f}}; + __align__(16) float2 tmp[cpy_ne] = {{0.0f, 0.0f}}; if (ncols == 1 || ic0 + j < int(ne01.z)) { ggml_cuda_memcpy_1(tmp, &Q_j[i]); ggml_cuda_memcpy_1(tmp + cpy_ne/2, &Q_j[i + cpy_ne/2]);