fix: add supported tests and fix related issues
Co-authored-by: safranowith <bsh155762@gmail.com> Co-authored-by: ye-NX <y8703470@gmail.com>
This commit is contained in:
parent
c62b98b083
commit
e1511c3be3
|
|
@ -83,15 +83,16 @@ void ggml_sycl_op_flash_attn_2(ggml_backend_sycl_context & ctx, ggml_tensor * ds
|
|||
sycl::local_accessor<float, 1> m_local({Br}, cgh);
|
||||
sycl::local_accessor<float, 1> l_local({Br}, cgh);
|
||||
|
||||
float* q_loc = Qtile.template get_multi_ptr<sycl::access::decorated::no>().get();
|
||||
float* k_loc = Ktile.template get_multi_ptr<sycl::access::decorated::no>().get();
|
||||
float* v_loc = Vtile.template get_multi_ptr<sycl::access::decorated::no>().get();
|
||||
float* s_loc = Stile.template get_multi_ptr<sycl::access::decorated::no>().get();
|
||||
float* p_loc = Ptile.template get_multi_ptr<sycl::access::decorated::no>().get();
|
||||
float* m_loc = m_local.template get_multi_ptr<sycl::access::decorated::no>().get();
|
||||
float* l_loc = l_local.template get_multi_ptr<sycl::access::decorated::no>().get();
|
||||
|
||||
cgh.parallel_for(sycl::nd_range<2>(global, local), [=](sycl::nd_item<2> it) {
|
||||
|
||||
float* q_loc = Qtile.template get_multi_ptr<sycl::access::decorated::no>().get();
|
||||
float* k_loc = Ktile.template get_multi_ptr<sycl::access::decorated::no>().get();
|
||||
float* v_loc = Vtile.template get_multi_ptr<sycl::access::decorated::no>().get();
|
||||
float* s_loc = Stile.template get_multi_ptr<sycl::access::decorated::no>().get();
|
||||
float* p_loc = Ptile.template get_multi_ptr<sycl::access::decorated::no>().get();
|
||||
float* m_loc = m_local.template get_multi_ptr<sycl::access::decorated::no>().get();
|
||||
float* l_loc = l_local.template get_multi_ptr<sycl::access::decorated::no>().get();
|
||||
|
||||
auto group = it.get_group();
|
||||
int group_id_i = group.get_group_id(0);
|
||||
int group_id_j = group.get_group_id(1);
|
||||
|
|
@ -174,7 +175,6 @@ void ggml_sycl_op_flash_attn_2(ggml_backend_sycl_context & ctx, ggml_tensor * ds
|
|||
sycl::free(m_d, *stream);
|
||||
}
|
||||
|
||||
|
||||
void ggml_sycl_op_flash_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
|
|
@ -204,8 +204,12 @@ void ggml_sycl_op_flash_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
|
|||
GGML_ASSERT(V->ne[0] == 256);
|
||||
ggml_sycl_op_flash_attn_2<256, 256>(ctx, dst);
|
||||
break;
|
||||
case 576:
|
||||
GGML_ASSERT(V->ne[0] == 512);
|
||||
ggml_sycl_op_flash_attn_2<576, 512>(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
GGML_ABORT("Unsupported head size");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -53,14 +53,14 @@ inline void flash_attn_softmax_kernel(
|
|||
float m_old = m_d[row];
|
||||
float l_old = l_d[row];
|
||||
|
||||
// 2. Block max
|
||||
// Block max
|
||||
float m_block = -INFINITY;
|
||||
for (int j = 0; j < Bc; ++j) {
|
||||
const float s_ij = S[row_offset + j];
|
||||
m_block = sycl::fmax(m_block, s_ij);
|
||||
}
|
||||
|
||||
// 3. Block exp-sum
|
||||
// Block exp-sum
|
||||
float l_block = 0.0f;
|
||||
for (int j = 0; j < Bc; ++j) {
|
||||
const float e = sycl::exp(S[row_offset + j] - m_block);
|
||||
|
|
@ -68,7 +68,7 @@ inline void flash_attn_softmax_kernel(
|
|||
l_block += e;
|
||||
}
|
||||
|
||||
// 4. Merge block stats with global (streaming softmax)
|
||||
// Merge block stats with global (streaming softmax)
|
||||
float m_new;
|
||||
float l_new;
|
||||
|
||||
|
|
@ -85,11 +85,11 @@ inline void flash_attn_softmax_kernel(
|
|||
l_new = alpha * l_old + beta * l_block;
|
||||
}
|
||||
|
||||
// 5. Store updated global stats
|
||||
// Store updated global stats
|
||||
m_d[row] = m_new;
|
||||
l_d[row] = l_new;
|
||||
|
||||
// 6. Convert local e_ij to global probabilities p_ij
|
||||
// Convert local e_ij to global probabilities p_ij
|
||||
float scale_block = 0.0f;
|
||||
if (l_new > 0.0f) {
|
||||
scale_block = sycl::exp(m_block - m_new) / l_new;
|
||||
|
|
@ -99,7 +99,7 @@ inline void flash_attn_softmax_kernel(
|
|||
P[row_offset + j] *= scale_block;
|
||||
}
|
||||
|
||||
// 7. Optional: keep local copies
|
||||
// Optional: keep local copies
|
||||
m_local[li] = m_new;
|
||||
l_local[li] = l_new;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7117,6 +7117,14 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
for (int kv : { 4096, 8192, 16384, }) {
|
||||
for (int hs : { 64, 128, }) {
|
||||
test_cases.emplace_back(new test_flash_attn_ext(hs, hs, 1, {1, 1}, kv, 1, false, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F32));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, false));
|
||||
test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, true));
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue