diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 02443b8c63..2750117aa9 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -10,6 +10,12 @@ #define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction. #define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs. +// log(2) = 0.6931, by adding this to the KQ maximum used for the softmax the numerical range representable +// by the VKQ accumulators is effectively being shifted up by a factor of 8. +// This reduces issues with numerical overflow but also causes larger values to be flushed to zero. +// However, as the output from FlashAttention will usually be used as an input for a matrix multiplication this should be negligible. +#define FATTN_KQ_MAX_OFFSET 0.6931f + typedef void (* fattn_kernel_t)( const char * __restrict__ Q, const char * __restrict__ K, diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index b6250cf794..ade0773dad 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -532,7 +532,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( #pragma unroll for (int l = 0; l < T_C_KQ::ne; ++l) { if (!oob_check || k0 + T_C_KQ::get_i(l) < k_VKQ_sup) { - KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k0/(np*T_C_KQ::I)].x[l]); + KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k0/(np*T_C_KQ::I)].x[l] + FATTN_KQ_MAX_OFFSET); } } } @@ -585,7 +585,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( for (int l = 0; l < T_C_KQ::ne; ++l) { if (!oob_check || k0 + T_C_KQ::get_j(l) < k_VKQ_sup) { // Turing + Volta: - KQ_max_new[(l/2) % 2] = fmaxf(KQ_max_new[(l/2) % 2], KQ_C[(k0/(np*T_C_KQ::J))].x[l]); + KQ_max_new[(l/2) % 2] = fmaxf(KQ_max_new[(l/2) % 2], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + FATTN_KQ_MAX_OFFSET); } } } diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh index 63b235674e..8afc1daaeb 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cuh +++ b/ggml/src/ggml-cuda/fattn-tile.cuh @@ -572,7 +572,7 @@ static __device__ __forceinline__ void flash_attn_tile_iter( KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] += (ncols2 > 1 || mask) ? slope*__half2float(mask[j*stride_mask + k_VKQ_0 + i_KQ]) : 0.0f; - KQ_max_new[jc0] = fmaxf(KQ_max_new[jc0], KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]); + KQ_max_new[jc0] = fmaxf(KQ_max_new[jc0], KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] + FATTN_KQ_MAX_OFFSET); } } diff --git a/ggml/src/ggml-cuda/fattn-vec.cuh b/ggml/src/ggml-cuda/fattn-vec.cuh index 0bae9849a9..4d167b95a0 100644 --- a/ggml/src/ggml-cuda/fattn-vec.cuh +++ b/ggml/src/ggml-cuda/fattn-vec.cuh @@ -270,7 +270,7 @@ static __global__ void flash_attn_ext_vec( sum += slope*__half2float(maskh[j*ne11 + i_KQ]); } - KQ_max_new[j] = fmaxf(KQ_max_new[j], sum); + KQ_max_new[j] = fmaxf(KQ_max_new[j], sum + FATTN_KQ_MAX_OFFSET); if ((nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ) == uint32_t(i_KQ_0)) { KQ_reg[j] = sum; diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index 0d81f0aae0..8694fd06c7 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -220,7 +220,7 @@ static __global__ void flash_attn_ext_f16( KQ_f_tmp[k0/warp_size] += mask && ic0 + j < int(ne01.z) ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f; - KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/warp_size]); + KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/warp_size] + FATTN_KQ_MAX_OFFSET); } KQ_max_new = warp_reduce_max(KQ_max_new);