From e1511c3be3e2244ac130e72fed5237a673d6d099 Mon Sep 17 00:00:00 2001 From: yehudit-dev Date: Tue, 2 Dec 2025 11:27:44 +0200 Subject: [PATCH] fix: add supported tests and fix related issues Co-authored-by: safranowith Co-authored-by: ye-NX --- ggml/src/ggml-sycl/fattn.cpp | 24 ++++++++++++++---------- ggml/src/ggml-sycl/fattn_kernel.hpp | 12 ++++++------ tests/test-backend-ops.cpp | 8 ++++++++ 3 files changed, 28 insertions(+), 16 deletions(-) diff --git a/ggml/src/ggml-sycl/fattn.cpp b/ggml/src/ggml-sycl/fattn.cpp index 67cafa10cd..ea6730145f 100644 --- a/ggml/src/ggml-sycl/fattn.cpp +++ b/ggml/src/ggml-sycl/fattn.cpp @@ -83,15 +83,16 @@ void ggml_sycl_op_flash_attn_2(ggml_backend_sycl_context & ctx, ggml_tensor * ds sycl::local_accessor m_local({Br}, cgh); sycl::local_accessor l_local({Br}, cgh); - float* q_loc = Qtile.template get_multi_ptr().get(); - float* k_loc = Ktile.template get_multi_ptr().get(); - float* v_loc = Vtile.template get_multi_ptr().get(); - float* s_loc = Stile.template get_multi_ptr().get(); - float* p_loc = Ptile.template get_multi_ptr().get(); - float* m_loc = m_local.template get_multi_ptr().get(); - float* l_loc = l_local.template get_multi_ptr().get(); - cgh.parallel_for(sycl::nd_range<2>(global, local), [=](sycl::nd_item<2> it) { + + float* q_loc = Qtile.template get_multi_ptr().get(); + float* k_loc = Ktile.template get_multi_ptr().get(); + float* v_loc = Vtile.template get_multi_ptr().get(); + float* s_loc = Stile.template get_multi_ptr().get(); + float* p_loc = Ptile.template get_multi_ptr().get(); + float* m_loc = m_local.template get_multi_ptr().get(); + float* l_loc = l_local.template get_multi_ptr().get(); + auto group = it.get_group(); int group_id_i = group.get_group_id(0); int group_id_j = group.get_group_id(1); @@ -174,7 +175,6 @@ void ggml_sycl_op_flash_attn_2(ggml_backend_sycl_context & ctx, ggml_tensor * ds sycl::free(m_d, *stream); } - void ggml_sycl_op_flash_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; const ggml_tensor * V = dst->src[2]; @@ -204,8 +204,12 @@ void ggml_sycl_op_flash_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) GGML_ASSERT(V->ne[0] == 256); ggml_sycl_op_flash_attn_2<256, 256>(ctx, dst); break; + case 576: + GGML_ASSERT(V->ne[0] == 512); + ggml_sycl_op_flash_attn_2<576, 512>(ctx, dst); + break; default: - GGML_ABORT("fatal error"); + GGML_ABORT("Unsupported head size"); break; } } diff --git a/ggml/src/ggml-sycl/fattn_kernel.hpp b/ggml/src/ggml-sycl/fattn_kernel.hpp index 8229470567..5c49a0a4c1 100644 --- a/ggml/src/ggml-sycl/fattn_kernel.hpp +++ b/ggml/src/ggml-sycl/fattn_kernel.hpp @@ -53,14 +53,14 @@ inline void flash_attn_softmax_kernel( float m_old = m_d[row]; float l_old = l_d[row]; - // 2. Block max + // Block max float m_block = -INFINITY; for (int j = 0; j < Bc; ++j) { const float s_ij = S[row_offset + j]; m_block = sycl::fmax(m_block, s_ij); } - // 3. Block exp-sum + // Block exp-sum float l_block = 0.0f; for (int j = 0; j < Bc; ++j) { const float e = sycl::exp(S[row_offset + j] - m_block); @@ -68,7 +68,7 @@ inline void flash_attn_softmax_kernel( l_block += e; } - // 4. Merge block stats with global (streaming softmax) + // Merge block stats with global (streaming softmax) float m_new; float l_new; @@ -85,11 +85,11 @@ inline void flash_attn_softmax_kernel( l_new = alpha * l_old + beta * l_block; } - // 5. Store updated global stats + // Store updated global stats m_d[row] = m_new; l_d[row] = l_new; - // 6. Convert local e_ij to global probabilities p_ij + // Convert local e_ij to global probabilities p_ij float scale_block = 0.0f; if (l_new > 0.0f) { scale_block = sycl::exp(m_block - m_new) / l_new; @@ -99,7 +99,7 @@ inline void flash_attn_softmax_kernel( P[row_offset + j] *= scale_block; } - // 7. Optional: keep local copies + // Optional: keep local copies m_local[li] = m_new; l_local[li] = l_new; } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index fa98db2982..28a21b426e 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -7117,6 +7117,14 @@ static std::vector> make_test_cases_perf() { } } + + for (int kv : { 4096, 8192, 16384, }) { + for (int hs : { 64, 128, }) { + test_cases.emplace_back(new test_flash_attn_ext(hs, hs, 1, {1, 1}, kv, 1, false, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F32)); + } + } + + test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, false)); test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, true));