diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index e5856ab6af..a48d349268 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -93,7 +93,11 @@ static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_bf16( ggml_cuda_memcpy_1(tmp, K_bf16 + k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne); #pragma unroll for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) { +#ifdef V_DOT2_F32_F16_AVAILABLE + ggml_cuda_mad(sum, ggml_cuda_cast(tmp[k_KQ_1]), __half22float2(((const half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1])); +#else ggml_cuda_mad(sum, ggml_cuda_cast(tmp[k_KQ_1]), ((const float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]); +#endif // V_DOT2_F32_F16_AVAILABLE } } @@ -609,7 +613,7 @@ constexpr __device__ dequantize_V_t get_dequantize_V() { } else if constexpr (type_V == GGML_TYPE_Q8_0) { return dequantize_V_q8_0; } else if constexpr (type_V == GGML_TYPE_BF16) { - return dequantize_V_bf16; + return dequantize_V_bf16; } else { static_assert(type_V == -1, "bad type"); return nullptr; diff --git a/ggml/src/ggml-cuda/fattn-vec.cuh b/ggml/src/ggml-cuda/fattn-vec.cuh index 0b40c903a0..f0bd42a576 100644 --- a/ggml/src/ggml-cuda/fattn-vec.cuh +++ b/ggml/src/ggml-cuda/fattn-vec.cuh @@ -323,8 +323,18 @@ static __global__ void flash_attn_ext_vec( #pragma unroll for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) { half2 tmp[V_rows_per_thread/2]; - dequantize_V(V + k*nb21, tmp, - 2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread); + if constexpr (type_V == GGML_TYPE_BF16) { + float2 tmp_f[V_rows_per_thread/2]; + dequantize_V(V + k*nb21, tmp_f, + 2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread); +#pragma unroll + for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) { + tmp[i_VKQ_1] = __float22half2_rn(tmp_f[i_VKQ_1]); + } + } else { + dequantize_V(V + k*nb21, tmp, + 2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread); + } #pragma unroll for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) { #pragma unroll