From b73d1557eb7de0eb9bb108de3da00e07c2381da2 Mon Sep 17 00:00:00 2001 From: Patrick Buckley Date: Fri, 13 Mar 2026 13:25:50 -0700 Subject: [PATCH 1/4] ggml-cuda: native bf16 flash attention for vec and tile kernels mma kernel still converts bf16 to fp16 before launch, native mma bf16 todo --- ggml/src/ggml-cuda/CMakeLists.txt | 8 +- ggml/src/ggml-cuda/fattn-common.cuh | 140 +++++++++++++++++- ggml/src/ggml-cuda/fattn-tile.cuh | 126 ++++++++++------ ggml/src/ggml-cuda/fattn-vec.cuh | 22 ++- ggml/src/ggml-cuda/fattn.cu | 16 ++ .../fattn-vec-instance-bf16-bf16.cu | 7 + .../fattn-vec-instance-bf16-f16.cu | 7 + .../fattn-vec-instance-bf16-q4_0.cu | 7 + .../fattn-vec-instance-bf16-q4_1.cu | 7 + .../fattn-vec-instance-bf16-q5_0.cu | 7 + .../fattn-vec-instance-bf16-q5_1.cu | 7 + .../fattn-vec-instance-bf16-q8_0.cu | 7 + .../fattn-vec-instance-f16-bf16.cu | 7 + .../fattn-vec-instance-q4_0-bf16.cu | 7 + .../fattn-vec-instance-q4_1-bf16.cu | 7 + .../fattn-vec-instance-q5_0-bf16.cu | 7 + .../fattn-vec-instance-q5_1-bf16.cu | 7 + .../fattn-vec-instance-q8_0-bf16.cu | 7 + .../template-instances/generate_cu_files.py | 2 +- 19 files changed, 339 insertions(+), 66 deletions(-) create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt index 262f88204e..7189f64540 100644 --- a/ggml/src/ggml-cuda/CMakeLists.txt +++ b/ggml/src/ggml-cuda/CMakeLists.txt @@ -116,11 +116,13 @@ if (CUDAToolkit_FOUND) list(APPEND GGML_SOURCES_CUDA ${SRCS}) add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS) else() - file(GLOB SRCS "template-instances/fattn-vec*q4_0-q4_0.cu") + file(GLOB SRCS "template-instances/fattn-vec-instance-q4_0-q4_0.cu") list(APPEND GGML_SOURCES_CUDA ${SRCS}) - file(GLOB SRCS "template-instances/fattn-vec*q8_0-q8_0.cu") + file(GLOB SRCS "template-instances/fattn-vec-instance-q8_0-q8_0.cu") list(APPEND GGML_SOURCES_CUDA ${SRCS}) - file(GLOB SRCS "template-instances/fattn-vec*f16-f16.cu") + file(GLOB SRCS "template-instances/fattn-vec-instance-f16-f16.cu") + list(APPEND GGML_SOURCES_CUDA ${SRCS}) + file(GLOB SRCS "template-instances/fattn-vec-instance-bf16-bf16.cu") list(APPEND GGML_SOURCES_CUDA ${SRCS}) endif() diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index b6a7460da8..e8722212cf 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -74,6 +74,37 @@ static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_f16( return sum; } +template +static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_bf16( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) { + + const nv_bfloat162 * K_bf16 = (const nv_bfloat162 *) K_c; + GGML_UNUSED(Q_q8); + GGML_UNUSED(Q_ds_v); + + constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + + float sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += nthreads*cpy_ne) { + __align__(16) nv_bfloat162 tmp[cpy_ne]; + ggml_cuda_memcpy_1(tmp, K_bf16 + k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne); +#pragma unroll + for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) { +#ifdef V_DOT2_F32_F16_AVAILABLE + const float2 bf16_f2 = __bfloat1622float2(tmp[k_KQ_1]); + ggml_cuda_mad(sum, make_half2(bf16_f2.x, bf16_f2.y), ((const half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]); +#else + ggml_cuda_mad(sum, __bfloat1622float2(tmp[k_KQ_1]), ((const float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]); +#endif // V_DOT2_F32_F16_AVAILABLE + } + } + + return sum; +} + template static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q4_0( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { @@ -321,6 +352,32 @@ static __device__ __forceinline__ void dequantize_V_f16(const void * __restrict_ } } +template +static __device__ __forceinline__ void dequantize_V_bf16(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) { + if constexpr (std::is_same_v) { + static_assert(ne % 2 == 0, "bad ne"); + __align__(16) nv_bfloat162 tmp[ne/2]; + ggml_cuda_memcpy_1(tmp, (const nv_bfloat16 *) vx + i0); + half2 * dst_h2 = (half2 *) dst; +#pragma unroll + for (int l = 0; l < ne/2; ++l) { + const float2 f2 = __bfloat1622float2(tmp[l]); + dst_h2[l] = make_half2(f2.x, f2.y); + } + } else if constexpr (std::is_same_v) { + static_assert(ne % 2 == 0, "bad ne"); + __align__(16) nv_bfloat162 tmp[ne/2]; + ggml_cuda_memcpy_1(tmp, (const nv_bfloat16 *) vx + i0); + float2 * dst_f2 = (float2 *) dst; +#pragma unroll + for (int l = 0; l < ne/2; ++l) { + dst_f2[l] = __bfloat1622float2(tmp[l]); + } + } else { + static_assert(std::is_same_v, "unsupported type"); + } +} + template static __device__ __forceinline__ void dequantize_V_q4_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) { const block_q4_0 * x = (const block_q4_0 *) vx; @@ -547,6 +604,8 @@ constexpr __device__ vec_dot_KQ_t get_vec_dot_KQ() { return vec_dot_fattn_vec_KQ_q5_1; } else if constexpr (type_K == GGML_TYPE_Q8_0) { return vec_dot_fattn_vec_KQ_q8_0; + } else if constexpr (type_K == GGML_TYPE_BF16) { + return vec_dot_fattn_vec_KQ_bf16; } else { static_assert(type_K == -1, "bad type"); return nullptr; @@ -567,6 +626,8 @@ constexpr __device__ dequantize_V_t get_dequantize_V() { return dequantize_V_q5_1; } else if constexpr (type_V == GGML_TYPE_Q8_0) { return dequantize_V_q8_0; + } else if constexpr (type_V == GGML_TYPE_BF16) { + return dequantize_V_bf16; } else { static_assert(type_V == -1, "bad type"); return nullptr; @@ -781,7 +842,8 @@ static __global__ void flash_attn_combine_results( template void launch_fattn( ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared, - const int nbatch_fa, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE + const int nbatch_fa, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE, + const bool need_bf16_K = false, const bool need_bf16_V = false ) { constexpr int ncols = ncols1 * ncols2; @@ -798,6 +860,8 @@ void launch_fattn( GGML_ASSERT(Q->type == GGML_TYPE_F32); GGML_ASSERT(KQV->type == GGML_TYPE_F32); + GGML_ASSERT(!(need_f16_K && need_bf16_K)); + GGML_ASSERT(!(need_f16_V && need_bf16_V)); GGML_ASSERT(Q->nb[0] == ggml_element_size(Q)); GGML_ASSERT(K->nb[0] == ggml_element_size(K)); @@ -811,11 +875,13 @@ void launch_fattn( const int cc = ggml_cuda_info().devices[id].cc; const int nsm = ggml_cuda_info().devices[id].nsm; - ggml_cuda_pool_alloc K_f16(pool); - ggml_cuda_pool_alloc V_f16(pool); - ggml_cuda_pool_alloc KV_max(pool); - ggml_cuda_pool_alloc dst_tmp(pool); - ggml_cuda_pool_alloc dst_tmp_meta(pool); + ggml_cuda_pool_alloc K_f16(pool); + ggml_cuda_pool_alloc V_f16(pool); + ggml_cuda_pool_alloc K_bf16(pool); + ggml_cuda_pool_alloc V_bf16(pool); + ggml_cuda_pool_alloc KV_max(pool); + ggml_cuda_pool_alloc dst_tmp(pool); + ggml_cuda_pool_alloc dst_tmp_meta(pool); const char * K_data = (const char *) K->data; size_t nb11 = K->nb[1]; @@ -889,6 +955,68 @@ void launch_fattn( } } + if (need_bf16_K && K->type != GGML_TYPE_BF16) { + const size_t bs = ggml_blck_size(K->type); + const size_t ts = ggml_type_size(K->type); + + K_bf16.alloc(ggml_nelements(K)); + if (ggml_is_contiguously_allocated(K)) { + to_bf16_cuda_t to_bf16 = ggml_get_to_bf16_cuda(K->type); + to_bf16(K_data, K_bf16.ptr, ggml_nelements(K), main_stream); + + nb11 = nb11*bs*sizeof(nv_bfloat16)/ts; + nb12 = nb12*bs*sizeof(nv_bfloat16)/ts; + nb13 = nb13*bs*sizeof(nv_bfloat16)/ts; + } else { + GGML_ASSERT(K->nb[0] == ts); + to_bf16_nc_cuda_t to_bf16 = ggml_get_to_bf16_nc_cuda(K->type); + const int64_t s01 = nb11 / ts; + const int64_t s02 = nb12 / ts; + const int64_t s03 = nb13 / ts; + to_bf16(K_data, K_bf16.ptr, K->ne[0], K->ne[1], K->ne[2], K->ne[3], s01, s02, s03, main_stream); + + nb11 = K->ne[0] * sizeof(nv_bfloat16); + nb12 = K->ne[1] * nb11; + nb13 = K->ne[2] * nb12; + } + K_data = (char *) K_bf16.ptr; + } + + if (need_bf16_V && V->type != GGML_TYPE_BF16) { + if (V_is_K_view) { + V_data = K_data; + nb21 = nb11; + nb22 = nb12; + nb23 = nb13; + } else { + const size_t bs = ggml_blck_size(V->type); + const size_t ts = ggml_type_size(V->type); + + V_bf16.alloc(ggml_nelements(V)); + if (ggml_is_contiguously_allocated(V)) { + to_bf16_cuda_t to_bf16 = ggml_get_to_bf16_cuda(V->type); + to_bf16(V_data, V_bf16.ptr, ggml_nelements(V), main_stream); + V_data = (char *) V_bf16.ptr; + + nb21 = nb21*bs*sizeof(nv_bfloat16)/ts; + nb22 = nb22*bs*sizeof(nv_bfloat16)/ts; + nb23 = nb23*bs*sizeof(nv_bfloat16)/ts; + } else { + GGML_ASSERT(V->nb[0] == ts); + to_bf16_nc_cuda_t to_bf16 = ggml_get_to_bf16_nc_cuda(V->type); + const int64_t s01 = nb21 / ts; + const int64_t s02 = nb22 / ts; + const int64_t s03 = nb23 / ts; + to_bf16(V_data, V_bf16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream); + + nb21 = V->ne[0] * sizeof(nv_bfloat16); + nb22 = V->ne[1] * nb21; + nb23 = V->ne[2] * nb22; + } + V_data = (char *) V_bf16.ptr; + } + } + const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1); const int gqa_ratio = Q->ne[2] / K->ne[2]; const int ntiles_z_gqa = ((gqa_ratio + ncols2 - 1) / ncols2); diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh index f3fa80ab23..55396eb4ee 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cuh +++ b/ggml/src/ggml-cuda/fattn-tile.cuh @@ -321,9 +321,9 @@ static constexpr __device__ int ggml_cuda_fattn_tile_get_nbatch_K(const int DKQ, } // TODO: deduplicate with mma-f16 -template +template static __device__ __forceinline__ void flash_attn_tile_load_tile( - const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV, const int i_sup) { + const T_KV * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV, const int i_sup) { constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); constexpr int cpy_ne = cpy_nb / 4; @@ -351,10 +351,24 @@ static __device__ __forceinline__ void flash_attn_tile_load_tile( for (int j0 = j0_start; j0 < j0_stop; j0 += stride_j) { const int j = j0*cpy_ne + (stride_j == warp_size ? threadIdx.x : threadIdx.x % stride_j)*cpy_ne; - const __align__(16) half2 zero[cpy_ne] = {{0.0f, 0.0f}}; - ggml_cuda_memcpy_1( - tile_KV + i*(J/2 + J_padding) + j, - !oob_check || i < i_sup ? KV + i*stride_KV + j : zero); + if constexpr (std::is_same_v) { + const __align__(16) half2 zero[cpy_ne] = {{0.0f, 0.0f}}; + ggml_cuda_memcpy_1( + tile_KV + i*(J/2 + J_padding) + j, + !oob_check || i < i_sup ? KV + i*stride_KV + j : zero); + } else { + const __align__(16) T_KV zero[cpy_ne] = {}; + __align__(16) T_KV tmp[cpy_ne]; + ggml_cuda_memcpy_1( + tmp, !oob_check || i < i_sup ? KV + i*stride_KV + j : zero); + __align__(16) half2 converted[cpy_ne]; +#pragma unroll + for (int l = 0; l < cpy_ne; ++l) { + const float2 f = __bfloat1622float2(tmp[l]); + converted[l] = make_half2(f.x, f.y); + } + ggml_cuda_memcpy_1(tile_KV + i*(J/2 + J_padding) + j, converted); + } } } } @@ -371,9 +385,9 @@ static __device__ __forceinline__ void flash_attn_tile_load_tile( ggml_cuda_unroll<7>{}(load); } -template +template static __device__ __forceinline__ void flash_attn_tile_load_tile( - const half2 * const __restrict__ KV, float * const __restrict__ tile_KV, const int stride_KV, const int i_sup) { + const T_KV * const __restrict__ KV, float * const __restrict__ tile_KV, const int stride_KV, const int i_sup) { constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); constexpr int cpy_ne = cpy_nb / 4; @@ -401,15 +415,19 @@ static __device__ __forceinline__ void flash_attn_tile_load_tile( for (int j0 = j0_start; j0 < j0_stop; j0 += stride_j) { const int j = j0*(cpy_ne/2) + (stride_j == warp_size ? threadIdx.x : threadIdx.x % stride_j)*(cpy_ne/2); - const half2 zero[cpy_ne/2] = {{0.0f, 0.0f}}; - __align__(16) half2 tmp_h2[cpy_ne/2]; - ggml_cuda_memcpy_1( - tmp_h2, !oob_check || i < i_sup ? KV + i*stride_KV + j : zero); + const T_KV zero[cpy_ne/2] = {}; + __align__(16) T_KV tmp_kv[cpy_ne/2]; + ggml_cuda_memcpy_1( + tmp_kv, !oob_check || i < i_sup ? KV + i*stride_KV + j : zero); __align__(16) float2 tmp_f2[cpy_ne/2]; #pragma unroll for (int l = 0; l < cpy_ne/2; ++l) { - tmp_f2[l] = __half22float2(tmp_h2[l]); + if constexpr (std::is_same_v) { + tmp_f2[l] = __half22float2(tmp_kv[l]); + } else { + tmp_f2[l] = __bfloat1622float2(tmp_kv[l]); + } } ggml_cuda_memcpy_1(tile_KV + i*(J + J_padding) + 2*j, tmp_f2); } @@ -428,10 +446,10 @@ static __device__ __forceinline__ void flash_attn_tile_load_tile( // Function that performs a single iteration in for the KQ matrix multiplication: template + bool use_logit_softcap, bool oob_check, typename T_vec_dot, typename T_KV> static __device__ __forceinline__ void flash_attn_tile_iter_KQ( T_vec_dot * const Q_tmp, - const half2 * const __restrict__ K_h2, + const T_KV * const __restrict__ K_kv, T_vec_dot * const KV_tmp, const int stride_K2, const int k_VKQ_0, @@ -446,7 +464,7 @@ static __device__ __forceinline__ void flash_attn_tile_iter_KQ( constexpr int np = nwarps > ncols ? nwarps/ncols : 1; // number of parallel warps per Q column flash_attn_tile_load_tile - (K_h2 + int64_t(k_VKQ_0)*stride_K2 + k_KQ_0/2, KV_tmp, stride_K2, k_VKQ_sup); + (K_kv + int64_t(k_VKQ_0)*stride_K2 + k_KQ_0/2, KV_tmp, stride_K2, k_VKQ_sup); __syncthreads(); #ifdef FAST_FP16_AVAILABLE @@ -503,11 +521,11 @@ static __device__ __forceinline__ void flash_attn_tile_iter_KQ( // Function that performs a single iteration of the main loop over up to nbatch_fa tokens. template + bool use_logit_softcap, bool oob_check, typename T_vec_dot, typename T_KQ, typename T_acc, typename T_KV> static __device__ __forceinline__ void flash_attn_tile_iter( T_vec_dot * const Q_tmp, - const half2 * const __restrict__ K_h2, - const half2 * const __restrict__ V_h2, + const T_KV * const __restrict__ K_kv, + const T_KV * const __restrict__ V_kv, const half * const __restrict__ mask, const uint3 ne01, const float logit_softcap, @@ -555,12 +573,12 @@ static __device__ __forceinline__ void flash_attn_tile_iter( #pragma unroll for (int k_KQ_0 = 0; k_KQ_0 < DKQ - nbatch_K_last; k_KQ_0 += nbatch_K) { flash_attn_tile_iter_KQ( - Q_tmp, K_h2, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, KQ_acc); + Q_tmp, K_kv, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, KQ_acc); } if (nbatch_K_last > 0) { constexpr int k_KQ_0 = DKQ - nbatch_K_last; flash_attn_tile_iter_KQ( - Q_tmp, K_h2, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, KQ_acc); + Q_tmp, K_kv, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, KQ_acc); } // Apply logit softcap + mask, update KQ_max: @@ -666,7 +684,7 @@ static __device__ __forceinline__ void flash_attn_tile_iter( #pragma unroll for (int k0 = 0; k0 < nbatch_fa; k0 += nbatch_V) { flash_attn_tile_load_tile - (V_h2 + int64_t(k_VKQ_0 + k0)*stride_V2, KV_tmp, stride_V2, k_VKQ_sup - k0); + (V_kv + int64_t(k_VKQ_0 + k0)*stride_V2, KV_tmp, stride_V2, k_VKQ_sup - k0); __syncthreads(); #ifdef FAST_FP16_AVAILABLE @@ -735,7 +753,7 @@ static __device__ __forceinline__ void flash_attn_tile_iter( } } -template // D == head size +template // D == head size __launch_bounds__(ggml_cuda_fattn_tile_get_nthreads(DKQ, DV, ncols1*ncols2), ggml_cuda_fattn_tile_get_occupancy(DKQ, DV, ncols1*ncols2)) static __global__ void flash_attn_tile( const char * __restrict__ Q, @@ -798,13 +816,13 @@ static __global__ void flash_attn_tile( const int head0 = blockIdx.z*ncols2 - sequence*ne02; // == blockIdx.z % (ne02/ncols2) const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. const float * Q_f = (const float *) (Q + nb03*sequence + nb02* head0); - const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio)); - const half2 * V_h2 = (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); // K and V have same shape + const T_KV * K_kv = (const T_KV *) (K + nb13*sequence + nb12*(head0 / gqa_ratio)); + const T_KV * V_kv = (const T_KV *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); // K and V have same shape const half * maskh = mask ? (const half *) (mask + nb33*(sequence % ne33)) : nullptr; - const int stride_K2 = nb11 / sizeof(half2); - const int stride_V2 = nb21 / sizeof(half2); + const int stride_K2 = nb11 / sizeof(T_KV); + const int stride_V2 = nb21 / sizeof(T_KV); const int stride_mask = nb31 / sizeof(half); const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f; @@ -900,14 +918,14 @@ static __global__ void flash_attn_tile( while (k_VKQ_0 < k_VKQ_max - nbatch_fa) { constexpr bool oob_check = false; flash_attn_tile_iter - (Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp, + (Q_tmp, K_kv, V_kv, maskh, ne01, logit_softcap, slope, KQ, KV_tmp, stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0); k_VKQ_0 += gridDim.y*nbatch_fa; } if (k_VKQ_0 < k_VKQ_max) { constexpr bool oob_check = true; flash_attn_tile_iter - (Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp, + (Q_tmp, K_kv, V_kv, maskh, ne01, logit_softcap, slope, KQ, KV_tmp, stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0); } } else { @@ -915,7 +933,7 @@ static __global__ void flash_attn_tile( for (int k_VKQ_0 = blockIdx.y*nbatch_fa; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*nbatch_fa) { constexpr bool oob_check = false; flash_attn_tile_iter - (Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp, + (Q_tmp, K_kv, V_kv, maskh, ne01, logit_softcap, slope, KQ, KV_tmp, stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0); } } @@ -1087,6 +1105,24 @@ static __global__ void flash_attn_tile( #endif // FLASH_ATTN_AVAILABLE } +template +static void launch_fattn_tile_kernel( + ggml_backend_cuda_context & ctx, ggml_tensor * dst, + const int nwarps, const size_t nbytes_shared, const int nbatch_fa, const int warp_size) { + const ggml_tensor * K = dst->src[1]; + const bool bf16_kv = K->type == GGML_TYPE_BF16; + + if (bf16_kv) { + fattn_kernel_t fattn_kernel = flash_attn_tile; + launch_fattn + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, false, false, false, warp_size, true, true); + } else { + fattn_kernel_t fattn_kernel = flash_attn_tile; + launch_fattn + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); + } +} + template static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; @@ -1103,9 +1139,8 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm constexpr int cols_per_block = 64; const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); - fattn_kernel_t fattn_kernel = flash_attn_tile; - launch_fattn - (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); + launch_fattn_tile_kernel + (ctx, dst, nwarps, nbytes_shared, nbatch_fa, warp_size); return; } } @@ -1119,9 +1154,8 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm constexpr int cols_per_block = 32; const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); - fattn_kernel_t fattn_kernel = flash_attn_tile; - launch_fattn - (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); + launch_fattn_tile_kernel + (ctx, dst, nwarps, nbytes_shared, nbatch_fa, warp_size); return; } } @@ -1130,9 +1164,8 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm constexpr int cols_per_block = 16; const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); - fattn_kernel_t fattn_kernel = flash_attn_tile; - launch_fattn - (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); + launch_fattn_tile_kernel + (ctx, dst, nwarps, nbytes_shared, nbatch_fa, warp_size); return; } @@ -1141,9 +1174,8 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm constexpr int cols_per_block = 8; const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); - fattn_kernel_t fattn_kernel = flash_attn_tile; - launch_fattn - (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); + launch_fattn_tile_kernel + (ctx, dst, nwarps, nbytes_shared, nbatch_fa, warp_size); return; } } @@ -1153,9 +1185,8 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm constexpr int cols_per_block = 4; const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); - fattn_kernel_t fattn_kernel = flash_attn_tile; - launch_fattn - (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); + launch_fattn_tile_kernel + (ctx, dst, nwarps, nbytes_shared, nbatch_fa, warp_size); return; } } @@ -1164,9 +1195,8 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm constexpr int cols_per_block = 2; const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); - fattn_kernel_t fattn_kernel = flash_attn_tile; - launch_fattn - (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); + launch_fattn_tile_kernel + (ctx, dst, nwarps, nbytes_shared, nbatch_fa, warp_size); return; } diff --git a/ggml/src/ggml-cuda/fattn-vec.cuh b/ggml/src/ggml-cuda/fattn-vec.cuh index 7cbe32633e..f64a331094 100644 --- a/ggml/src/ggml-cuda/fattn-vec.cuh +++ b/ggml/src/ggml-cuda/fattn-vec.cuh @@ -75,17 +75,17 @@ static __global__ void flash_attn_ext_vec( #endif // GGML_USE_HIP constexpr int nthreads = ggml_cuda_fattn_vec_get_nthreads_device(); - constexpr int nthreads_KQ = type_K == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_KQ_q; - constexpr int nthreads_V = type_V == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_V_q; + constexpr int nthreads_KQ = (type_K == GGML_TYPE_F16 || type_K == GGML_TYPE_BF16) ? 128 / cpy_nb : nthreads_KQ_q; + constexpr int nthreads_V = (type_V == GGML_TYPE_F16 || type_V == GGML_TYPE_BF16) ? 128 / cpy_nb : nthreads_V_q; static_assert(WARP_SIZE % nthreads_KQ == 0, "bad nthreads_K"); static_assert(WARP_SIZE % nthreads_V == 0, "bad nthreads_V"); - constexpr int V_rows_per_thread = type_V == GGML_TYPE_F16 ? 2*cpy_ne : 4; + constexpr int V_rows_per_thread = (type_V == GGML_TYPE_F16 || type_V == GGML_TYPE_BF16) ? 2*cpy_ne : 4; constexpr int V_cols_per_iter = WARP_SIZE / nthreads_V; constexpr vec_dot_KQ_t vec_dot_KQ = get_vec_dot_KQ(); - constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16; + constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16 && type_K != GGML_TYPE_BF16; #ifdef V_DOT2_F32_F16_AVAILABLE constexpr dequantize_V_t dequantize_V = get_dequantize_V(); #else @@ -516,10 +516,12 @@ void ggml_cuda_flash_attn_ext_vec_case_impl(ggml_backend_cuda_context & ctx, ggm const int nthreads = ggml_cuda_fattn_vec_get_nthreads_host(cc); const int nwarps = nthreads / WARP_SIZE; fattn_kernel_t fattn_kernel = flash_attn_ext_vec; - const bool need_f16_K = type_K == GGML_TYPE_F16; - const bool need_f16_V = type_V == GGML_TYPE_F16; + const bool need_f16_K = type_K == GGML_TYPE_F16; + const bool need_f16_V = type_V == GGML_TYPE_F16; + const bool need_bf16_K = type_K == GGML_TYPE_BF16; + const bool need_bf16_V = type_V == GGML_TYPE_BF16; constexpr size_t nbytes_shared = 0; - launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false); + launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false, WARP_SIZE, need_bf16_K, need_bf16_V); } template @@ -556,13 +558,14 @@ void ggml_cuda_flash_attn_ext_vec_case(ggml_backend_cuda_context & ctx, ggml_ten template void ggml_cuda_flash_attn_ext_vec_case \ (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ -#define EXTERN_DECL_FATTN_VEC_CASES(D, type_K) \ +#define EXTERN_DECL_FATTN_VEC_CASES(D, type_K) \ extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_F16); \ extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q4_0); \ extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q4_1); \ extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_0); \ extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_1); \ extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q8_0); \ + extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_BF16); \ EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_F16) EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_0) @@ -570,6 +573,7 @@ EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_1) EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_0) EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_1) EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q8_0) +EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_BF16) EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_F16) EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_0) @@ -577,6 +581,7 @@ EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_1) EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_0) EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_1) EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q8_0) +EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_BF16) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_F16) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_0) @@ -584,3 +589,4 @@ EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_1) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_0) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_1) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q8_0) +EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_BF16) diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 85c177f496..a25a890db6 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -224,6 +224,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_F16) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_F16) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_F16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_F16) FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) @@ -231,6 +232,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q4_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_1) @@ -238,6 +240,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q4_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_0) @@ -245,6 +248,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q5_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_1) @@ -252,6 +256,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q5_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q8_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q8_0) @@ -259,10 +264,20 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q8_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q8_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q8_0) + + FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_BF16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_BF16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_BF16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_BF16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_BF16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_BF16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_BF16) #else FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_F16) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_BF16) #endif // GGML_CUDA_FA_ALL_QUANTS GGML_ABORT("fatal error"); @@ -355,6 +370,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const #endif // GGML_CUDA_FA_ALL_QUANTS case GGML_TYPE_Q4_0: case GGML_TYPE_Q8_0: + case GGML_TYPE_BF16: break; default: return BEST_FATTN_KERNEL_NONE; diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu new file mode 100644 index 0000000000..3a2fa99b05 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu new file mode 100644 index 0000000000..60f0f6f795 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_F16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu new file mode 100644 index 0000000000..489e05f08c --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu new file mode 100644 index 0000000000..6fa3c26d30 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu new file mode 100644 index 0000000000..421027fb29 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu new file mode 100644 index 0000000000..abbc943480 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu new file mode 100644 index 0000000000..d641f859d8 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu new file mode 100644 index 0000000000..d1071dc243 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu new file mode 100644 index 0000000000..8afda31423 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu new file mode 100644 index 0000000000..506864ac18 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu new file mode 100644 index 0000000000..0bbda8371e --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu new file mode 100644 index 0000000000..79be24daf9 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu new file mode 100644 index 0000000000..45636e5e70 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index e382df1ae2..3b5ab12fc4 100755 --- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -5,7 +5,7 @@ import os HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 576] -TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0"] +TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_BF16"] SOURCE_FATTN_TILE = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. From adc7d74a073dd712abbe4c5d1dca12088b2b57d1 Mon Sep 17 00:00:00 2001 From: Patrick Buckley Date: Sat, 14 Mar 2026 10:10:09 -0700 Subject: [PATCH 2/4] ggml-cuda: address code owner review feedback reverted tile kernel changes to avoid larger refactor --- ggml/src/ggml-cuda/CMakeLists.txt | 13 ++- ggml/src/ggml-cuda/convert.cuh | 2 + ggml/src/ggml-cuda/fattn-common.cuh | 113 ++++--------------------- ggml/src/ggml-cuda/fattn-tile.cuh | 126 +++++++++++----------------- ggml/src/ggml-cuda/fattn-vec.cuh | 10 +-- 5 files changed, 73 insertions(+), 191 deletions(-) diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt index 7189f64540..419862101d 100644 --- a/ggml/src/ggml-cuda/CMakeLists.txt +++ b/ggml/src/ggml-cuda/CMakeLists.txt @@ -116,14 +116,11 @@ if (CUDAToolkit_FOUND) list(APPEND GGML_SOURCES_CUDA ${SRCS}) add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS) else() - file(GLOB SRCS "template-instances/fattn-vec-instance-q4_0-q4_0.cu") - list(APPEND GGML_SOURCES_CUDA ${SRCS}) - file(GLOB SRCS "template-instances/fattn-vec-instance-q8_0-q8_0.cu") - list(APPEND GGML_SOURCES_CUDA ${SRCS}) - file(GLOB SRCS "template-instances/fattn-vec-instance-f16-f16.cu") - list(APPEND GGML_SOURCES_CUDA ${SRCS}) - file(GLOB SRCS "template-instances/fattn-vec-instance-bf16-bf16.cu") - list(APPEND GGML_SOURCES_CUDA ${SRCS}) + list(APPEND GGML_SOURCES_CUDA + template-instances/fattn-vec-instance-f16-f16.cu + template-instances/fattn-vec-instance-q4_0-q4_0.cu + template-instances/fattn-vec-instance-q8_0-q8_0.cu + template-instances/fattn-vec-instance-bf16-bf16.cu) endif() ggml_add_backend_library(ggml-cuda diff --git a/ggml/src/ggml-cuda/convert.cuh b/ggml/src/ggml-cuda/convert.cuh index 09f9a33f90..3d0f313d33 100644 --- a/ggml/src/ggml-cuda/convert.cuh +++ b/ggml/src/ggml-cuda/convert.cuh @@ -41,6 +41,8 @@ template return __bfloat162float(x); } else if constexpr(std::is_same_v && std::is_same_v) { return __float22half2_rn(x); + } else if constexpr(std::is_same_v && std::is_same_v) { + return __bfloat1622float2(x); } else if constexpr(std::is_same_v && std::is_same_v) { // bypass compile error on cuda 12.0.1 #ifdef GGML_USE_HIP diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index e8722212cf..e5856ab6af 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -93,12 +93,7 @@ static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_bf16( ggml_cuda_memcpy_1(tmp, K_bf16 + k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne); #pragma unroll for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) { -#ifdef V_DOT2_F32_F16_AVAILABLE - const float2 bf16_f2 = __bfloat1622float2(tmp[k_KQ_1]); - ggml_cuda_mad(sum, make_half2(bf16_f2.x, bf16_f2.y), ((const half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]); -#else - ggml_cuda_mad(sum, __bfloat1622float2(tmp[k_KQ_1]), ((const float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]); -#endif // V_DOT2_F32_F16_AVAILABLE + ggml_cuda_mad(sum, ggml_cuda_cast(tmp[k_KQ_1]), ((const float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]); } } @@ -354,27 +349,14 @@ static __device__ __forceinline__ void dequantize_V_f16(const void * __restrict_ template static __device__ __forceinline__ void dequantize_V_bf16(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) { - if constexpr (std::is_same_v) { - static_assert(ne % 2 == 0, "bad ne"); - __align__(16) nv_bfloat162 tmp[ne/2]; - ggml_cuda_memcpy_1(tmp, (const nv_bfloat16 *) vx + i0); - half2 * dst_h2 = (half2 *) dst; + static_assert(std::is_same_v, "BF16 V dequantization only supports float output"); + static_assert(ne % 2 == 0, "bad ne"); + __align__(16) nv_bfloat162 tmp[ne/2]; + ggml_cuda_memcpy_1(tmp, (const nv_bfloat16 *) vx + i0); + float2 * dst_f2 = (float2 *) dst; #pragma unroll - for (int l = 0; l < ne/2; ++l) { - const float2 f2 = __bfloat1622float2(tmp[l]); - dst_h2[l] = make_half2(f2.x, f2.y); - } - } else if constexpr (std::is_same_v) { - static_assert(ne % 2 == 0, "bad ne"); - __align__(16) nv_bfloat162 tmp[ne/2]; - ggml_cuda_memcpy_1(tmp, (const nv_bfloat16 *) vx + i0); - float2 * dst_f2 = (float2 *) dst; -#pragma unroll - for (int l = 0; l < ne/2; ++l) { - dst_f2[l] = __bfloat1622float2(tmp[l]); - } - } else { - static_assert(std::is_same_v, "unsupported type"); + for (int l = 0; l < ne/2; ++l) { + dst_f2[l] = ggml_cuda_cast(tmp[l]); } } @@ -842,8 +824,7 @@ static __global__ void flash_attn_combine_results( template void launch_fattn( ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared, - const int nbatch_fa, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE, - const bool need_bf16_K = false, const bool need_bf16_V = false + const int nbatch_fa, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE ) { constexpr int ncols = ncols1 * ncols2; @@ -860,8 +841,6 @@ void launch_fattn( GGML_ASSERT(Q->type == GGML_TYPE_F32); GGML_ASSERT(KQV->type == GGML_TYPE_F32); - GGML_ASSERT(!(need_f16_K && need_bf16_K)); - GGML_ASSERT(!(need_f16_V && need_bf16_V)); GGML_ASSERT(Q->nb[0] == ggml_element_size(Q)); GGML_ASSERT(K->nb[0] == ggml_element_size(K)); @@ -875,13 +854,11 @@ void launch_fattn( const int cc = ggml_cuda_info().devices[id].cc; const int nsm = ggml_cuda_info().devices[id].nsm; - ggml_cuda_pool_alloc K_f16(pool); - ggml_cuda_pool_alloc V_f16(pool); - ggml_cuda_pool_alloc K_bf16(pool); - ggml_cuda_pool_alloc V_bf16(pool); - ggml_cuda_pool_alloc KV_max(pool); - ggml_cuda_pool_alloc dst_tmp(pool); - ggml_cuda_pool_alloc dst_tmp_meta(pool); + ggml_cuda_pool_alloc K_f16(pool); + ggml_cuda_pool_alloc V_f16(pool); + ggml_cuda_pool_alloc KV_max(pool); + ggml_cuda_pool_alloc dst_tmp(pool); + ggml_cuda_pool_alloc dst_tmp_meta(pool); const char * K_data = (const char *) K->data; size_t nb11 = K->nb[1]; @@ -955,68 +932,6 @@ void launch_fattn( } } - if (need_bf16_K && K->type != GGML_TYPE_BF16) { - const size_t bs = ggml_blck_size(K->type); - const size_t ts = ggml_type_size(K->type); - - K_bf16.alloc(ggml_nelements(K)); - if (ggml_is_contiguously_allocated(K)) { - to_bf16_cuda_t to_bf16 = ggml_get_to_bf16_cuda(K->type); - to_bf16(K_data, K_bf16.ptr, ggml_nelements(K), main_stream); - - nb11 = nb11*bs*sizeof(nv_bfloat16)/ts; - nb12 = nb12*bs*sizeof(nv_bfloat16)/ts; - nb13 = nb13*bs*sizeof(nv_bfloat16)/ts; - } else { - GGML_ASSERT(K->nb[0] == ts); - to_bf16_nc_cuda_t to_bf16 = ggml_get_to_bf16_nc_cuda(K->type); - const int64_t s01 = nb11 / ts; - const int64_t s02 = nb12 / ts; - const int64_t s03 = nb13 / ts; - to_bf16(K_data, K_bf16.ptr, K->ne[0], K->ne[1], K->ne[2], K->ne[3], s01, s02, s03, main_stream); - - nb11 = K->ne[0] * sizeof(nv_bfloat16); - nb12 = K->ne[1] * nb11; - nb13 = K->ne[2] * nb12; - } - K_data = (char *) K_bf16.ptr; - } - - if (need_bf16_V && V->type != GGML_TYPE_BF16) { - if (V_is_K_view) { - V_data = K_data; - nb21 = nb11; - nb22 = nb12; - nb23 = nb13; - } else { - const size_t bs = ggml_blck_size(V->type); - const size_t ts = ggml_type_size(V->type); - - V_bf16.alloc(ggml_nelements(V)); - if (ggml_is_contiguously_allocated(V)) { - to_bf16_cuda_t to_bf16 = ggml_get_to_bf16_cuda(V->type); - to_bf16(V_data, V_bf16.ptr, ggml_nelements(V), main_stream); - V_data = (char *) V_bf16.ptr; - - nb21 = nb21*bs*sizeof(nv_bfloat16)/ts; - nb22 = nb22*bs*sizeof(nv_bfloat16)/ts; - nb23 = nb23*bs*sizeof(nv_bfloat16)/ts; - } else { - GGML_ASSERT(V->nb[0] == ts); - to_bf16_nc_cuda_t to_bf16 = ggml_get_to_bf16_nc_cuda(V->type); - const int64_t s01 = nb21 / ts; - const int64_t s02 = nb22 / ts; - const int64_t s03 = nb23 / ts; - to_bf16(V_data, V_bf16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream); - - nb21 = V->ne[0] * sizeof(nv_bfloat16); - nb22 = V->ne[1] * nb21; - nb23 = V->ne[2] * nb22; - } - V_data = (char *) V_bf16.ptr; - } - } - const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1); const int gqa_ratio = Q->ne[2] / K->ne[2]; const int ntiles_z_gqa = ((gqa_ratio + ncols2 - 1) / ncols2); diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh index 55396eb4ee..f3fa80ab23 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cuh +++ b/ggml/src/ggml-cuda/fattn-tile.cuh @@ -321,9 +321,9 @@ static constexpr __device__ int ggml_cuda_fattn_tile_get_nbatch_K(const int DKQ, } // TODO: deduplicate with mma-f16 -template +template static __device__ __forceinline__ void flash_attn_tile_load_tile( - const T_KV * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV, const int i_sup) { + const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV, const int i_sup) { constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); constexpr int cpy_ne = cpy_nb / 4; @@ -351,24 +351,10 @@ static __device__ __forceinline__ void flash_attn_tile_load_tile( for (int j0 = j0_start; j0 < j0_stop; j0 += stride_j) { const int j = j0*cpy_ne + (stride_j == warp_size ? threadIdx.x : threadIdx.x % stride_j)*cpy_ne; - if constexpr (std::is_same_v) { - const __align__(16) half2 zero[cpy_ne] = {{0.0f, 0.0f}}; - ggml_cuda_memcpy_1( - tile_KV + i*(J/2 + J_padding) + j, - !oob_check || i < i_sup ? KV + i*stride_KV + j : zero); - } else { - const __align__(16) T_KV zero[cpy_ne] = {}; - __align__(16) T_KV tmp[cpy_ne]; - ggml_cuda_memcpy_1( - tmp, !oob_check || i < i_sup ? KV + i*stride_KV + j : zero); - __align__(16) half2 converted[cpy_ne]; -#pragma unroll - for (int l = 0; l < cpy_ne; ++l) { - const float2 f = __bfloat1622float2(tmp[l]); - converted[l] = make_half2(f.x, f.y); - } - ggml_cuda_memcpy_1(tile_KV + i*(J/2 + J_padding) + j, converted); - } + const __align__(16) half2 zero[cpy_ne] = {{0.0f, 0.0f}}; + ggml_cuda_memcpy_1( + tile_KV + i*(J/2 + J_padding) + j, + !oob_check || i < i_sup ? KV + i*stride_KV + j : zero); } } } @@ -385,9 +371,9 @@ static __device__ __forceinline__ void flash_attn_tile_load_tile( ggml_cuda_unroll<7>{}(load); } -template +template static __device__ __forceinline__ void flash_attn_tile_load_tile( - const T_KV * const __restrict__ KV, float * const __restrict__ tile_KV, const int stride_KV, const int i_sup) { + const half2 * const __restrict__ KV, float * const __restrict__ tile_KV, const int stride_KV, const int i_sup) { constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); constexpr int cpy_ne = cpy_nb / 4; @@ -415,19 +401,15 @@ static __device__ __forceinline__ void flash_attn_tile_load_tile( for (int j0 = j0_start; j0 < j0_stop; j0 += stride_j) { const int j = j0*(cpy_ne/2) + (stride_j == warp_size ? threadIdx.x : threadIdx.x % stride_j)*(cpy_ne/2); - const T_KV zero[cpy_ne/2] = {}; - __align__(16) T_KV tmp_kv[cpy_ne/2]; - ggml_cuda_memcpy_1( - tmp_kv, !oob_check || i < i_sup ? KV + i*stride_KV + j : zero); + const half2 zero[cpy_ne/2] = {{0.0f, 0.0f}}; + __align__(16) half2 tmp_h2[cpy_ne/2]; + ggml_cuda_memcpy_1( + tmp_h2, !oob_check || i < i_sup ? KV + i*stride_KV + j : zero); __align__(16) float2 tmp_f2[cpy_ne/2]; #pragma unroll for (int l = 0; l < cpy_ne/2; ++l) { - if constexpr (std::is_same_v) { - tmp_f2[l] = __half22float2(tmp_kv[l]); - } else { - tmp_f2[l] = __bfloat1622float2(tmp_kv[l]); - } + tmp_f2[l] = __half22float2(tmp_h2[l]); } ggml_cuda_memcpy_1(tile_KV + i*(J + J_padding) + 2*j, tmp_f2); } @@ -446,10 +428,10 @@ static __device__ __forceinline__ void flash_attn_tile_load_tile( // Function that performs a single iteration in for the KQ matrix multiplication: template + bool use_logit_softcap, bool oob_check, typename T_vec_dot> static __device__ __forceinline__ void flash_attn_tile_iter_KQ( T_vec_dot * const Q_tmp, - const T_KV * const __restrict__ K_kv, + const half2 * const __restrict__ K_h2, T_vec_dot * const KV_tmp, const int stride_K2, const int k_VKQ_0, @@ -464,7 +446,7 @@ static __device__ __forceinline__ void flash_attn_tile_iter_KQ( constexpr int np = nwarps > ncols ? nwarps/ncols : 1; // number of parallel warps per Q column flash_attn_tile_load_tile - (K_kv + int64_t(k_VKQ_0)*stride_K2 + k_KQ_0/2, KV_tmp, stride_K2, k_VKQ_sup); + (K_h2 + int64_t(k_VKQ_0)*stride_K2 + k_KQ_0/2, KV_tmp, stride_K2, k_VKQ_sup); __syncthreads(); #ifdef FAST_FP16_AVAILABLE @@ -521,11 +503,11 @@ static __device__ __forceinline__ void flash_attn_tile_iter_KQ( // Function that performs a single iteration of the main loop over up to nbatch_fa tokens. template + bool use_logit_softcap, bool oob_check, typename T_vec_dot, typename T_KQ, typename T_acc> static __device__ __forceinline__ void flash_attn_tile_iter( T_vec_dot * const Q_tmp, - const T_KV * const __restrict__ K_kv, - const T_KV * const __restrict__ V_kv, + const half2 * const __restrict__ K_h2, + const half2 * const __restrict__ V_h2, const half * const __restrict__ mask, const uint3 ne01, const float logit_softcap, @@ -573,12 +555,12 @@ static __device__ __forceinline__ void flash_attn_tile_iter( #pragma unroll for (int k_KQ_0 = 0; k_KQ_0 < DKQ - nbatch_K_last; k_KQ_0 += nbatch_K) { flash_attn_tile_iter_KQ( - Q_tmp, K_kv, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, KQ_acc); + Q_tmp, K_h2, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, KQ_acc); } if (nbatch_K_last > 0) { constexpr int k_KQ_0 = DKQ - nbatch_K_last; flash_attn_tile_iter_KQ( - Q_tmp, K_kv, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, KQ_acc); + Q_tmp, K_h2, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, KQ_acc); } // Apply logit softcap + mask, update KQ_max: @@ -684,7 +666,7 @@ static __device__ __forceinline__ void flash_attn_tile_iter( #pragma unroll for (int k0 = 0; k0 < nbatch_fa; k0 += nbatch_V) { flash_attn_tile_load_tile - (V_kv + int64_t(k_VKQ_0 + k0)*stride_V2, KV_tmp, stride_V2, k_VKQ_sup - k0); + (V_h2 + int64_t(k_VKQ_0 + k0)*stride_V2, KV_tmp, stride_V2, k_VKQ_sup - k0); __syncthreads(); #ifdef FAST_FP16_AVAILABLE @@ -753,7 +735,7 @@ static __device__ __forceinline__ void flash_attn_tile_iter( } } -template // D == head size +template // D == head size __launch_bounds__(ggml_cuda_fattn_tile_get_nthreads(DKQ, DV, ncols1*ncols2), ggml_cuda_fattn_tile_get_occupancy(DKQ, DV, ncols1*ncols2)) static __global__ void flash_attn_tile( const char * __restrict__ Q, @@ -816,13 +798,13 @@ static __global__ void flash_attn_tile( const int head0 = blockIdx.z*ncols2 - sequence*ne02; // == blockIdx.z % (ne02/ncols2) const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. const float * Q_f = (const float *) (Q + nb03*sequence + nb02* head0); - const T_KV * K_kv = (const T_KV *) (K + nb13*sequence + nb12*(head0 / gqa_ratio)); - const T_KV * V_kv = (const T_KV *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); // K and V have same shape + const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio)); + const half2 * V_h2 = (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); // K and V have same shape const half * maskh = mask ? (const half *) (mask + nb33*(sequence % ne33)) : nullptr; - const int stride_K2 = nb11 / sizeof(T_KV); - const int stride_V2 = nb21 / sizeof(T_KV); + const int stride_K2 = nb11 / sizeof(half2); + const int stride_V2 = nb21 / sizeof(half2); const int stride_mask = nb31 / sizeof(half); const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f; @@ -918,14 +900,14 @@ static __global__ void flash_attn_tile( while (k_VKQ_0 < k_VKQ_max - nbatch_fa) { constexpr bool oob_check = false; flash_attn_tile_iter - (Q_tmp, K_kv, V_kv, maskh, ne01, logit_softcap, slope, KQ, KV_tmp, + (Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp, stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0); k_VKQ_0 += gridDim.y*nbatch_fa; } if (k_VKQ_0 < k_VKQ_max) { constexpr bool oob_check = true; flash_attn_tile_iter - (Q_tmp, K_kv, V_kv, maskh, ne01, logit_softcap, slope, KQ, KV_tmp, + (Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp, stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0); } } else { @@ -933,7 +915,7 @@ static __global__ void flash_attn_tile( for (int k_VKQ_0 = blockIdx.y*nbatch_fa; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*nbatch_fa) { constexpr bool oob_check = false; flash_attn_tile_iter - (Q_tmp, K_kv, V_kv, maskh, ne01, logit_softcap, slope, KQ, KV_tmp, + (Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp, stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0); } } @@ -1105,24 +1087,6 @@ static __global__ void flash_attn_tile( #endif // FLASH_ATTN_AVAILABLE } -template -static void launch_fattn_tile_kernel( - ggml_backend_cuda_context & ctx, ggml_tensor * dst, - const int nwarps, const size_t nbytes_shared, const int nbatch_fa, const int warp_size) { - const ggml_tensor * K = dst->src[1]; - const bool bf16_kv = K->type == GGML_TYPE_BF16; - - if (bf16_kv) { - fattn_kernel_t fattn_kernel = flash_attn_tile; - launch_fattn - (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, false, false, false, warp_size, true, true); - } else { - fattn_kernel_t fattn_kernel = flash_attn_tile; - launch_fattn - (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); - } -} - template static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; @@ -1139,8 +1103,9 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm constexpr int cols_per_block = 64; const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); - launch_fattn_tile_kernel - (ctx, dst, nwarps, nbytes_shared, nbatch_fa, warp_size); + fattn_kernel_t fattn_kernel = flash_attn_tile; + launch_fattn + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); return; } } @@ -1154,8 +1119,9 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm constexpr int cols_per_block = 32; const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); - launch_fattn_tile_kernel - (ctx, dst, nwarps, nbytes_shared, nbatch_fa, warp_size); + fattn_kernel_t fattn_kernel = flash_attn_tile; + launch_fattn + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); return; } } @@ -1164,8 +1130,9 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm constexpr int cols_per_block = 16; const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); - launch_fattn_tile_kernel - (ctx, dst, nwarps, nbytes_shared, nbatch_fa, warp_size); + fattn_kernel_t fattn_kernel = flash_attn_tile; + launch_fattn + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); return; } @@ -1174,8 +1141,9 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm constexpr int cols_per_block = 8; const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); - launch_fattn_tile_kernel - (ctx, dst, nwarps, nbytes_shared, nbatch_fa, warp_size); + fattn_kernel_t fattn_kernel = flash_attn_tile; + launch_fattn + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); return; } } @@ -1185,8 +1153,9 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm constexpr int cols_per_block = 4; const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); - launch_fattn_tile_kernel - (ctx, dst, nwarps, nbytes_shared, nbatch_fa, warp_size); + fattn_kernel_t fattn_kernel = flash_attn_tile; + launch_fattn + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); return; } } @@ -1195,8 +1164,9 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm constexpr int cols_per_block = 2; const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); - launch_fattn_tile_kernel - (ctx, dst, nwarps, nbytes_shared, nbatch_fa, warp_size); + fattn_kernel_t fattn_kernel = flash_attn_tile; + launch_fattn + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); return; } diff --git a/ggml/src/ggml-cuda/fattn-vec.cuh b/ggml/src/ggml-cuda/fattn-vec.cuh index f64a331094..0b40c903a0 100644 --- a/ggml/src/ggml-cuda/fattn-vec.cuh +++ b/ggml/src/ggml-cuda/fattn-vec.cuh @@ -516,12 +516,10 @@ void ggml_cuda_flash_attn_ext_vec_case_impl(ggml_backend_cuda_context & ctx, ggm const int nthreads = ggml_cuda_fattn_vec_get_nthreads_host(cc); const int nwarps = nthreads / WARP_SIZE; fattn_kernel_t fattn_kernel = flash_attn_ext_vec; - const bool need_f16_K = type_K == GGML_TYPE_F16; - const bool need_f16_V = type_V == GGML_TYPE_F16; - const bool need_bf16_K = type_K == GGML_TYPE_BF16; - const bool need_bf16_V = type_V == GGML_TYPE_BF16; + const bool need_f16_K = type_K == GGML_TYPE_F16; + const bool need_f16_V = type_V == GGML_TYPE_F16; constexpr size_t nbytes_shared = 0; - launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false, WARP_SIZE, need_bf16_K, need_bf16_V); + launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false); } template @@ -558,7 +556,7 @@ void ggml_cuda_flash_attn_ext_vec_case(ggml_backend_cuda_context & ctx, ggml_ten template void ggml_cuda_flash_attn_ext_vec_case \ (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ -#define EXTERN_DECL_FATTN_VEC_CASES(D, type_K) \ +#define EXTERN_DECL_FATTN_VEC_CASES(D, type_K) \ extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_F16); \ extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q4_0); \ extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q4_1); \ From d6b945ef5d46901181120eea38db6ff04aeb97d9 Mon Sep 17 00:00:00 2001 From: Patrick Buckley Date: Sun, 15 Mar 2026 04:15:58 -0700 Subject: [PATCH 3/4] fix ci failures on turing and hip --- ggml/src/ggml-cuda/convert.cuh | 4 ++++ ggml/src/ggml-hip/CMakeLists.txt | 11 +++++------ ggml/src/ggml-musa/CMakeLists.txt | 11 +++++------ 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-cuda/convert.cuh b/ggml/src/ggml-cuda/convert.cuh index 3d0f313d33..4c6667db33 100644 --- a/ggml/src/ggml-cuda/convert.cuh +++ b/ggml/src/ggml-cuda/convert.cuh @@ -42,7 +42,11 @@ template } else if constexpr(std::is_same_v && std::is_same_v) { return __float22half2_rn(x); } else if constexpr(std::is_same_v && std::is_same_v) { +#if !defined(GGML_USE_HIP) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 return __bfloat1622float2(x); +#else + return make_float2(__bfloat162float(__low2bfloat16(x)), __bfloat162float(__high2bfloat16(x))); +#endif } else if constexpr(std::is_same_v && std::is_same_v) { // bypass compile error on cuda 12.0.1 #ifdef GGML_USE_HIP diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt index b44ed0f721..8e3fefb01e 100644 --- a/ggml/src/ggml-hip/CMakeLists.txt +++ b/ggml/src/ggml-hip/CMakeLists.txt @@ -74,12 +74,11 @@ if (GGML_CUDA_FA_ALL_QUANTS) list(APPEND GGML_SOURCES_ROCM ${SRCS}) add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS) else() - file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu") - list(APPEND GGML_SOURCES_ROCM ${SRCS}) - file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu") - list(APPEND GGML_SOURCES_ROCM ${SRCS}) - file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*f16-f16.cu") - list(APPEND GGML_SOURCES_ROCM ${SRCS}) + list(APPEND GGML_SOURCES_ROCM + ../ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu + ../ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu + ../ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu + ../ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu) endif() ggml_add_backend_library(ggml-hip diff --git a/ggml/src/ggml-musa/CMakeLists.txt b/ggml/src/ggml-musa/CMakeLists.txt index d76cb51977..cc53c812ce 100644 --- a/ggml/src/ggml-musa/CMakeLists.txt +++ b/ggml/src/ggml-musa/CMakeLists.txt @@ -48,12 +48,11 @@ if (MUSAToolkit_FOUND) list(APPEND GGML_SOURCES_MUSA ${SRCS}) add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS) else() - file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu") - list(APPEND GGML_SOURCES_MUSA ${SRCS}) - file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu") - list(APPEND GGML_SOURCES_MUSA ${SRCS}) - file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*f16-f16.cu") - list(APPEND GGML_SOURCES_MUSA ${SRCS}) + list(APPEND GGML_SOURCES_MUSA + ../ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu + ../ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu + ../ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu + ../ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu) endif() set_source_files_properties(${GGML_SOURCES_MUSA} PROPERTIES LANGUAGE CXX) From 0ab7687386d3b8b9b8e2654cd82b3700c6667c6a Mon Sep 17 00:00:00 2001 From: Patrick Buckley Date: Sun, 15 Mar 2026 12:49:44 -0700 Subject: [PATCH 4/4] fix bf16 vec kernel compile on hip v_dot2 platforms --- ggml/src/ggml-cuda/fattn-common.cuh | 6 +++++- ggml/src/ggml-cuda/fattn-vec.cuh | 14 ++++++++++++-- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index e5856ab6af..a48d349268 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -93,7 +93,11 @@ static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_bf16( ggml_cuda_memcpy_1(tmp, K_bf16 + k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne); #pragma unroll for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) { +#ifdef V_DOT2_F32_F16_AVAILABLE + ggml_cuda_mad(sum, ggml_cuda_cast(tmp[k_KQ_1]), __half22float2(((const half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1])); +#else ggml_cuda_mad(sum, ggml_cuda_cast(tmp[k_KQ_1]), ((const float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]); +#endif // V_DOT2_F32_F16_AVAILABLE } } @@ -609,7 +613,7 @@ constexpr __device__ dequantize_V_t get_dequantize_V() { } else if constexpr (type_V == GGML_TYPE_Q8_0) { return dequantize_V_q8_0; } else if constexpr (type_V == GGML_TYPE_BF16) { - return dequantize_V_bf16; + return dequantize_V_bf16; } else { static_assert(type_V == -1, "bad type"); return nullptr; diff --git a/ggml/src/ggml-cuda/fattn-vec.cuh b/ggml/src/ggml-cuda/fattn-vec.cuh index 0b40c903a0..f0bd42a576 100644 --- a/ggml/src/ggml-cuda/fattn-vec.cuh +++ b/ggml/src/ggml-cuda/fattn-vec.cuh @@ -323,8 +323,18 @@ static __global__ void flash_attn_ext_vec( #pragma unroll for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) { half2 tmp[V_rows_per_thread/2]; - dequantize_V(V + k*nb21, tmp, - 2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread); + if constexpr (type_V == GGML_TYPE_BF16) { + float2 tmp_f[V_rows_per_thread/2]; + dequantize_V(V + k*nb21, tmp_f, + 2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread); +#pragma unroll + for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) { + tmp[i_VKQ_1] = __float22half2_rn(tmp_f[i_VKQ_1]); + } + } else { + dequantize_V(V + k*nb21, tmp, + 2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread); + } #pragma unroll for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) { #pragma unroll