CUDA: fix allignment on register spill for FA (#18815)

This commit is contained in:
Johannes Gäßler 2026-01-15 15:14:50 +01:00 committed by GitHub
parent 8cc0ba957b
commit 5c662d21a3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 25 additions and 25 deletions

View File

@ -59,7 +59,7 @@ static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_f16(
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += nthreads*cpy_ne) {
half2 tmp[cpy_ne];
__align__(16) half2 tmp[cpy_ne];
ggml_cuda_memcpy_1<sizeof(tmp)>(tmp, K_h2 + k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne);
#pragma unroll
for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) {
@ -309,7 +309,7 @@ static __device__ __forceinline__ void dequantize_V_f16(const void * __restrict_
ggml_cuda_memcpy_1<ne*sizeof(half)>(dst, (const half *) vx + i0);
} else if constexpr (std::is_same_v<T, float>) {
static_assert(ne % 2 == 0, "bad ne");
half2 tmp[ne/2];
__align__(16) half2 tmp[ne/2];
ggml_cuda_memcpy_1<ne*sizeof(half)>(tmp, (const half *) vx + i0);
float2 * dst_f2 = (float2 *) dst;
#pragma unroll

View File

@ -343,7 +343,7 @@ static __device__ __forceinline__ void flash_attn_tile_load_tile(
for (int j0 = j0_start; j0 < j0_stop; j0 += stride_j) {
const int j = j0*cpy_ne + (stride_j == warp_size ? threadIdx.x : threadIdx.x % stride_j)*cpy_ne;
const half2 zero[cpy_ne] = {{0.0f, 0.0f}};
const __align__(16) half2 zero[cpy_ne] = {{0.0f, 0.0f}};
ggml_cuda_memcpy_1<cpy_nb>(
tile_KV + i*(J/2 + J_padding) + j,
!oob_check || i < i_sup ? KV + i*stride_KV + j : zero);
@ -394,11 +394,11 @@ static __device__ __forceinline__ void flash_attn_tile_load_tile(
const int j = j0*(cpy_ne/2) + (stride_j == warp_size ? threadIdx.x : threadIdx.x % stride_j)*(cpy_ne/2);
const half2 zero[cpy_ne/2] = {{0.0f, 0.0f}};
half2 tmp_h2[cpy_ne/2];
__align__(16) half2 tmp_h2[cpy_ne/2];
ggml_cuda_memcpy_1<sizeof(tmp_h2)>(
tmp_h2, !oob_check || i < i_sup ? KV + i*stride_KV + j : zero);
float2 tmp_f2[cpy_ne/2];
__align__(16) float2 tmp_f2[cpy_ne/2];
#pragma unroll
for (int l = 0; l < cpy_ne/2; ++l) {
tmp_f2[l] = __half22float2(tmp_h2[l]);
@ -445,14 +445,14 @@ static __device__ __forceinline__ void flash_attn_tile_iter_KQ(
static_assert((nbatch_K/2) % cpy_ne == 0, "bad nbatch_K");
#pragma unroll
for (int k_KQ_1 = 0; k_KQ_1 < nbatch_K/2; k_KQ_1 += cpy_ne) {
half2 K_k[nbatch_fa/(np*warp_size)][cpy_ne];
half2 Q_k[cpw][cpy_ne];
__align__(16) half2 K_k[nbatch_fa/(np*warp_size)][cpy_ne];
__align__(16) half2 Q_k[cpw][cpy_ne];
#else
static_assert(nbatch_K % cpy_ne == 0, "bad nbatch_K");
#pragma unroll
for (int k_KQ_1 = 0; k_KQ_1 < nbatch_K; k_KQ_1 += cpy_ne) {
float K_k[nbatch_fa/(np*warp_size)][cpy_ne];
float Q_k[cpw][cpy_ne];
__align__(16) float K_k[nbatch_fa/(np*warp_size)][cpy_ne];
__align__(16) float Q_k[cpw][cpy_ne];
#endif // FAST_FP16_AVAILABLE
#pragma unroll
@ -602,9 +602,9 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
#pragma unroll
for (int jc0 = 0; jc0 < cpw; jc0 += KQ_cs) {
#ifdef FAST_FP16_AVAILABLE
half tmp[nbatch_fa/(np*warp_size)][KQ_cs];
__align__(16) half tmp[nbatch_fa/(np*warp_size)][KQ_cs];
#else
float tmp[nbatch_fa/(np*warp_size)][KQ_cs];
__align__(16) float tmp[nbatch_fa/(np*warp_size)][KQ_cs];
#endif // FAST_FP16_AVAILABLE
#pragma unroll
@ -664,8 +664,8 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
#ifdef FAST_FP16_AVAILABLE
#pragma unroll
for (int k1 = 0; k1 < nbatch_V; k1 += np) {
half2 V_k[(DVp/2)/warp_size];
half2 KQ_k[cpw];
__align__(16) half2 V_k[(DVp/2)/warp_size];
__align__(16) half2 KQ_k[cpw];
constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size;
#pragma unroll
@ -676,7 +676,7 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; jc_VKQ_0 += KQ_cs) {
const int jc_KQ = jc_VKQ_0/KQ_cs + (threadIdx.y / np)*(cpw/KQ_cs);
half tmp[KQ_cs];
__align__(16) half tmp[KQ_cs];
ggml_cuda_memcpy_1<KQ_cs*sizeof(half)>(
&tmp, KQ + jc_KQ*(nbatch_fa*KQ_cs) + (k0 + k1 + threadIdx.y % np)*KQ_cs);
#pragma unroll
@ -696,8 +696,8 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
#else
#pragma unroll
for (int k1 = 0; k1 < nbatch_V; k1 += np) {
float2 V_k[(DVp/2)/warp_size];
float KQ_k[cpw];
__align__(16) float2 V_k[(DVp/2)/warp_size];
__align__(16) float KQ_k[cpw];
constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size;
#pragma unroll
@ -821,12 +821,12 @@ static __global__ void flash_attn_tile(
__shared__ half2 Q_tmp[ncols * DKQ/2];
__shared__ half2 KV_tmp[nbatch_fa * (nbatch_K/2 + cpy_ne) + DVp-DV];
__shared__ half KQ[ncols * nbatch_fa];
half2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}};
__align__(16) half2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}};
#else
__shared__ float Q_tmp[ncols * DKQ];
__shared__ float KV_tmp[nbatch_fa * (nbatch_K + cpy_ne) + DVp-DV];
__shared__ float KQ[ncols * nbatch_fa];
float2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}};
__align__(16) float2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}};
#endif // FAST_FP16_AVAILABLE
float KQ_max[cpw];
@ -849,7 +849,7 @@ static __global__ void flash_attn_tile(
#pragma unroll
for (int i0 = 0; i0 < DKQp; i0 += np*warp_size*cpy_ne_D) {
if (i0 + np*warp_size*cpy_ne_D <= DKQ || i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D < DKQ) {
float tmp_f[cpy_ne_D] = {0.0f};
__align__(16) float tmp_f[cpy_ne_D] = {0.0f};
ggml_cuda_memcpy_1<sizeof(tmp_f)>
(tmp_f, &Q_f[c*(nb02/sizeof(float)) + fastmodulo(col_Q_0 + j, ne01)*(nb01/sizeof(float))
+ i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D]);
@ -860,7 +860,7 @@ static __global__ void flash_attn_tile(
}
#ifdef FAST_FP16_AVAILABLE
half2 tmp_h2[cpy_ne_D/2];
__align__(16) half2 tmp_h2[cpy_ne_D/2];
#pragma unroll
for (int i1 = 0; i1 < cpy_ne_D; i1 += 2) {
tmp_h2[i1/2] = make_half2(tmp_f[i1 + 0], tmp_f[i1 + 1]);
@ -959,7 +959,7 @@ static __global__ void flash_attn_tile(
constexpr int cpy_ne_D = cpy_ne < (DVp/2)/warp_size ? cpy_ne : (DVp/2)/warp_size;
#pragma unroll
for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) {
half2 tmp[cpy_ne_D];
__align__(16) half2 tmp[cpy_ne_D];
ggml_cuda_memcpy_1<cpy_ne_D*4>(tmp, &VKQ_combine[(threadIdx.y + ip)*(DVp/2) + i0 + threadIdx.x*cpy_ne_D]);
#pragma unroll
for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
@ -970,7 +970,7 @@ static __global__ void flash_attn_tile(
constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size;
#pragma unroll
for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {
float tmp[cpy_ne_D];
__align__(16) float tmp[cpy_ne_D];
ggml_cuda_memcpy_1<cpy_ne_D*4>(tmp, &VKQ_combine[(threadIdx.y + ip)*DVp + i0 + threadIdx.x*cpy_ne_D]);
#pragma unroll
for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
@ -1033,7 +1033,7 @@ static __global__ void flash_attn_tile(
constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size;
#pragma unroll
for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) {
float2 tmp[cpy_ne_D];
__align__(16) float2 tmp[cpy_ne_D];
#pragma unroll
for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
tmp[i1] = __half22float2(VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size + i1]);

View File

@ -132,7 +132,7 @@ static __global__ void flash_attn_ext_vec(
#ifdef V_DOT2_F32_F16_AVAILABLE
half2 Q_reg[ncols][(D/2)/nthreads_KQ]; // Will be initialized completely.
#else
float2 Q_reg[ncols][(D/2)/nthreads_KQ] = {{{0.0f, 0.0f}}}; // May be only partially initialized.
__align__(16) float2 Q_reg[ncols][(D/2)/nthreads_KQ] = {{{0.0f, 0.0f}}}; // May be only partially initialized.
#endif // V_DOT2_F32_F16_AVAILABLE
int Q_i32[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];
float2 Q_ds[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];
@ -200,7 +200,7 @@ static __global__ void flash_attn_ext_vec(
for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) {
const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ)*cpy_ne;
float2 tmp[cpy_ne] = {{0.0f, 0.0f}};
__align__(16) float2 tmp[cpy_ne] = {{0.0f, 0.0f}};
if (ncols == 1 || ic0 + j < int(ne01.z)) {
ggml_cuda_memcpy_1<cpy_nb>(tmp, &Q_j[i]);
ggml_cuda_memcpy_1<cpy_nb>(tmp + cpy_ne/2, &Q_j[i + cpy_ne/2]);