ggml-cuda: native bf16 flash attention for vec and tile kernels
mma kernel still converts bf16 to fp16 before launch, native mma bf16 todo
This commit is contained in:
parent
f17b3be63f
commit
b73d1557eb
|
|
@ -116,11 +116,13 @@ if (CUDAToolkit_FOUND)
|
|||
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
||||
add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS)
|
||||
else()
|
||||
file(GLOB SRCS "template-instances/fattn-vec*q4_0-q4_0.cu")
|
||||
file(GLOB SRCS "template-instances/fattn-vec-instance-q4_0-q4_0.cu")
|
||||
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
||||
file(GLOB SRCS "template-instances/fattn-vec*q8_0-q8_0.cu")
|
||||
file(GLOB SRCS "template-instances/fattn-vec-instance-q8_0-q8_0.cu")
|
||||
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
||||
file(GLOB SRCS "template-instances/fattn-vec*f16-f16.cu")
|
||||
file(GLOB SRCS "template-instances/fattn-vec-instance-f16-f16.cu")
|
||||
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
||||
file(GLOB SRCS "template-instances/fattn-vec-instance-bf16-bf16.cu")
|
||||
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
||||
endif()
|
||||
|
||||
|
|
|
|||
|
|
@ -74,6 +74,37 @@ static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_f16(
|
|||
return sum;
|
||||
}
|
||||
|
||||
template <int D, int nthreads>
|
||||
static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_bf16(
|
||||
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) {
|
||||
|
||||
const nv_bfloat162 * K_bf16 = (const nv_bfloat162 *) K_c;
|
||||
GGML_UNUSED(Q_q8);
|
||||
GGML_UNUSED(Q_ds_v);
|
||||
|
||||
constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
|
||||
constexpr int cpy_ne = cpy_nb / 4;
|
||||
|
||||
float sum = 0.0f;
|
||||
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += nthreads*cpy_ne) {
|
||||
__align__(16) nv_bfloat162 tmp[cpy_ne];
|
||||
ggml_cuda_memcpy_1<sizeof(tmp)>(tmp, K_bf16 + k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne);
|
||||
#pragma unroll
|
||||
for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) {
|
||||
#ifdef V_DOT2_F32_F16_AVAILABLE
|
||||
const float2 bf16_f2 = __bfloat1622float2(tmp[k_KQ_1]);
|
||||
ggml_cuda_mad(sum, make_half2(bf16_f2.x, bf16_f2.y), ((const half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
|
||||
#else
|
||||
ggml_cuda_mad(sum, __bfloat1622float2(tmp[k_KQ_1]), ((const float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
|
||||
#endif // V_DOT2_F32_F16_AVAILABLE
|
||||
}
|
||||
}
|
||||
|
||||
return sum;
|
||||
}
|
||||
|
||||
template<int D, int nthreads>
|
||||
static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q4_0(
|
||||
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
||||
|
|
@ -321,6 +352,32 @@ static __device__ __forceinline__ void dequantize_V_f16(const void * __restrict_
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T, int ne>
|
||||
static __device__ __forceinline__ void dequantize_V_bf16(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
|
||||
if constexpr (std::is_same_v<T, half>) {
|
||||
static_assert(ne % 2 == 0, "bad ne");
|
||||
__align__(16) nv_bfloat162 tmp[ne/2];
|
||||
ggml_cuda_memcpy_1<ne*sizeof(nv_bfloat16)>(tmp, (const nv_bfloat16 *) vx + i0);
|
||||
half2 * dst_h2 = (half2 *) dst;
|
||||
#pragma unroll
|
||||
for (int l = 0; l < ne/2; ++l) {
|
||||
const float2 f2 = __bfloat1622float2(tmp[l]);
|
||||
dst_h2[l] = make_half2(f2.x, f2.y);
|
||||
}
|
||||
} else if constexpr (std::is_same_v<T, float>) {
|
||||
static_assert(ne % 2 == 0, "bad ne");
|
||||
__align__(16) nv_bfloat162 tmp[ne/2];
|
||||
ggml_cuda_memcpy_1<ne*sizeof(nv_bfloat16)>(tmp, (const nv_bfloat16 *) vx + i0);
|
||||
float2 * dst_f2 = (float2 *) dst;
|
||||
#pragma unroll
|
||||
for (int l = 0; l < ne/2; ++l) {
|
||||
dst_f2[l] = __bfloat1622float2(tmp[l]);
|
||||
}
|
||||
} else {
|
||||
static_assert(std::is_same_v<T, void>, "unsupported type");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int ne>
|
||||
static __device__ __forceinline__ void dequantize_V_q4_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
|
||||
const block_q4_0 * x = (const block_q4_0 *) vx;
|
||||
|
|
@ -547,6 +604,8 @@ constexpr __device__ vec_dot_KQ_t get_vec_dot_KQ() {
|
|||
return vec_dot_fattn_vec_KQ_q5_1<D, nthreads>;
|
||||
} else if constexpr (type_K == GGML_TYPE_Q8_0) {
|
||||
return vec_dot_fattn_vec_KQ_q8_0<D, nthreads>;
|
||||
} else if constexpr (type_K == GGML_TYPE_BF16) {
|
||||
return vec_dot_fattn_vec_KQ_bf16<D, nthreads>;
|
||||
} else {
|
||||
static_assert(type_K == -1, "bad type");
|
||||
return nullptr;
|
||||
|
|
@ -567,6 +626,8 @@ constexpr __device__ dequantize_V_t get_dequantize_V() {
|
|||
return dequantize_V_q5_1<T, ne>;
|
||||
} else if constexpr (type_V == GGML_TYPE_Q8_0) {
|
||||
return dequantize_V_q8_0<T, ne>;
|
||||
} else if constexpr (type_V == GGML_TYPE_BF16) {
|
||||
return dequantize_V_bf16<T, ne>;
|
||||
} else {
|
||||
static_assert(type_V == -1, "bad type");
|
||||
return nullptr;
|
||||
|
|
@ -781,7 +842,8 @@ static __global__ void flash_attn_combine_results(
|
|||
template <int DV, int ncols1, int ncols2>
|
||||
void launch_fattn(
|
||||
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
|
||||
const int nbatch_fa, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
|
||||
const int nbatch_fa, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE,
|
||||
const bool need_bf16_K = false, const bool need_bf16_V = false
|
||||
) {
|
||||
constexpr int ncols = ncols1 * ncols2;
|
||||
|
||||
|
|
@ -798,6 +860,8 @@ void launch_fattn(
|
|||
|
||||
GGML_ASSERT(Q->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(KQV->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(!(need_f16_K && need_bf16_K));
|
||||
GGML_ASSERT(!(need_f16_V && need_bf16_V));
|
||||
|
||||
GGML_ASSERT(Q->nb[0] == ggml_element_size(Q));
|
||||
GGML_ASSERT(K->nb[0] == ggml_element_size(K));
|
||||
|
|
@ -811,11 +875,13 @@ void launch_fattn(
|
|||
const int cc = ggml_cuda_info().devices[id].cc;
|
||||
const int nsm = ggml_cuda_info().devices[id].nsm;
|
||||
|
||||
ggml_cuda_pool_alloc<half> K_f16(pool);
|
||||
ggml_cuda_pool_alloc<half> V_f16(pool);
|
||||
ggml_cuda_pool_alloc<int> KV_max(pool);
|
||||
ggml_cuda_pool_alloc<float> dst_tmp(pool);
|
||||
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
|
||||
ggml_cuda_pool_alloc<half> K_f16(pool);
|
||||
ggml_cuda_pool_alloc<half> V_f16(pool);
|
||||
ggml_cuda_pool_alloc<nv_bfloat16> K_bf16(pool);
|
||||
ggml_cuda_pool_alloc<nv_bfloat16> V_bf16(pool);
|
||||
ggml_cuda_pool_alloc<int> KV_max(pool);
|
||||
ggml_cuda_pool_alloc<float> dst_tmp(pool);
|
||||
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
|
||||
|
||||
const char * K_data = (const char *) K->data;
|
||||
size_t nb11 = K->nb[1];
|
||||
|
|
@ -889,6 +955,68 @@ void launch_fattn(
|
|||
}
|
||||
}
|
||||
|
||||
if (need_bf16_K && K->type != GGML_TYPE_BF16) {
|
||||
const size_t bs = ggml_blck_size(K->type);
|
||||
const size_t ts = ggml_type_size(K->type);
|
||||
|
||||
K_bf16.alloc(ggml_nelements(K));
|
||||
if (ggml_is_contiguously_allocated(K)) {
|
||||
to_bf16_cuda_t to_bf16 = ggml_get_to_bf16_cuda(K->type);
|
||||
to_bf16(K_data, K_bf16.ptr, ggml_nelements(K), main_stream);
|
||||
|
||||
nb11 = nb11*bs*sizeof(nv_bfloat16)/ts;
|
||||
nb12 = nb12*bs*sizeof(nv_bfloat16)/ts;
|
||||
nb13 = nb13*bs*sizeof(nv_bfloat16)/ts;
|
||||
} else {
|
||||
GGML_ASSERT(K->nb[0] == ts);
|
||||
to_bf16_nc_cuda_t to_bf16 = ggml_get_to_bf16_nc_cuda(K->type);
|
||||
const int64_t s01 = nb11 / ts;
|
||||
const int64_t s02 = nb12 / ts;
|
||||
const int64_t s03 = nb13 / ts;
|
||||
to_bf16(K_data, K_bf16.ptr, K->ne[0], K->ne[1], K->ne[2], K->ne[3], s01, s02, s03, main_stream);
|
||||
|
||||
nb11 = K->ne[0] * sizeof(nv_bfloat16);
|
||||
nb12 = K->ne[1] * nb11;
|
||||
nb13 = K->ne[2] * nb12;
|
||||
}
|
||||
K_data = (char *) K_bf16.ptr;
|
||||
}
|
||||
|
||||
if (need_bf16_V && V->type != GGML_TYPE_BF16) {
|
||||
if (V_is_K_view) {
|
||||
V_data = K_data;
|
||||
nb21 = nb11;
|
||||
nb22 = nb12;
|
||||
nb23 = nb13;
|
||||
} else {
|
||||
const size_t bs = ggml_blck_size(V->type);
|
||||
const size_t ts = ggml_type_size(V->type);
|
||||
|
||||
V_bf16.alloc(ggml_nelements(V));
|
||||
if (ggml_is_contiguously_allocated(V)) {
|
||||
to_bf16_cuda_t to_bf16 = ggml_get_to_bf16_cuda(V->type);
|
||||
to_bf16(V_data, V_bf16.ptr, ggml_nelements(V), main_stream);
|
||||
V_data = (char *) V_bf16.ptr;
|
||||
|
||||
nb21 = nb21*bs*sizeof(nv_bfloat16)/ts;
|
||||
nb22 = nb22*bs*sizeof(nv_bfloat16)/ts;
|
||||
nb23 = nb23*bs*sizeof(nv_bfloat16)/ts;
|
||||
} else {
|
||||
GGML_ASSERT(V->nb[0] == ts);
|
||||
to_bf16_nc_cuda_t to_bf16 = ggml_get_to_bf16_nc_cuda(V->type);
|
||||
const int64_t s01 = nb21 / ts;
|
||||
const int64_t s02 = nb22 / ts;
|
||||
const int64_t s03 = nb23 / ts;
|
||||
to_bf16(V_data, V_bf16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream);
|
||||
|
||||
nb21 = V->ne[0] * sizeof(nv_bfloat16);
|
||||
nb22 = V->ne[1] * nb21;
|
||||
nb23 = V->ne[2] * nb22;
|
||||
}
|
||||
V_data = (char *) V_bf16.ptr;
|
||||
}
|
||||
}
|
||||
|
||||
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
|
||||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||
const int ntiles_z_gqa = ((gqa_ratio + ncols2 - 1) / ncols2);
|
||||
|
|
|
|||
|
|
@ -321,9 +321,9 @@ static constexpr __device__ int ggml_cuda_fattn_tile_get_nbatch_K(const int DKQ,
|
|||
}
|
||||
|
||||
// TODO: deduplicate with mma-f16
|
||||
template<int warp_size, int nwarps, int I, int J, int J_padding, bool oob_check>
|
||||
template<int warp_size, int nwarps, int I, int J, int J_padding, bool oob_check, typename T_KV>
|
||||
static __device__ __forceinline__ void flash_attn_tile_load_tile(
|
||||
const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV, const int i_sup) {
|
||||
const T_KV * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV, const int i_sup) {
|
||||
constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
|
||||
constexpr int cpy_ne = cpy_nb / 4;
|
||||
|
||||
|
|
@ -351,10 +351,24 @@ static __device__ __forceinline__ void flash_attn_tile_load_tile(
|
|||
for (int j0 = j0_start; j0 < j0_stop; j0 += stride_j) {
|
||||
const int j = j0*cpy_ne + (stride_j == warp_size ? threadIdx.x : threadIdx.x % stride_j)*cpy_ne;
|
||||
|
||||
const __align__(16) half2 zero[cpy_ne] = {{0.0f, 0.0f}};
|
||||
ggml_cuda_memcpy_1<cpy_nb>(
|
||||
tile_KV + i*(J/2 + J_padding) + j,
|
||||
!oob_check || i < i_sup ? KV + i*stride_KV + j : zero);
|
||||
if constexpr (std::is_same_v<T_KV, half2>) {
|
||||
const __align__(16) half2 zero[cpy_ne] = {{0.0f, 0.0f}};
|
||||
ggml_cuda_memcpy_1<cpy_nb>(
|
||||
tile_KV + i*(J/2 + J_padding) + j,
|
||||
!oob_check || i < i_sup ? KV + i*stride_KV + j : zero);
|
||||
} else {
|
||||
const __align__(16) T_KV zero[cpy_ne] = {};
|
||||
__align__(16) T_KV tmp[cpy_ne];
|
||||
ggml_cuda_memcpy_1<cpy_nb>(
|
||||
tmp, !oob_check || i < i_sup ? KV + i*stride_KV + j : zero);
|
||||
__align__(16) half2 converted[cpy_ne];
|
||||
#pragma unroll
|
||||
for (int l = 0; l < cpy_ne; ++l) {
|
||||
const float2 f = __bfloat1622float2(tmp[l]);
|
||||
converted[l] = make_half2(f.x, f.y);
|
||||
}
|
||||
ggml_cuda_memcpy_1<sizeof(converted)>(tile_KV + i*(J/2 + J_padding) + j, converted);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -371,9 +385,9 @@ static __device__ __forceinline__ void flash_attn_tile_load_tile(
|
|||
ggml_cuda_unroll<7>{}(load);
|
||||
}
|
||||
|
||||
template<int warp_size, int nwarps, int I, int J, int J_padding, bool oob_check>
|
||||
template<int warp_size, int nwarps, int I, int J, int J_padding, bool oob_check, typename T_KV>
|
||||
static __device__ __forceinline__ void flash_attn_tile_load_tile(
|
||||
const half2 * const __restrict__ KV, float * const __restrict__ tile_KV, const int stride_KV, const int i_sup) {
|
||||
const T_KV * const __restrict__ KV, float * const __restrict__ tile_KV, const int stride_KV, const int i_sup) {
|
||||
constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
|
||||
constexpr int cpy_ne = cpy_nb / 4;
|
||||
|
||||
|
|
@ -401,15 +415,19 @@ static __device__ __forceinline__ void flash_attn_tile_load_tile(
|
|||
for (int j0 = j0_start; j0 < j0_stop; j0 += stride_j) {
|
||||
const int j = j0*(cpy_ne/2) + (stride_j == warp_size ? threadIdx.x : threadIdx.x % stride_j)*(cpy_ne/2);
|
||||
|
||||
const half2 zero[cpy_ne/2] = {{0.0f, 0.0f}};
|
||||
__align__(16) half2 tmp_h2[cpy_ne/2];
|
||||
ggml_cuda_memcpy_1<sizeof(tmp_h2)>(
|
||||
tmp_h2, !oob_check || i < i_sup ? KV + i*stride_KV + j : zero);
|
||||
const T_KV zero[cpy_ne/2] = {};
|
||||
__align__(16) T_KV tmp_kv[cpy_ne/2];
|
||||
ggml_cuda_memcpy_1<sizeof(tmp_kv)>(
|
||||
tmp_kv, !oob_check || i < i_sup ? KV + i*stride_KV + j : zero);
|
||||
|
||||
__align__(16) float2 tmp_f2[cpy_ne/2];
|
||||
#pragma unroll
|
||||
for (int l = 0; l < cpy_ne/2; ++l) {
|
||||
tmp_f2[l] = __half22float2(tmp_h2[l]);
|
||||
if constexpr (std::is_same_v<T_KV, half2>) {
|
||||
tmp_f2[l] = __half22float2(tmp_kv[l]);
|
||||
} else {
|
||||
tmp_f2[l] = __bfloat1622float2(tmp_kv[l]);
|
||||
}
|
||||
}
|
||||
ggml_cuda_memcpy_1<sizeof(tmp_f2)>(tile_KV + i*(J + J_padding) + 2*j, tmp_f2);
|
||||
}
|
||||
|
|
@ -428,10 +446,10 @@ static __device__ __forceinline__ void flash_attn_tile_load_tile(
|
|||
|
||||
// Function that performs a single iteration in for the KQ matrix multiplication:
|
||||
template <int warp_size, int nwarps, int ncols1, int ncols2, int DKQ, int nbatch_fa, int nbatch_K,
|
||||
bool use_logit_softcap, bool oob_check, typename T_vec_dot>
|
||||
bool use_logit_softcap, bool oob_check, typename T_vec_dot, typename T_KV>
|
||||
static __device__ __forceinline__ void flash_attn_tile_iter_KQ(
|
||||
T_vec_dot * const Q_tmp,
|
||||
const half2 * const __restrict__ K_h2,
|
||||
const T_KV * const __restrict__ K_kv,
|
||||
T_vec_dot * const KV_tmp,
|
||||
const int stride_K2,
|
||||
const int k_VKQ_0,
|
||||
|
|
@ -446,7 +464,7 @@ static __device__ __forceinline__ void flash_attn_tile_iter_KQ(
|
|||
constexpr int np = nwarps > ncols ? nwarps/ncols : 1; // number of parallel warps per Q column
|
||||
|
||||
flash_attn_tile_load_tile<warp_size, nwarps, nbatch_fa, nbatch_K, cpy_ne, oob_check>
|
||||
(K_h2 + int64_t(k_VKQ_0)*stride_K2 + k_KQ_0/2, KV_tmp, stride_K2, k_VKQ_sup);
|
||||
(K_kv + int64_t(k_VKQ_0)*stride_K2 + k_KQ_0/2, KV_tmp, stride_K2, k_VKQ_sup);
|
||||
__syncthreads();
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
|
|
@ -503,11 +521,11 @@ static __device__ __forceinline__ void flash_attn_tile_iter_KQ(
|
|||
|
||||
// Function that performs a single iteration of the main loop over up to nbatch_fa tokens.
|
||||
template <int warp_size, int nwarps, int ncols1, int ncols2, int DKQ, int DV, int nbatch_fa, int nbatch_K,
|
||||
bool use_logit_softcap, bool oob_check, typename T_vec_dot, typename T_KQ, typename T_acc>
|
||||
bool use_logit_softcap, bool oob_check, typename T_vec_dot, typename T_KQ, typename T_acc, typename T_KV>
|
||||
static __device__ __forceinline__ void flash_attn_tile_iter(
|
||||
T_vec_dot * const Q_tmp,
|
||||
const half2 * const __restrict__ K_h2,
|
||||
const half2 * const __restrict__ V_h2,
|
||||
const T_KV * const __restrict__ K_kv,
|
||||
const T_KV * const __restrict__ V_kv,
|
||||
const half * const __restrict__ mask,
|
||||
const uint3 ne01,
|
||||
const float logit_softcap,
|
||||
|
|
@ -555,12 +573,12 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
|
|||
#pragma unroll
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < DKQ - nbatch_K_last; k_KQ_0 += nbatch_K) {
|
||||
flash_attn_tile_iter_KQ<warp_size, nwarps, ncols1, ncols2, DKQ, nbatch_fa, nbatch_K, use_logit_softcap, oob_check>(
|
||||
Q_tmp, K_h2, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, KQ_acc);
|
||||
Q_tmp, K_kv, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, KQ_acc);
|
||||
}
|
||||
if (nbatch_K_last > 0) {
|
||||
constexpr int k_KQ_0 = DKQ - nbatch_K_last;
|
||||
flash_attn_tile_iter_KQ<warp_size, nwarps, ncols1, ncols2, DKQ, nbatch_fa, nbatch_K_last, use_logit_softcap, oob_check>(
|
||||
Q_tmp, K_h2, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, KQ_acc);
|
||||
Q_tmp, K_kv, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, KQ_acc);
|
||||
}
|
||||
|
||||
// Apply logit softcap + mask, update KQ_max:
|
||||
|
|
@ -666,7 +684,7 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
|
|||
#pragma unroll
|
||||
for (int k0 = 0; k0 < nbatch_fa; k0 += nbatch_V) {
|
||||
flash_attn_tile_load_tile<warp_size, nwarps, nbatch_V, DV, 0, oob_check>
|
||||
(V_h2 + int64_t(k_VKQ_0 + k0)*stride_V2, KV_tmp, stride_V2, k_VKQ_sup - k0);
|
||||
(V_kv + int64_t(k_VKQ_0 + k0)*stride_V2, KV_tmp, stride_V2, k_VKQ_sup - k0);
|
||||
__syncthreads();
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
|
|
@ -735,7 +753,7 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
|
|||
}
|
||||
}
|
||||
|
||||
template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap> // D == head size
|
||||
template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap, typename T_KV = half2> // D == head size
|
||||
__launch_bounds__(ggml_cuda_fattn_tile_get_nthreads(DKQ, DV, ncols1*ncols2), ggml_cuda_fattn_tile_get_occupancy(DKQ, DV, ncols1*ncols2))
|
||||
static __global__ void flash_attn_tile(
|
||||
const char * __restrict__ Q,
|
||||
|
|
@ -798,13 +816,13 @@ static __global__ void flash_attn_tile(
|
|||
const int head0 = blockIdx.z*ncols2 - sequence*ne02; // == blockIdx.z % (ne02/ncols2)
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
const float * Q_f = (const float *) (Q + nb03*sequence + nb02* head0);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
|
||||
const half2 * V_h2 = (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); // K and V have same shape
|
||||
const T_KV * K_kv = (const T_KV *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
|
||||
const T_KV * V_kv = (const T_KV *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); // K and V have same shape
|
||||
|
||||
const half * maskh = mask ? (const half *) (mask + nb33*(sequence % ne33)) : nullptr;
|
||||
|
||||
const int stride_K2 = nb11 / sizeof(half2);
|
||||
const int stride_V2 = nb21 / sizeof(half2);
|
||||
const int stride_K2 = nb11 / sizeof(T_KV);
|
||||
const int stride_V2 = nb21 / sizeof(T_KV);
|
||||
const int stride_mask = nb31 / sizeof(half);
|
||||
|
||||
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
|
||||
|
|
@ -900,14 +918,14 @@ static __global__ void flash_attn_tile(
|
|||
while (k_VKQ_0 < k_VKQ_max - nbatch_fa) {
|
||||
constexpr bool oob_check = false;
|
||||
flash_attn_tile_iter<warp_size, nwarps, ncols1, ncols2, DKQ, DV, nbatch_fa, nbatch_K, use_logit_softcap, oob_check>
|
||||
(Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp,
|
||||
(Q_tmp, K_kv, V_kv, maskh, ne01, logit_softcap, slope, KQ, KV_tmp,
|
||||
stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0);
|
||||
k_VKQ_0 += gridDim.y*nbatch_fa;
|
||||
}
|
||||
if (k_VKQ_0 < k_VKQ_max) {
|
||||
constexpr bool oob_check = true;
|
||||
flash_attn_tile_iter<warp_size, nwarps, ncols1, ncols2, DKQ, DV, nbatch_fa, nbatch_K, use_logit_softcap, oob_check>
|
||||
(Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp,
|
||||
(Q_tmp, K_kv, V_kv, maskh, ne01, logit_softcap, slope, KQ, KV_tmp,
|
||||
stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0);
|
||||
}
|
||||
} else {
|
||||
|
|
@ -915,7 +933,7 @@ static __global__ void flash_attn_tile(
|
|||
for (int k_VKQ_0 = blockIdx.y*nbatch_fa; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*nbatch_fa) {
|
||||
constexpr bool oob_check = false;
|
||||
flash_attn_tile_iter<warp_size, nwarps, ncols1, ncols2, DKQ, DV, nbatch_fa, nbatch_K, use_logit_softcap, oob_check>
|
||||
(Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp,
|
||||
(Q_tmp, K_kv, V_kv, maskh, ne01, logit_softcap, slope, KQ, KV_tmp,
|
||||
stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0);
|
||||
}
|
||||
}
|
||||
|
|
@ -1087,6 +1105,24 @@ static __global__ void flash_attn_tile(
|
|||
#endif // FLASH_ATTN_AVAILABLE
|
||||
}
|
||||
|
||||
template <int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap>
|
||||
static void launch_fattn_tile_kernel(
|
||||
ggml_backend_cuda_context & ctx, ggml_tensor * dst,
|
||||
const int nwarps, const size_t nbytes_shared, const int nbatch_fa, const int warp_size) {
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const bool bf16_kv = K->type == GGML_TYPE_BF16;
|
||||
|
||||
if (bf16_kv) {
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, ncols1, ncols2, use_logit_softcap, nv_bfloat162>;
|
||||
launch_fattn<DV, ncols1, ncols2>
|
||||
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, false, false, false, warp_size, true, true);
|
||||
} else {
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, ncols1, ncols2, use_logit_softcap, half2>;
|
||||
launch_fattn<DV, ncols1, ncols2>
|
||||
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);
|
||||
}
|
||||
}
|
||||
|
||||
template <int DKQ, int DV, int ncols2, bool use_logit_softcap>
|
||||
static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
|
|
@ -1103,9 +1139,8 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm
|
|||
constexpr int cols_per_block = 64;
|
||||
const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
|
||||
const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>;
|
||||
launch_fattn<DV, cols_per_block/ncols2, ncols2>
|
||||
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);
|
||||
launch_fattn_tile_kernel<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>
|
||||
(ctx, dst, nwarps, nbytes_shared, nbatch_fa, warp_size);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
|
@ -1119,9 +1154,8 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm
|
|||
constexpr int cols_per_block = 32;
|
||||
const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
|
||||
const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>;
|
||||
launch_fattn<DV, cols_per_block/ncols2, ncols2>
|
||||
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);
|
||||
launch_fattn_tile_kernel<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>
|
||||
(ctx, dst, nwarps, nbytes_shared, nbatch_fa, warp_size);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
|
@ -1130,9 +1164,8 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm
|
|||
constexpr int cols_per_block = 16;
|
||||
const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
|
||||
const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>;
|
||||
launch_fattn<DV, cols_per_block/ncols2, ncols2>
|
||||
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);
|
||||
launch_fattn_tile_kernel<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>
|
||||
(ctx, dst, nwarps, nbytes_shared, nbatch_fa, warp_size);
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -1141,9 +1174,8 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm
|
|||
constexpr int cols_per_block = 8;
|
||||
const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
|
||||
const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>;
|
||||
launch_fattn<DV, cols_per_block/ncols2, ncols2>
|
||||
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);
|
||||
launch_fattn_tile_kernel<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>
|
||||
(ctx, dst, nwarps, nbytes_shared, nbatch_fa, warp_size);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
|
@ -1153,9 +1185,8 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm
|
|||
constexpr int cols_per_block = 4;
|
||||
const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
|
||||
const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>;
|
||||
launch_fattn<DV, cols_per_block/ncols2, ncols2>
|
||||
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);
|
||||
launch_fattn_tile_kernel<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>
|
||||
(ctx, dst, nwarps, nbytes_shared, nbatch_fa, warp_size);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
|
@ -1164,9 +1195,8 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm
|
|||
constexpr int cols_per_block = 2;
|
||||
const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
|
||||
const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>;
|
||||
launch_fattn<DV, cols_per_block/ncols2, ncols2>
|
||||
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);
|
||||
launch_fattn_tile_kernel<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>
|
||||
(ctx, dst, nwarps, nbytes_shared, nbatch_fa, warp_size);
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -75,17 +75,17 @@ static __global__ void flash_attn_ext_vec(
|
|||
#endif // GGML_USE_HIP
|
||||
|
||||
constexpr int nthreads = ggml_cuda_fattn_vec_get_nthreads_device();
|
||||
constexpr int nthreads_KQ = type_K == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_KQ_q;
|
||||
constexpr int nthreads_V = type_V == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_V_q;
|
||||
constexpr int nthreads_KQ = (type_K == GGML_TYPE_F16 || type_K == GGML_TYPE_BF16) ? 128 / cpy_nb : nthreads_KQ_q;
|
||||
constexpr int nthreads_V = (type_V == GGML_TYPE_F16 || type_V == GGML_TYPE_BF16) ? 128 / cpy_nb : nthreads_V_q;
|
||||
|
||||
static_assert(WARP_SIZE % nthreads_KQ == 0, "bad nthreads_K");
|
||||
static_assert(WARP_SIZE % nthreads_V == 0, "bad nthreads_V");
|
||||
|
||||
constexpr int V_rows_per_thread = type_V == GGML_TYPE_F16 ? 2*cpy_ne : 4;
|
||||
constexpr int V_rows_per_thread = (type_V == GGML_TYPE_F16 || type_V == GGML_TYPE_BF16) ? 2*cpy_ne : 4;
|
||||
constexpr int V_cols_per_iter = WARP_SIZE / nthreads_V;
|
||||
|
||||
constexpr vec_dot_KQ_t vec_dot_KQ = get_vec_dot_KQ<type_K, D, nthreads_KQ>();
|
||||
constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
|
||||
constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16 && type_K != GGML_TYPE_BF16;
|
||||
#ifdef V_DOT2_F32_F16_AVAILABLE
|
||||
constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, half, V_rows_per_thread>();
|
||||
#else
|
||||
|
|
@ -516,10 +516,12 @@ void ggml_cuda_flash_attn_ext_vec_case_impl(ggml_backend_cuda_context & ctx, ggm
|
|||
const int nthreads = ggml_cuda_fattn_vec_get_nthreads_host(cc);
|
||||
const int nwarps = nthreads / WARP_SIZE;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_ext_vec<D, cols_per_block, type_K, type_V, use_logit_softcap>;
|
||||
const bool need_f16_K = type_K == GGML_TYPE_F16;
|
||||
const bool need_f16_V = type_V == GGML_TYPE_F16;
|
||||
const bool need_f16_K = type_K == GGML_TYPE_F16;
|
||||
const bool need_f16_V = type_V == GGML_TYPE_F16;
|
||||
const bool need_bf16_K = type_K == GGML_TYPE_BF16;
|
||||
const bool need_bf16_V = type_V == GGML_TYPE_BF16;
|
||||
constexpr size_t nbytes_shared = 0;
|
||||
launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
|
||||
launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false, WARP_SIZE, need_bf16_K, need_bf16_V);
|
||||
}
|
||||
|
||||
template <int D, ggml_type type_K, ggml_type type_V>
|
||||
|
|
@ -556,13 +558,14 @@ void ggml_cuda_flash_attn_ext_vec_case(ggml_backend_cuda_context & ctx, ggml_ten
|
|||
template void ggml_cuda_flash_attn_ext_vec_case \
|
||||
<D, type_K, type_V>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
|
||||
|
||||
#define EXTERN_DECL_FATTN_VEC_CASES(D, type_K) \
|
||||
#define EXTERN_DECL_FATTN_VEC_CASES(D, type_K) \
|
||||
extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_F16); \
|
||||
extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q4_0); \
|
||||
extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q4_1); \
|
||||
extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_0); \
|
||||
extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_1); \
|
||||
extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q8_0); \
|
||||
extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_BF16); \
|
||||
|
||||
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_F16)
|
||||
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_0)
|
||||
|
|
@ -570,6 +573,7 @@ EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_1)
|
|||
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_0)
|
||||
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_1)
|
||||
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q8_0)
|
||||
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_BF16)
|
||||
|
||||
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_F16)
|
||||
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_0)
|
||||
|
|
@ -577,6 +581,7 @@ EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_1)
|
|||
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_0)
|
||||
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_1)
|
||||
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q8_0)
|
||||
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_BF16)
|
||||
|
||||
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_F16)
|
||||
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_0)
|
||||
|
|
@ -584,3 +589,4 @@ EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_1)
|
|||
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_0)
|
||||
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_1)
|
||||
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q8_0)
|
||||
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_BF16)
|
||||
|
|
|
|||
|
|
@ -224,6 +224,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t
|
|||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_F16)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_F16)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_F16)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_F16)
|
||||
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
|
||||
|
|
@ -231,6 +232,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t
|
|||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q4_0)
|
||||
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_1)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_1)
|
||||
|
|
@ -238,6 +240,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t
|
|||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_1)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_1)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_1)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q4_1)
|
||||
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
|
||||
|
|
@ -245,6 +248,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t
|
|||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q5_0)
|
||||
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_1)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
|
||||
|
|
@ -252,6 +256,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t
|
|||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q5_1)
|
||||
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q8_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q8_0)
|
||||
|
|
@ -259,10 +264,20 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t
|
|||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q8_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q8_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q8_0)
|
||||
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_BF16)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_BF16)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_BF16)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_BF16)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_BF16)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_BF16)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_BF16)
|
||||
#else
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_F16)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_BF16)
|
||||
#endif // GGML_CUDA_FA_ALL_QUANTS
|
||||
|
||||
GGML_ABORT("fatal error");
|
||||
|
|
@ -355,6 +370,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
|||
#endif // GGML_CUDA_FA_ALL_QUANTS
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_BF16:
|
||||
break;
|
||||
default:
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,7 @@
|
|||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.cuh"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_BF16);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_BF16);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_BF16);
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.cuh"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_F16);
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.cuh"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q4_0);
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.cuh"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q4_1);
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.cuh"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q5_0);
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.cuh"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q5_1);
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.cuh"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q8_0);
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.cuh"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_BF16);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_BF16);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_BF16);
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.cuh"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_BF16);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_BF16);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_BF16);
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.cuh"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_BF16);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_BF16);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_BF16);
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.cuh"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_BF16);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_BF16);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_BF16);
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.cuh"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_BF16);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_BF16);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_BF16);
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.cuh"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_BF16);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_BF16);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_BF16);
|
||||
|
|
@ -5,7 +5,7 @@ import os
|
|||
|
||||
HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 576]
|
||||
|
||||
TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0"]
|
||||
TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_BF16"]
|
||||
|
||||
SOURCE_FATTN_TILE = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue