sycl : enhance fattn perf (#21185)
This commit is contained in:
parent
90aa83c6bd
commit
62278cedde
|
|
@ -70,6 +70,7 @@ static constexpr uint32_t ggml_sycl_fattn_tile_get_config_fp16(const int DKQ, co
|
|||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 64, 64)
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
|
@ -310,11 +311,11 @@ static __dpct_inline__ void flash_attn_tile_load_tile(const sycl::half2 * const
|
|||
sycl::half2 * const __restrict__ tile_KV,
|
||||
const int stride_KV,
|
||||
const int i_sup) {
|
||||
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
|
||||
constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes();
|
||||
constexpr int cpy_ne = cpy_nb / 4;
|
||||
|
||||
auto load = [&] (const int n) {
|
||||
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
|
||||
const int stride_j = warp_size >> n;
|
||||
|
||||
if (stride_j == 0) {
|
||||
|
|
@ -455,7 +456,7 @@ static __dpct_inline__ void flash_attn_tile_iter_KQ(T_vec_dot * const Q_tmp,
|
|||
|
||||
flash_attn_tile_load_tile<warp_size, nwarps, nbatch_fa, nbatch_K, cpy_ne, oob_check>
|
||||
(K_h2 + int64_t(k_VKQ_0)*stride_K2 + k_KQ_0/2, KV_tmp, stride_K2, k_VKQ_sup);
|
||||
item_ct1.barrier();
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
#ifdef SYCL_FAST_FP16
|
||||
static_assert((nbatch_K/2) % cpy_ne == 0, "bad nbatch_K");
|
||||
|
|
@ -505,7 +506,7 @@ static __dpct_inline__ void flash_attn_tile_iter_KQ(T_vec_dot * const Q_tmp,
|
|||
}
|
||||
|
||||
if (k_KQ_0 + nbatch_K < DKQ) {
|
||||
item_ct1.barrier(); // Sync not needed on last iteration.
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space); // Sync not needed on last iteration.
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -545,7 +546,7 @@ static __dpct_inline__ void flash_attn_tile_iter(T_vec_dot * const Q_tmp,
|
|||
const int k_VKQ_max,
|
||||
const int col_Q_0,
|
||||
float * KQ_max_new_shared) {
|
||||
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
|
||||
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
|
||||
constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes();
|
||||
constexpr int cpy_ne = cpy_nb / 4;
|
||||
|
||||
|
|
@ -620,14 +621,14 @@ static __dpct_inline__ void flash_attn_tile_iter(T_vec_dot * const Q_tmp,
|
|||
}
|
||||
|
||||
if constexpr (np == 1) {
|
||||
item_ct1.barrier();
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
} else {
|
||||
static_assert(cpw == 1, "bad cpw");
|
||||
|
||||
if (item_ct1.get_local_id(2) == 0) {
|
||||
KQ_max_new_shared[item_ct1.get_local_id(1)] = KQ_max_new[0];
|
||||
}
|
||||
item_ct1.barrier();
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
KQ_max_new[0] = KQ_max_new_shared[(item_ct1.get_local_id(1) & ~(np - 1)) + item_ct1.get_local_id(2) % np];
|
||||
KQ_max_new[0] = warp_reduce_max<np>(KQ_max_new[0]);
|
||||
}
|
||||
|
|
@ -697,7 +698,7 @@ static __dpct_inline__ void flash_attn_tile_iter(T_vec_dot * const Q_tmp,
|
|||
for (int k0 = 0; k0 < nbatch_fa; k0 += nbatch_V) {
|
||||
flash_attn_tile_load_tile<warp_size, nwarps, nbatch_V, DV, 0, oob_check>
|
||||
(V_h2 + int64_t(k_VKQ_0 + k0)*stride_V2, KV_tmp, stride_V2, k_VKQ_sup - k0);
|
||||
item_ct1.barrier();
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
#ifdef SYCL_FAST_FP16
|
||||
#pragma unroll
|
||||
|
|
@ -765,7 +766,7 @@ static __dpct_inline__ void flash_attn_tile_iter(T_vec_dot * const Q_tmp,
|
|||
}
|
||||
}
|
||||
#endif // SYCL_FAST_FP16
|
||||
item_ct1.barrier();
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -972,7 +973,7 @@ static void flash_attn_tile(const char * Q,
|
|||
}
|
||||
}
|
||||
|
||||
item_ct1.barrier();
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
// Main loop over KV cache:
|
||||
const int k_VKQ_max = KV_max ? KV_max[sequence * item_ct1.get_group_range(2) + item_ct1.get_group(2)] : ne11;
|
||||
|
|
@ -1051,7 +1052,7 @@ static void flash_attn_tile(const char * Q,
|
|||
return;
|
||||
}
|
||||
|
||||
item_ct1.barrier();
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
#pragma unroll
|
||||
for (int ip = 1; ip < np; ++ip) {
|
||||
|
|
@ -1193,37 +1194,39 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_sycl_context & ctx, ggm
|
|||
|
||||
constexpr size_t nbytes_shared = 0;
|
||||
|
||||
if constexpr (DV <= 256) {
|
||||
if (Q->ne[1] > 16/ncols2) {
|
||||
constexpr int cols_per_block = 32;
|
||||
const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
|
||||
const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
|
||||
launch_fattn<DV, cols_per_block/ncols2, ncols2,
|
||||
flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
|
||||
(ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
|
||||
return;
|
||||
if (DV < 512 && Q->ne[1] < 32) {
|
||||
if constexpr (ncols2 <= 32) {
|
||||
if (Q->ne[1] > 16/ncols2) {
|
||||
constexpr int cols_per_block = 32;
|
||||
const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
|
||||
const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
|
||||
launch_fattn<DV, cols_per_block/ncols2, ncols2,
|
||||
flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
|
||||
(ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (Q->ne[1] > 8/ncols2) {
|
||||
constexpr int cols_per_block = 16;
|
||||
const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
|
||||
const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
|
||||
launch_fattn<DV, cols_per_block/ncols2, ncols2,
|
||||
flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
|
||||
(ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
|
||||
return;
|
||||
}
|
||||
|
||||
if constexpr (ncols2 <= 8) {
|
||||
if (Q->ne[1] > 4/ncols2) {
|
||||
constexpr int cols_per_block = 8;
|
||||
const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
|
||||
const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
|
||||
launch_fattn<DV, cols_per_block/ncols2, ncols2,
|
||||
flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
|
||||
(ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
|
||||
return;
|
||||
if constexpr (ncols2 <= 16) {
|
||||
if (Q->ne[1] > 8/ncols2) {
|
||||
constexpr int cols_per_block = 16;
|
||||
const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
|
||||
const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
|
||||
launch_fattn<DV, cols_per_block/ncols2, ncols2,
|
||||
flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
|
||||
(ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
|
||||
return;
|
||||
}
|
||||
}
|
||||
if constexpr (ncols2 <= 8) {
|
||||
if (Q->ne[1] > 4/ncols2) {
|
||||
constexpr int cols_per_block = 8;
|
||||
const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
|
||||
const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
|
||||
launch_fattn<DV, cols_per_block/ncols2, ncols2,
|
||||
flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
|
||||
(ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue