sycl : enhance fattn perf (#21185)

This commit is contained in:
Neo Zhang 2026-03-31 18:31:50 +08:00 committed by GitHub
parent 90aa83c6bd
commit 62278cedde
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 43 additions and 40 deletions

View File

@ -70,6 +70,7 @@ static constexpr uint32_t ggml_sycl_fattn_tile_get_config_fp16(const int DKQ, co
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 64, 64)
return 0;
}
@ -310,11 +311,11 @@ static __dpct_inline__ void flash_attn_tile_load_tile(const sycl::half2 * const
sycl::half2 * const __restrict__ tile_KV,
const int stride_KV,
const int i_sup) {
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes();
constexpr int cpy_ne = cpy_nb / 4;
auto load = [&] (const int n) {
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
const int stride_j = warp_size >> n;
if (stride_j == 0) {
@ -455,7 +456,7 @@ static __dpct_inline__ void flash_attn_tile_iter_KQ(T_vec_dot * const Q_tmp,
flash_attn_tile_load_tile<warp_size, nwarps, nbatch_fa, nbatch_K, cpy_ne, oob_check>
(K_h2 + int64_t(k_VKQ_0)*stride_K2 + k_KQ_0/2, KV_tmp, stride_K2, k_VKQ_sup);
item_ct1.barrier();
item_ct1.barrier(sycl::access::fence_space::local_space);
#ifdef SYCL_FAST_FP16
static_assert((nbatch_K/2) % cpy_ne == 0, "bad nbatch_K");
@ -505,7 +506,7 @@ static __dpct_inline__ void flash_attn_tile_iter_KQ(T_vec_dot * const Q_tmp,
}
if (k_KQ_0 + nbatch_K < DKQ) {
item_ct1.barrier(); // Sync not needed on last iteration.
item_ct1.barrier(sycl::access::fence_space::local_space); // Sync not needed on last iteration.
}
}
@ -545,7 +546,7 @@ static __dpct_inline__ void flash_attn_tile_iter(T_vec_dot * const Q_tmp,
const int k_VKQ_max,
const int col_Q_0,
float * KQ_max_new_shared) {
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes();
constexpr int cpy_ne = cpy_nb / 4;
@ -620,14 +621,14 @@ static __dpct_inline__ void flash_attn_tile_iter(T_vec_dot * const Q_tmp,
}
if constexpr (np == 1) {
item_ct1.barrier();
item_ct1.barrier(sycl::access::fence_space::local_space);
} else {
static_assert(cpw == 1, "bad cpw");
if (item_ct1.get_local_id(2) == 0) {
KQ_max_new_shared[item_ct1.get_local_id(1)] = KQ_max_new[0];
}
item_ct1.barrier();
item_ct1.barrier(sycl::access::fence_space::local_space);
KQ_max_new[0] = KQ_max_new_shared[(item_ct1.get_local_id(1) & ~(np - 1)) + item_ct1.get_local_id(2) % np];
KQ_max_new[0] = warp_reduce_max<np>(KQ_max_new[0]);
}
@ -697,7 +698,7 @@ static __dpct_inline__ void flash_attn_tile_iter(T_vec_dot * const Q_tmp,
for (int k0 = 0; k0 < nbatch_fa; k0 += nbatch_V) {
flash_attn_tile_load_tile<warp_size, nwarps, nbatch_V, DV, 0, oob_check>
(V_h2 + int64_t(k_VKQ_0 + k0)*stride_V2, KV_tmp, stride_V2, k_VKQ_sup - k0);
item_ct1.barrier();
item_ct1.barrier(sycl::access::fence_space::local_space);
#ifdef SYCL_FAST_FP16
#pragma unroll
@ -765,7 +766,7 @@ static __dpct_inline__ void flash_attn_tile_iter(T_vec_dot * const Q_tmp,
}
}
#endif // SYCL_FAST_FP16
item_ct1.barrier();
item_ct1.barrier(sycl::access::fence_space::local_space);
}
}
@ -972,7 +973,7 @@ static void flash_attn_tile(const char * Q,
}
}
item_ct1.barrier();
item_ct1.barrier(sycl::access::fence_space::local_space);
// Main loop over KV cache:
const int k_VKQ_max = KV_max ? KV_max[sequence * item_ct1.get_group_range(2) + item_ct1.get_group(2)] : ne11;
@ -1051,7 +1052,7 @@ static void flash_attn_tile(const char * Q,
return;
}
item_ct1.barrier();
item_ct1.barrier(sycl::access::fence_space::local_space);
#pragma unroll
for (int ip = 1; ip < np; ++ip) {
@ -1193,37 +1194,39 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_sycl_context & ctx, ggm
constexpr size_t nbytes_shared = 0;
if constexpr (DV <= 256) {
if (Q->ne[1] > 16/ncols2) {
constexpr int cols_per_block = 32;
const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
launch_fattn<DV, cols_per_block/ncols2, ncols2,
flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
(ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
return;
if (DV < 512 && Q->ne[1] < 32) {
if constexpr (ncols2 <= 32) {
if (Q->ne[1] > 16/ncols2) {
constexpr int cols_per_block = 32;
const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
launch_fattn<DV, cols_per_block/ncols2, ncols2,
flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
(ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
return;
}
}
}
if (Q->ne[1] > 8/ncols2) {
constexpr int cols_per_block = 16;
const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
launch_fattn<DV, cols_per_block/ncols2, ncols2,
flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
(ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
return;
}
if constexpr (ncols2 <= 8) {
if (Q->ne[1] > 4/ncols2) {
constexpr int cols_per_block = 8;
const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
launch_fattn<DV, cols_per_block/ncols2, ncols2,
flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
(ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
return;
if constexpr (ncols2 <= 16) {
if (Q->ne[1] > 8/ncols2) {
constexpr int cols_per_block = 16;
const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
launch_fattn<DV, cols_per_block/ncols2, ncols2,
flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
(ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
return;
}
}
if constexpr (ncols2 <= 8) {
if (Q->ne[1] > 4/ncols2) {
constexpr int cols_per_block = 8;
const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
launch_fattn<DV, cols_per_block/ncols2, ncols2,
flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
(ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
return;
}
}
}