diff --git a/ggml/src/ggml-sycl/fattn-tile.hpp b/ggml/src/ggml-sycl/fattn-tile.hpp index 29fd0f8c9e..c4d24613a5 100644 --- a/ggml/src/ggml-sycl/fattn-tile.hpp +++ b/ggml/src/ggml-sycl/fattn-tile.hpp @@ -70,6 +70,7 @@ static constexpr uint32_t ggml_sycl_fattn_tile_get_config_fp16(const int DKQ, co GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64) GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 64, 64) return 0; } @@ -310,11 +311,11 @@ static __dpct_inline__ void flash_attn_tile_load_tile(const sycl::half2 * const sycl::half2 * const __restrict__ tile_KV, const int stride_KV, const int i_sup) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes(); constexpr int cpy_ne = cpy_nb / 4; auto load = [&] (const int n) { - auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); const int stride_j = warp_size >> n; if (stride_j == 0) { @@ -455,7 +456,7 @@ static __dpct_inline__ void flash_attn_tile_iter_KQ(T_vec_dot * const Q_tmp, flash_attn_tile_load_tile (K_h2 + int64_t(k_VKQ_0)*stride_K2 + k_KQ_0/2, KV_tmp, stride_K2, k_VKQ_sup); - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); #ifdef SYCL_FAST_FP16 static_assert((nbatch_K/2) % cpy_ne == 0, "bad nbatch_K"); @@ -505,7 +506,7 @@ static __dpct_inline__ void flash_attn_tile_iter_KQ(T_vec_dot * const Q_tmp, } if (k_KQ_0 + nbatch_K < DKQ) { - item_ct1.barrier(); // Sync not needed on last iteration. + item_ct1.barrier(sycl::access::fence_space::local_space); // Sync not needed on last iteration. } } @@ -545,7 +546,7 @@ static __dpct_inline__ void flash_attn_tile_iter(T_vec_dot * const Q_tmp, const int k_VKQ_max, const int col_Q_0, float * KQ_max_new_shared) { - auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes(); constexpr int cpy_ne = cpy_nb / 4; @@ -620,14 +621,14 @@ static __dpct_inline__ void flash_attn_tile_iter(T_vec_dot * const Q_tmp, } if constexpr (np == 1) { - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); } else { static_assert(cpw == 1, "bad cpw"); if (item_ct1.get_local_id(2) == 0) { KQ_max_new_shared[item_ct1.get_local_id(1)] = KQ_max_new[0]; } - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); KQ_max_new[0] = KQ_max_new_shared[(item_ct1.get_local_id(1) & ~(np - 1)) + item_ct1.get_local_id(2) % np]; KQ_max_new[0] = warp_reduce_max(KQ_max_new[0]); } @@ -697,7 +698,7 @@ static __dpct_inline__ void flash_attn_tile_iter(T_vec_dot * const Q_tmp, for (int k0 = 0; k0 < nbatch_fa; k0 += nbatch_V) { flash_attn_tile_load_tile (V_h2 + int64_t(k_VKQ_0 + k0)*stride_V2, KV_tmp, stride_V2, k_VKQ_sup - k0); - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); #ifdef SYCL_FAST_FP16 #pragma unroll @@ -765,7 +766,7 @@ static __dpct_inline__ void flash_attn_tile_iter(T_vec_dot * const Q_tmp, } } #endif // SYCL_FAST_FP16 - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); } } @@ -972,7 +973,7 @@ static void flash_attn_tile(const char * Q, } } - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); // Main loop over KV cache: const int k_VKQ_max = KV_max ? KV_max[sequence * item_ct1.get_group_range(2) + item_ct1.get_group(2)] : ne11; @@ -1051,7 +1052,7 @@ static void flash_attn_tile(const char * Q, return; } - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); #pragma unroll for (int ip = 1; ip < np; ++ip) { @@ -1193,37 +1194,39 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_sycl_context & ctx, ggm constexpr size_t nbytes_shared = 0; - if constexpr (DV <= 256) { - if (Q->ne[1] > 16/ncols2) { - constexpr int cols_per_block = 32; - const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; - const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); - launch_fattn, warp_size> - (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false); - return; + if (DV < 512 && Q->ne[1] < 32) { + if constexpr (ncols2 <= 32) { + if (Q->ne[1] > 16/ncols2) { + constexpr int cols_per_block = 32; + const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + launch_fattn, warp_size> + (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false); + return; + } } - } - - if (Q->ne[1] > 8/ncols2) { - constexpr int cols_per_block = 16; - const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; - const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); - launch_fattn, warp_size> - (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false); - return; - } - - if constexpr (ncols2 <= 8) { - if (Q->ne[1] > 4/ncols2) { - constexpr int cols_per_block = 8; - const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; - const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); - launch_fattn, warp_size> - (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false); - return; + if constexpr (ncols2 <= 16) { + if (Q->ne[1] > 8/ncols2) { + constexpr int cols_per_block = 16; + const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + launch_fattn, warp_size> + (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false); + return; + } + } + if constexpr (ncols2 <= 8) { + if (Q->ne[1] > 4/ncols2) { + constexpr int cols_per_block = 8; + const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + launch_fattn, warp_size> + (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false); + return; + } } }