diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt index 262f88204e..419862101d 100644 --- a/ggml/src/ggml-cuda/CMakeLists.txt +++ b/ggml/src/ggml-cuda/CMakeLists.txt @@ -116,12 +116,11 @@ if (CUDAToolkit_FOUND) list(APPEND GGML_SOURCES_CUDA ${SRCS}) add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS) else() - file(GLOB SRCS "template-instances/fattn-vec*q4_0-q4_0.cu") - list(APPEND GGML_SOURCES_CUDA ${SRCS}) - file(GLOB SRCS "template-instances/fattn-vec*q8_0-q8_0.cu") - list(APPEND GGML_SOURCES_CUDA ${SRCS}) - file(GLOB SRCS "template-instances/fattn-vec*f16-f16.cu") - list(APPEND GGML_SOURCES_CUDA ${SRCS}) + list(APPEND GGML_SOURCES_CUDA + template-instances/fattn-vec-instance-f16-f16.cu + template-instances/fattn-vec-instance-q4_0-q4_0.cu + template-instances/fattn-vec-instance-q8_0-q8_0.cu + template-instances/fattn-vec-instance-bf16-bf16.cu) endif() ggml_add_backend_library(ggml-cuda diff --git a/ggml/src/ggml-cuda/convert.cuh b/ggml/src/ggml-cuda/convert.cuh index 09f9a33f90..4c6667db33 100644 --- a/ggml/src/ggml-cuda/convert.cuh +++ b/ggml/src/ggml-cuda/convert.cuh @@ -41,6 +41,12 @@ template return __bfloat162float(x); } else if constexpr(std::is_same_v && std::is_same_v) { return __float22half2_rn(x); + } else if constexpr(std::is_same_v && std::is_same_v) { +#if !defined(GGML_USE_HIP) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + return __bfloat1622float2(x); +#else + return make_float2(__bfloat162float(__low2bfloat16(x)), __bfloat162float(__high2bfloat16(x))); +#endif } else if constexpr(std::is_same_v && std::is_same_v) { // bypass compile error on cuda 12.0.1 #ifdef GGML_USE_HIP diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index e9abdf288c..24e0d7eaee 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -74,6 +74,36 @@ static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_f16( return sum; } +template +static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_bf16( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) { + + const nv_bfloat162 * K_bf16 = (const nv_bfloat162 *) K_c; + GGML_UNUSED(Q_q8); + GGML_UNUSED(Q_ds_v); + + constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + + float sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += nthreads*cpy_ne) { + __align__(16) nv_bfloat162 tmp[cpy_ne]; + ggml_cuda_memcpy_1(tmp, K_bf16 + k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne); +#pragma unroll + for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) { +#ifdef V_DOT2_F32_F16_AVAILABLE + ggml_cuda_mad(sum, ggml_cuda_cast(tmp[k_KQ_1]), __half22float2(((const half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1])); +#else + ggml_cuda_mad(sum, ggml_cuda_cast(tmp[k_KQ_1]), ((const float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]); +#endif // V_DOT2_F32_F16_AVAILABLE + } + } + + return sum; +} + template static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q4_0( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { @@ -321,6 +351,19 @@ static __device__ __forceinline__ void dequantize_V_f16(const void * __restrict_ } } +template +static __device__ __forceinline__ void dequantize_V_bf16(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) { + static_assert(std::is_same_v, "BF16 V dequantization only supports float output"); + static_assert(ne % 2 == 0, "bad ne"); + __align__(16) nv_bfloat162 tmp[ne/2]; + ggml_cuda_memcpy_1(tmp, (const nv_bfloat16 *) vx + i0); + float2 * dst_f2 = (float2 *) dst; +#pragma unroll + for (int l = 0; l < ne/2; ++l) { + dst_f2[l] = ggml_cuda_cast(tmp[l]); + } +} + template static __device__ __forceinline__ void dequantize_V_q4_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) { const block_q4_0 * x = (const block_q4_0 *) vx; @@ -547,6 +590,8 @@ constexpr __device__ vec_dot_KQ_t get_vec_dot_KQ() { return vec_dot_fattn_vec_KQ_q5_1; } else if constexpr (type_K == GGML_TYPE_Q8_0) { return vec_dot_fattn_vec_KQ_q8_0; + } else if constexpr (type_K == GGML_TYPE_BF16) { + return vec_dot_fattn_vec_KQ_bf16; } else { static_assert(type_K == -1, "bad type"); return nullptr; @@ -567,6 +612,8 @@ constexpr __device__ dequantize_V_t get_dequantize_V() { return dequantize_V_q5_1; } else if constexpr (type_V == GGML_TYPE_Q8_0) { return dequantize_V_q8_0; + } else if constexpr (type_V == GGML_TYPE_BF16) { + return dequantize_V_bf16; } else { static_assert(type_V == -1, "bad type"); return nullptr; diff --git a/ggml/src/ggml-cuda/fattn-vec.cuh b/ggml/src/ggml-cuda/fattn-vec.cuh index 7cbe32633e..f0bd42a576 100644 --- a/ggml/src/ggml-cuda/fattn-vec.cuh +++ b/ggml/src/ggml-cuda/fattn-vec.cuh @@ -75,17 +75,17 @@ static __global__ void flash_attn_ext_vec( #endif // GGML_USE_HIP constexpr int nthreads = ggml_cuda_fattn_vec_get_nthreads_device(); - constexpr int nthreads_KQ = type_K == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_KQ_q; - constexpr int nthreads_V = type_V == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_V_q; + constexpr int nthreads_KQ = (type_K == GGML_TYPE_F16 || type_K == GGML_TYPE_BF16) ? 128 / cpy_nb : nthreads_KQ_q; + constexpr int nthreads_V = (type_V == GGML_TYPE_F16 || type_V == GGML_TYPE_BF16) ? 128 / cpy_nb : nthreads_V_q; static_assert(WARP_SIZE % nthreads_KQ == 0, "bad nthreads_K"); static_assert(WARP_SIZE % nthreads_V == 0, "bad nthreads_V"); - constexpr int V_rows_per_thread = type_V == GGML_TYPE_F16 ? 2*cpy_ne : 4; + constexpr int V_rows_per_thread = (type_V == GGML_TYPE_F16 || type_V == GGML_TYPE_BF16) ? 2*cpy_ne : 4; constexpr int V_cols_per_iter = WARP_SIZE / nthreads_V; constexpr vec_dot_KQ_t vec_dot_KQ = get_vec_dot_KQ(); - constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16; + constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16 && type_K != GGML_TYPE_BF16; #ifdef V_DOT2_F32_F16_AVAILABLE constexpr dequantize_V_t dequantize_V = get_dequantize_V(); #else @@ -323,8 +323,18 @@ static __global__ void flash_attn_ext_vec( #pragma unroll for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) { half2 tmp[V_rows_per_thread/2]; - dequantize_V(V + k*nb21, tmp, - 2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread); + if constexpr (type_V == GGML_TYPE_BF16) { + float2 tmp_f[V_rows_per_thread/2]; + dequantize_V(V + k*nb21, tmp_f, + 2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread); +#pragma unroll + for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) { + tmp[i_VKQ_1] = __float22half2_rn(tmp_f[i_VKQ_1]); + } + } else { + dequantize_V(V + k*nb21, tmp, + 2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread); + } #pragma unroll for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) { #pragma unroll @@ -563,6 +573,7 @@ void ggml_cuda_flash_attn_ext_vec_case(ggml_backend_cuda_context & ctx, ggml_ten extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_0); \ extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_1); \ extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q8_0); \ + extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_BF16); \ EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_F16) EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_0) @@ -570,6 +581,7 @@ EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_1) EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_0) EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_1) EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q8_0) +EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_BF16) EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_F16) EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_0) @@ -577,6 +589,7 @@ EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_1) EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_0) EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_1) EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q8_0) +EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_BF16) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_F16) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_0) @@ -584,3 +597,4 @@ EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_1) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_0) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_1) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q8_0) +EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_BF16) diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 85c177f496..a25a890db6 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -224,6 +224,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_F16) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_F16) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_F16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_F16) FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) @@ -231,6 +232,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q4_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_1) @@ -238,6 +240,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q4_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_0) @@ -245,6 +248,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q5_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_1) @@ -252,6 +256,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q5_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q8_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q8_0) @@ -259,10 +264,20 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q8_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q8_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q8_0) + + FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_BF16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_BF16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_BF16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_BF16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_BF16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_BF16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_BF16) #else FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_F16) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_BF16) #endif // GGML_CUDA_FA_ALL_QUANTS GGML_ABORT("fatal error"); @@ -355,6 +370,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const #endif // GGML_CUDA_FA_ALL_QUANTS case GGML_TYPE_Q4_0: case GGML_TYPE_Q8_0: + case GGML_TYPE_BF16: break; default: return BEST_FATTN_KERNEL_NONE; diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu new file mode 100644 index 0000000000..3a2fa99b05 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu new file mode 100644 index 0000000000..60f0f6f795 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_F16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu new file mode 100644 index 0000000000..489e05f08c --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu new file mode 100644 index 0000000000..6fa3c26d30 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu new file mode 100644 index 0000000000..421027fb29 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu new file mode 100644 index 0000000000..abbc943480 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu new file mode 100644 index 0000000000..d641f859d8 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu new file mode 100644 index 0000000000..d1071dc243 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu new file mode 100644 index 0000000000..8afda31423 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu new file mode 100644 index 0000000000..506864ac18 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu new file mode 100644 index 0000000000..0bbda8371e --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu new file mode 100644 index 0000000000..79be24daf9 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu new file mode 100644 index 0000000000..45636e5e70 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index e382df1ae2..3b5ab12fc4 100755 --- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -5,7 +5,7 @@ import os HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 576] -TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0"] +TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_BF16"] SOURCE_FATTN_TILE = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt index b44ed0f721..8e3fefb01e 100644 --- a/ggml/src/ggml-hip/CMakeLists.txt +++ b/ggml/src/ggml-hip/CMakeLists.txt @@ -74,12 +74,11 @@ if (GGML_CUDA_FA_ALL_QUANTS) list(APPEND GGML_SOURCES_ROCM ${SRCS}) add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS) else() - file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu") - list(APPEND GGML_SOURCES_ROCM ${SRCS}) - file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu") - list(APPEND GGML_SOURCES_ROCM ${SRCS}) - file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*f16-f16.cu") - list(APPEND GGML_SOURCES_ROCM ${SRCS}) + list(APPEND GGML_SOURCES_ROCM + ../ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu + ../ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu + ../ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu + ../ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu) endif() ggml_add_backend_library(ggml-hip diff --git a/ggml/src/ggml-musa/CMakeLists.txt b/ggml/src/ggml-musa/CMakeLists.txt index d76cb51977..cc53c812ce 100644 --- a/ggml/src/ggml-musa/CMakeLists.txt +++ b/ggml/src/ggml-musa/CMakeLists.txt @@ -48,12 +48,11 @@ if (MUSAToolkit_FOUND) list(APPEND GGML_SOURCES_MUSA ${SRCS}) add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS) else() - file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu") - list(APPEND GGML_SOURCES_MUSA ${SRCS}) - file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu") - list(APPEND GGML_SOURCES_MUSA ${SRCS}) - file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*f16-f16.cu") - list(APPEND GGML_SOURCES_MUSA ${SRCS}) + list(APPEND GGML_SOURCES_MUSA + ../ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu + ../ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu + ../ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu + ../ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu) endif() set_source_files_properties(${GGML_SOURCES_MUSA} PROPERTIES LANGUAGE CXX)