fix: add supported tests and fix related issues
Co-authored-by: safranowith <bsh155762@gmail.com> Co-authored-by: ye-NX <y8703470@gmail.com>
This commit is contained in:
parent
c62b98b083
commit
e1511c3be3
|
|
@ -83,15 +83,16 @@ void ggml_sycl_op_flash_attn_2(ggml_backend_sycl_context & ctx, ggml_tensor * ds
|
||||||
sycl::local_accessor<float, 1> m_local({Br}, cgh);
|
sycl::local_accessor<float, 1> m_local({Br}, cgh);
|
||||||
sycl::local_accessor<float, 1> l_local({Br}, cgh);
|
sycl::local_accessor<float, 1> l_local({Br}, cgh);
|
||||||
|
|
||||||
float* q_loc = Qtile.template get_multi_ptr<sycl::access::decorated::no>().get();
|
|
||||||
float* k_loc = Ktile.template get_multi_ptr<sycl::access::decorated::no>().get();
|
|
||||||
float* v_loc = Vtile.template get_multi_ptr<sycl::access::decorated::no>().get();
|
|
||||||
float* s_loc = Stile.template get_multi_ptr<sycl::access::decorated::no>().get();
|
|
||||||
float* p_loc = Ptile.template get_multi_ptr<sycl::access::decorated::no>().get();
|
|
||||||
float* m_loc = m_local.template get_multi_ptr<sycl::access::decorated::no>().get();
|
|
||||||
float* l_loc = l_local.template get_multi_ptr<sycl::access::decorated::no>().get();
|
|
||||||
|
|
||||||
cgh.parallel_for(sycl::nd_range<2>(global, local), [=](sycl::nd_item<2> it) {
|
cgh.parallel_for(sycl::nd_range<2>(global, local), [=](sycl::nd_item<2> it) {
|
||||||
|
|
||||||
|
float* q_loc = Qtile.template get_multi_ptr<sycl::access::decorated::no>().get();
|
||||||
|
float* k_loc = Ktile.template get_multi_ptr<sycl::access::decorated::no>().get();
|
||||||
|
float* v_loc = Vtile.template get_multi_ptr<sycl::access::decorated::no>().get();
|
||||||
|
float* s_loc = Stile.template get_multi_ptr<sycl::access::decorated::no>().get();
|
||||||
|
float* p_loc = Ptile.template get_multi_ptr<sycl::access::decorated::no>().get();
|
||||||
|
float* m_loc = m_local.template get_multi_ptr<sycl::access::decorated::no>().get();
|
||||||
|
float* l_loc = l_local.template get_multi_ptr<sycl::access::decorated::no>().get();
|
||||||
|
|
||||||
auto group = it.get_group();
|
auto group = it.get_group();
|
||||||
int group_id_i = group.get_group_id(0);
|
int group_id_i = group.get_group_id(0);
|
||||||
int group_id_j = group.get_group_id(1);
|
int group_id_j = group.get_group_id(1);
|
||||||
|
|
@ -174,7 +175,6 @@ void ggml_sycl_op_flash_attn_2(ggml_backend_sycl_context & ctx, ggml_tensor * ds
|
||||||
sycl::free(m_d, *stream);
|
sycl::free(m_d, *stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
void ggml_sycl_op_flash_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
void ggml_sycl_op_flash_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||||
const ggml_tensor * Q = dst->src[0];
|
const ggml_tensor * Q = dst->src[0];
|
||||||
const ggml_tensor * V = dst->src[2];
|
const ggml_tensor * V = dst->src[2];
|
||||||
|
|
@ -204,8 +204,12 @@ void ggml_sycl_op_flash_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
|
||||||
GGML_ASSERT(V->ne[0] == 256);
|
GGML_ASSERT(V->ne[0] == 256);
|
||||||
ggml_sycl_op_flash_attn_2<256, 256>(ctx, dst);
|
ggml_sycl_op_flash_attn_2<256, 256>(ctx, dst);
|
||||||
break;
|
break;
|
||||||
|
case 576:
|
||||||
|
GGML_ASSERT(V->ne[0] == 512);
|
||||||
|
ggml_sycl_op_flash_attn_2<576, 512>(ctx, dst);
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("Unsupported head size");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -53,14 +53,14 @@ inline void flash_attn_softmax_kernel(
|
||||||
float m_old = m_d[row];
|
float m_old = m_d[row];
|
||||||
float l_old = l_d[row];
|
float l_old = l_d[row];
|
||||||
|
|
||||||
// 2. Block max
|
// Block max
|
||||||
float m_block = -INFINITY;
|
float m_block = -INFINITY;
|
||||||
for (int j = 0; j < Bc; ++j) {
|
for (int j = 0; j < Bc; ++j) {
|
||||||
const float s_ij = S[row_offset + j];
|
const float s_ij = S[row_offset + j];
|
||||||
m_block = sycl::fmax(m_block, s_ij);
|
m_block = sycl::fmax(m_block, s_ij);
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. Block exp-sum
|
// Block exp-sum
|
||||||
float l_block = 0.0f;
|
float l_block = 0.0f;
|
||||||
for (int j = 0; j < Bc; ++j) {
|
for (int j = 0; j < Bc; ++j) {
|
||||||
const float e = sycl::exp(S[row_offset + j] - m_block);
|
const float e = sycl::exp(S[row_offset + j] - m_block);
|
||||||
|
|
@ -68,7 +68,7 @@ inline void flash_attn_softmax_kernel(
|
||||||
l_block += e;
|
l_block += e;
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4. Merge block stats with global (streaming softmax)
|
// Merge block stats with global (streaming softmax)
|
||||||
float m_new;
|
float m_new;
|
||||||
float l_new;
|
float l_new;
|
||||||
|
|
||||||
|
|
@ -85,11 +85,11 @@ inline void flash_attn_softmax_kernel(
|
||||||
l_new = alpha * l_old + beta * l_block;
|
l_new = alpha * l_old + beta * l_block;
|
||||||
}
|
}
|
||||||
|
|
||||||
// 5. Store updated global stats
|
// Store updated global stats
|
||||||
m_d[row] = m_new;
|
m_d[row] = m_new;
|
||||||
l_d[row] = l_new;
|
l_d[row] = l_new;
|
||||||
|
|
||||||
// 6. Convert local e_ij to global probabilities p_ij
|
// Convert local e_ij to global probabilities p_ij
|
||||||
float scale_block = 0.0f;
|
float scale_block = 0.0f;
|
||||||
if (l_new > 0.0f) {
|
if (l_new > 0.0f) {
|
||||||
scale_block = sycl::exp(m_block - m_new) / l_new;
|
scale_block = sycl::exp(m_block - m_new) / l_new;
|
||||||
|
|
@ -99,7 +99,7 @@ inline void flash_attn_softmax_kernel(
|
||||||
P[row_offset + j] *= scale_block;
|
P[row_offset + j] *= scale_block;
|
||||||
}
|
}
|
||||||
|
|
||||||
// 7. Optional: keep local copies
|
// Optional: keep local copies
|
||||||
m_local[li] = m_new;
|
m_local[li] = m_new;
|
||||||
l_local[li] = l_new;
|
l_local[li] = l_new;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -7117,6 +7117,14 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
for (int kv : { 4096, 8192, 16384, }) {
|
||||||
|
for (int hs : { 64, 128, }) {
|
||||||
|
test_cases.emplace_back(new test_flash_attn_ext(hs, hs, 1, {1, 1}, kv, 1, false, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F32));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, false));
|
test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, false));
|
||||||
test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, true));
|
test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, true));
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue