CUDA: fix FA VKQ accumulator overflow (#17746)
This commit is contained in:
parent
668ed76574
commit
e95d0bc8fd
|
|
@ -10,6 +10,12 @@
|
||||||
#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
|
#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
|
||||||
#define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
|
#define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
|
||||||
|
|
||||||
|
// log(2) = 0.6931, by adding this to the KQ maximum used for the softmax the numerical range representable
|
||||||
|
// by the VKQ accumulators is effectively being shifted up by a factor of 8.
|
||||||
|
// This reduces issues with numerical overflow but also causes larger values to be flushed to zero.
|
||||||
|
// However, as the output from FlashAttention will usually be used as an input for a matrix multiplication this should be negligible.
|
||||||
|
#define FATTN_KQ_MAX_OFFSET 0.6931f
|
||||||
|
|
||||||
typedef void (* fattn_kernel_t)(
|
typedef void (* fattn_kernel_t)(
|
||||||
const char * __restrict__ Q,
|
const char * __restrict__ Q,
|
||||||
const char * __restrict__ K,
|
const char * __restrict__ K,
|
||||||
|
|
|
||||||
|
|
@ -532,7 +532,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int l = 0; l < T_C_KQ::ne; ++l) {
|
for (int l = 0; l < T_C_KQ::ne; ++l) {
|
||||||
if (!oob_check || k0 + T_C_KQ::get_i(l) < k_VKQ_sup) {
|
if (!oob_check || k0 + T_C_KQ::get_i(l) < k_VKQ_sup) {
|
||||||
KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k0/(np*T_C_KQ::I)].x[l]);
|
KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k0/(np*T_C_KQ::I)].x[l] + FATTN_KQ_MAX_OFFSET);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -585,7 +585,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||||
for (int l = 0; l < T_C_KQ::ne; ++l) {
|
for (int l = 0; l < T_C_KQ::ne; ++l) {
|
||||||
if (!oob_check || k0 + T_C_KQ::get_j(l) < k_VKQ_sup) {
|
if (!oob_check || k0 + T_C_KQ::get_j(l) < k_VKQ_sup) {
|
||||||
// Turing + Volta:
|
// Turing + Volta:
|
||||||
KQ_max_new[(l/2) % 2] = fmaxf(KQ_max_new[(l/2) % 2], KQ_C[(k0/(np*T_C_KQ::J))].x[l]);
|
KQ_max_new[(l/2) % 2] = fmaxf(KQ_max_new[(l/2) % 2], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + FATTN_KQ_MAX_OFFSET);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -572,7 +572,7 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
|
||||||
KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] += (ncols2 > 1 || mask) ?
|
KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] += (ncols2 > 1 || mask) ?
|
||||||
slope*__half2float(mask[j*stride_mask + k_VKQ_0 + i_KQ]) : 0.0f;
|
slope*__half2float(mask[j*stride_mask + k_VKQ_0 + i_KQ]) : 0.0f;
|
||||||
|
|
||||||
KQ_max_new[jc0] = fmaxf(KQ_max_new[jc0], KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]);
|
KQ_max_new[jc0] = fmaxf(KQ_max_new[jc0], KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] + FATTN_KQ_MAX_OFFSET);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -270,7 +270,7 @@ static __global__ void flash_attn_ext_vec(
|
||||||
sum += slope*__half2float(maskh[j*ne11 + i_KQ]);
|
sum += slope*__half2float(maskh[j*ne11 + i_KQ]);
|
||||||
}
|
}
|
||||||
|
|
||||||
KQ_max_new[j] = fmaxf(KQ_max_new[j], sum);
|
KQ_max_new[j] = fmaxf(KQ_max_new[j], sum + FATTN_KQ_MAX_OFFSET);
|
||||||
|
|
||||||
if ((nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ) == uint32_t(i_KQ_0)) {
|
if ((nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ) == uint32_t(i_KQ_0)) {
|
||||||
KQ_reg[j] = sum;
|
KQ_reg[j] = sum;
|
||||||
|
|
|
||||||
|
|
@ -220,7 +220,7 @@ static __global__ void flash_attn_ext_f16(
|
||||||
|
|
||||||
KQ_f_tmp[k0/warp_size] += mask && ic0 + j < int(ne01.z) ?
|
KQ_f_tmp[k0/warp_size] += mask && ic0 + j < int(ne01.z) ?
|
||||||
__half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
|
__half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
|
||||||
KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/warp_size]);
|
KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/warp_size] + FATTN_KQ_MAX_OFFSET);
|
||||||
}
|
}
|
||||||
KQ_max_new = warp_reduce_max<warp_size>(KQ_max_new);
|
KQ_max_new = warp_reduce_max<warp_size>(KQ_max_new);
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue