CUDA: no FP16 arithmetic for vector FA kernel (#17558)
This commit is contained in:
parent
35cf8887e1
commit
73955f7d2a
|
|
@ -558,8 +558,12 @@ static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const float2 v
|
||||||
acc += v.y*u.y;
|
acc += v.y*u.y;
|
||||||
}
|
}
|
||||||
|
|
||||||
static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v, const half2 u) {
|
|
||||||
#if defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(__gfx906__) || defined(CDNA))
|
#if defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(__gfx906__) || defined(CDNA))
|
||||||
|
#define V_DOT2_F32_F16_AVAILABLE
|
||||||
|
#endif // defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(__gfx906__) || defined(CDNA))
|
||||||
|
|
||||||
|
static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v, const half2 u) {
|
||||||
|
#ifdef V_DOT2_F32_F16_AVAILABLE
|
||||||
asm volatile("v_dot2_f32_f16 %0, %1, %2, %0" : "+v"(acc) : "v"(v), "v"(u));
|
asm volatile("v_dot2_f32_f16 %0, %1, %2, %0" : "+v"(acc) : "v"(v), "v"(u));
|
||||||
#else
|
#else
|
||||||
#ifdef FAST_FP16_AVAILABLE
|
#ifdef FAST_FP16_AVAILABLE
|
||||||
|
|
@ -571,7 +575,7 @@ static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v,
|
||||||
acc += tmpv.x * tmpu.x;
|
acc += tmpv.x * tmpu.x;
|
||||||
acc += tmpv.y * tmpu.y;
|
acc += tmpv.y * tmpu.y;
|
||||||
#endif // FAST_FP16_AVAILABLE
|
#endif // FAST_FP16_AVAILABLE
|
||||||
#endif // defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(GCN5) || defined(CDNA))
|
#endif // V_DOT2_F32_F16_AVAILABLE
|
||||||
}
|
}
|
||||||
|
|
||||||
static __device__ __forceinline__ void ggml_cuda_mad(half2 & acc, const half2 v, const half2 u) {
|
static __device__ __forceinline__ void ggml_cuda_mad(half2 & acc, const half2 v, const half2 u) {
|
||||||
|
|
|
||||||
|
|
@ -55,11 +55,11 @@ static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_f16(
|
||||||
ggml_cuda_memcpy_1<sizeof(tmp)>(tmp, K_h2 + k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne);
|
ggml_cuda_memcpy_1<sizeof(tmp)>(tmp, K_h2 + k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) {
|
for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) {
|
||||||
#ifdef FAST_FP16_AVAILABLE
|
#ifdef V_DOT2_F32_F16_AVAILABLE
|
||||||
ggml_cuda_mad(sum, tmp[k_KQ_1] , ((const half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
|
ggml_cuda_mad(sum, tmp[k_KQ_1] , ((const half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
|
||||||
#else
|
#else
|
||||||
ggml_cuda_mad(sum, __half22float2(tmp[k_KQ_1]), ((const float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
|
ggml_cuda_mad(sum, __half22float2(tmp[k_KQ_1]), ((const float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
|
||||||
#endif // FP16_AVAILABLE
|
#endif // V_DOT2_F32_F16_AVAILABLE
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -86,11 +86,11 @@ static __global__ void flash_attn_ext_vec(
|
||||||
|
|
||||||
constexpr vec_dot_KQ_t vec_dot_KQ = get_vec_dot_KQ<type_K, D, nthreads_KQ>();
|
constexpr vec_dot_KQ_t vec_dot_KQ = get_vec_dot_KQ<type_K, D, nthreads_KQ>();
|
||||||
constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
|
constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
|
||||||
#ifdef FAST_FP16_AVAILABLE
|
#ifdef V_DOT2_F32_F16_AVAILABLE
|
||||||
constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, half, V_rows_per_thread>();
|
constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, half, V_rows_per_thread>();
|
||||||
#else
|
#else
|
||||||
constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, float, V_rows_per_thread>();
|
constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, float, V_rows_per_thread>();
|
||||||
#endif // FAST_FP16_AVAILABLE
|
#endif // V_DOT2_F32_F16_AVAILABLE
|
||||||
|
|
||||||
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
|
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
|
||||||
|
|
||||||
|
|
@ -112,13 +112,13 @@ static __global__ void flash_attn_ext_vec(
|
||||||
|
|
||||||
constexpr int ne_KQ = ncols*D;
|
constexpr int ne_KQ = ncols*D;
|
||||||
constexpr int ne_combine = nwarps*V_cols_per_iter*D;
|
constexpr int ne_combine = nwarps*V_cols_per_iter*D;
|
||||||
#ifdef FAST_FP16_AVAILABLE
|
#ifdef V_DOT2_F32_F16_AVAILABLE
|
||||||
half2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}};
|
half2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}};
|
||||||
__shared__ half KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine];
|
__shared__ half KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine];
|
||||||
#else
|
#else
|
||||||
float2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}};
|
float2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}};
|
||||||
__shared__ float KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine];
|
__shared__ float KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine];
|
||||||
#endif // FAST_FP16_AVAILABLE
|
#endif // V_DOT2_F32_F16_AVAILABLE
|
||||||
|
|
||||||
float KQ_max[ncols];
|
float KQ_max[ncols];
|
||||||
float KQ_sum[ncols];
|
float KQ_sum[ncols];
|
||||||
|
|
@ -129,11 +129,11 @@ static __global__ void flash_attn_ext_vec(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers:
|
// Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers:
|
||||||
#ifdef FAST_FP16_AVAILABLE
|
#ifdef V_DOT2_F32_F16_AVAILABLE
|
||||||
half2 Q_reg[ncols][(D/2)/nthreads_KQ]; // Will be initialized completely.
|
half2 Q_reg[ncols][(D/2)/nthreads_KQ]; // Will be initialized completely.
|
||||||
#else
|
#else
|
||||||
float2 Q_reg[ncols][(D/2)/nthreads_KQ] = {{{0.0f, 0.0f}}}; // May be only partially initialized.
|
float2 Q_reg[ncols][(D/2)/nthreads_KQ] = {{{0.0f, 0.0f}}}; // May be only partially initialized.
|
||||||
#endif // FAST_FP16_AVAILABLE
|
#endif // V_DOT2_F32_F16_AVAILABLE
|
||||||
int Q_i32[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];
|
int Q_i32[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];
|
||||||
float2 Q_ds[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];
|
float2 Q_ds[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];
|
||||||
if constexpr (Q_q8_1) {
|
if constexpr (Q_q8_1) {
|
||||||
|
|
@ -191,7 +191,7 @@ static __global__ void flash_attn_ext_vec(
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
} else {
|
} else {
|
||||||
#ifdef FAST_FP16_AVAILABLE
|
#ifdef V_DOT2_F32_F16_AVAILABLE
|
||||||
const half2 scale_h2 = make_half2(scale, scale);
|
const half2 scale_h2 = make_half2(scale, scale);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < ncols; ++j) {
|
for (int j = 0; j < ncols; ++j) {
|
||||||
|
|
@ -233,7 +233,7 @@ static __global__ void flash_attn_ext_vec(
|
||||||
Q_reg[j][k].y *= scale;
|
Q_reg[j][k].y *= scale;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif // FAST_FP16_AVAILABLE
|
#endif // V_DOT2_F32_F16_AVAILABLE
|
||||||
}
|
}
|
||||||
|
|
||||||
const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
|
const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
|
||||||
|
|
@ -291,7 +291,7 @@ static __global__ void flash_attn_ext_vec(
|
||||||
KQ_sum[j] = KQ_sum[j]*KQ_max_scale + KQ_reg[j];
|
KQ_sum[j] = KQ_sum[j]*KQ_max_scale + KQ_reg[j];
|
||||||
KQ[j*nthreads + tid] = KQ_reg[j];
|
KQ[j*nthreads + tid] = KQ_reg[j];
|
||||||
|
|
||||||
#ifdef FAST_FP16_AVAILABLE
|
#ifdef V_DOT2_F32_F16_AVAILABLE
|
||||||
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
|
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
|
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
|
||||||
|
|
@ -303,7 +303,7 @@ static __global__ void flash_attn_ext_vec(
|
||||||
VKQ[j][i_VKQ_0/nthreads_V].x *= KQ_max_scale;
|
VKQ[j][i_VKQ_0/nthreads_V].x *= KQ_max_scale;
|
||||||
VKQ[j][i_VKQ_0/nthreads_V].y *= KQ_max_scale;
|
VKQ[j][i_VKQ_0/nthreads_V].y *= KQ_max_scale;
|
||||||
}
|
}
|
||||||
#endif // FAST_FP16_AVAILABLE
|
#endif // V_DOT2_F32_F16_AVAILABLE
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifndef GGML_USE_HIP
|
#ifndef GGML_USE_HIP
|
||||||
|
|
@ -314,7 +314,7 @@ static __global__ void flash_attn_ext_vec(
|
||||||
for (int k0 = 0; k0 < WARP_SIZE; k0 += V_cols_per_iter) {
|
for (int k0 = 0; k0 < WARP_SIZE; k0 += V_cols_per_iter) {
|
||||||
const int k = threadIdx.y*WARP_SIZE + k0 + (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V);
|
const int k = threadIdx.y*WARP_SIZE + k0 + (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V);
|
||||||
|
|
||||||
#ifdef FAST_FP16_AVAILABLE
|
#ifdef V_DOT2_F32_F16_AVAILABLE
|
||||||
half2 KQ_k[ncols];
|
half2 KQ_k[ncols];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < ncols; ++j) {
|
for (int j = 0; j < ncols; ++j) {
|
||||||
|
|
@ -353,7 +353,7 @@ static __global__ void flash_attn_ext_vec(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif // FAST_FP16_AVAILABLE
|
#endif // V_DOT2_F32_F16_AVAILABLE
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -374,7 +374,7 @@ static __global__ void flash_attn_ext_vec(
|
||||||
|
|
||||||
KQ_sum[j] = KQ_sum[j]*KQ_max_scale + (threadIdx.x == 0 ? expf(sink - KQ_max[j]) : 0.0f);
|
KQ_sum[j] = KQ_sum[j]*KQ_max_scale + (threadIdx.x == 0 ? expf(sink - KQ_max[j]) : 0.0f);
|
||||||
|
|
||||||
#ifdef FAST_FP16_AVAILABLE
|
#ifdef V_DOT2_F32_F16_AVAILABLE
|
||||||
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
|
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
|
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
|
||||||
|
|
@ -386,7 +386,7 @@ static __global__ void flash_attn_ext_vec(
|
||||||
VKQ[j][i_VKQ_0/nthreads_V].x *= KQ_max_scale;
|
VKQ[j][i_VKQ_0/nthreads_V].x *= KQ_max_scale;
|
||||||
VKQ[j][i_VKQ_0/nthreads_V].y *= KQ_max_scale;
|
VKQ[j][i_VKQ_0/nthreads_V].y *= KQ_max_scale;
|
||||||
}
|
}
|
||||||
#endif // FAST_FP16_AVAILABLE
|
#endif // V_DOT2_F32_F16_AVAILABLE
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -421,7 +421,7 @@ static __global__ void flash_attn_ext_vec(
|
||||||
const float kqmax_scale = expf(KQ_max[j_VKQ] - kqmax_new);
|
const float kqmax_scale = expf(KQ_max[j_VKQ] - kqmax_new);
|
||||||
KQ_max[j_VKQ] = kqmax_new;
|
KQ_max[j_VKQ] = kqmax_new;
|
||||||
|
|
||||||
#ifdef FAST_FP16_AVAILABLE
|
#ifdef V_DOT2_F32_F16_AVAILABLE
|
||||||
half2 * VKQ_tmp = (half2 *) KQ + threadIdx.y*(V_cols_per_iter*D/2)
|
half2 * VKQ_tmp = (half2 *) KQ + threadIdx.y*(V_cols_per_iter*D/2)
|
||||||
+ (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V)*(D/2);
|
+ (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V)*(D/2);
|
||||||
|
|
||||||
|
|
@ -452,7 +452,7 @@ static __global__ void flash_attn_ext_vec(
|
||||||
ggml_cuda_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ, &VKQ[j_VKQ][i_VKQ_0/nthreads_V]);
|
ggml_cuda_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ, &VKQ[j_VKQ][i_VKQ_0/nthreads_V]);
|
||||||
ggml_cuda_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ + V_rows_per_thread/4, &VKQ[j_VKQ][i_VKQ_0/nthreads_V + V_rows_per_thread/4]);
|
ggml_cuda_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ + V_rows_per_thread/4, &VKQ[j_VKQ][i_VKQ_0/nthreads_V + V_rows_per_thread/4]);
|
||||||
}
|
}
|
||||||
#endif // FAST_FP16_AVAILABLE
|
#endif // V_DOT2_F32_F16_AVAILABLE
|
||||||
|
|
||||||
KQ_sum[j_VKQ] *= kqmax_scale;
|
KQ_sum[j_VKQ] *= kqmax_scale;
|
||||||
KQ_sum[j_VKQ] = warp_reduce_sum(KQ_sum[j_VKQ]);
|
KQ_sum[j_VKQ] = warp_reduce_sum(KQ_sum[j_VKQ]);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue