This commit is contained in:
YehuditE 2025-12-17 08:57:16 +03:00 committed by GitHub
commit 529353ce83
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 425 additions and 0 deletions

View File

@ -27,6 +27,15 @@ file(GLOB GGML_HEADERS_SYCL "*.hpp")
file(GLOB GGML_SOURCES_SYCL "*.cpp")
target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL} ${GGML_SOURCES_SYCL})
# Include flash-attn sources (SYCL optimized flash attention implementation)
file(GLOB GGML_HEADERS_SYCL_FLASH "fattn*.h" "fattn*.hpp")
file(GLOB GGML_SOURCES_SYCL_FLASH "fattn*.cpp" "fattn*.c")
target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL_FLASH} ${GGML_SOURCES_SYCL_FLASH})
# Also include kernel headers under flash-attn/kernels
file(GLOB GGML_HEADERS_SYCL_FLASH_KERNELS "fattn_kernel*.h" "fattn_kernel*.hpp")
target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL_FLASH_KERNELS})
if (WIN32)
# To generate a Visual Studio solution, using Intel C++ Compiler for ggml-sycl is mandatory
if( ${CMAKE_GENERATOR} MATCHES "Visual Studio" AND NOT (${CMAKE_GENERATOR_TOOLSET} MATCHES "Intel C"))

View File

@ -40,6 +40,7 @@
#include "tsembd.hpp"
#include "wkv.hpp"
#include "pad_reflect_1d.hpp"
#include "fattn.hpp"
#endif // GGML_SYCL_BACKEND_HPP

View File

@ -0,0 +1,236 @@
#include "./fattn.hpp"
#include "./fattn_kernel.hpp"
#include "./fattn_common.hpp"
#include <cmath>
#include <cstring>
#include <limits>
#include <sycl/sycl.hpp>
#define Br 32
#define Bc 32
bool ggml_sycl_flash_attn_ext_supported(const ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
const ggml_tensor * mask = dst->src[3];
float scale, max_bias, logit_softcap;
std::memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
std::memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
std::memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float));
if( max_bias != 0.0f || logit_softcap != 0.0f){
return false;
}
if (Q == nullptr || K == nullptr || V == nullptr) {
return false;
}
if (mask != 0) {
return false;
}
if (Q->type != GGML_TYPE_F32 || K->type != GGML_TYPE_F32 || V->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) {
return false;
}
int64_t DQK = Q->ne[0];
int64_t DV = V->ne[0];
if (DQK != DV){
return false;
}
if (DV != 32 && DV != 64 && DV != 80 && DV != 96 && DV != 112 && DV != 128 && DV != 256 && DV != 512){
return false;
}
//not support multi-head yet
if (Q->ne[2] != 1 || K->ne[2] != 1 || V->ne[2] != 1) {
return false;
}
return true;
}
template<int64_t DQK, int64_t DV>
void ggml_sycl_op_flash_attn_2(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
const float * Q_d = (const float *) Q->data;
const float * K_d = (const float *) K->data;
const float * V_d = (const float *) V->data;
float * dst_d = (float *) dst->data;
dpct::queue_ptr stream = ctx.stream();
const int64_t N = Q->ne[1];
const ptrdiff_t q_row_stride = Q->nb[1] / (ptrdiff_t)sizeof(float);
const ptrdiff_t k_row_stride = K->nb[1] / (ptrdiff_t)sizeof(float);
const ptrdiff_t v_row_stride = V->nb[1] / (ptrdiff_t)sizeof(float);
const ptrdiff_t o_row_stride = dst->nb[1] / (ptrdiff_t)sizeof(float);
// const int Br = std::min((int) FLASH_ATTN_BR_MAX, (int) N);
// const int Bc = std::min((int) FLASH_ATTN_BC_MAX, (int) N);
const int Tr = (N + Br - 1) / Br;
const int Tc = (N + Bc - 1) / Bc;
float * l_d = (float *) sycl::malloc_device(N * sizeof(float), *stream);
float * m_d = (float *) sycl::malloc_device(N * sizeof(float), *stream);
sycl::range<2> global(Br * Tr, Tc);
sycl::range<2> local(Br,1);
stream->submit([&](sycl::handler& cgh) {
sycl::local_accessor<float, 2> Qtile({Br, DQK}, cgh);
sycl::local_accessor<float, 2> Ktile({Bc, DQK}, cgh);
sycl::local_accessor<float, 2> Vtile({Bc, DV}, cgh);
sycl::local_accessor<float, 2> Stile({Br, Bc}, cgh);
sycl::local_accessor<float, 1> Ptile({Br * Bc}, cgh);
sycl::local_accessor<float, 1> m_local({Br}, cgh);
sycl::local_accessor<float, 1> l_local({Br}, cgh);
cgh.parallel_for(sycl::nd_range<2>(global, local), [=](sycl::nd_item<2> it) {
float* q_loc = Qtile.template get_multi_ptr<sycl::access::decorated::no>().get();
float* k_loc = Ktile.template get_multi_ptr<sycl::access::decorated::no>().get();
float* v_loc = Vtile.template get_multi_ptr<sycl::access::decorated::no>().get();
float* s_loc = Stile.template get_multi_ptr<sycl::access::decorated::no>().get();
float* p_loc = Ptile.template get_multi_ptr<sycl::access::decorated::no>().get();
float* m_loc = m_local.template get_multi_ptr<sycl::access::decorated::no>().get();
float* l_loc = l_local.template get_multi_ptr<sycl::access::decorated::no>().get();
auto group = it.get_group();
int group_id_i = group.get_group_id(0);
int group_id_j = group.get_group_id(1);
int row0 = group_id_i * Br;
int col0 = group_id_j * Bc;
if (row0 >= (int) N || col0 >= (int) N) {
return;
}
const float* Q_block = Q_d + (ptrdiff_t)row0 * q_row_stride;
const float* K_block = K_d + (ptrdiff_t)col0 * k_row_stride;
const float* V_block = V_d + (ptrdiff_t)col0 * v_row_stride;
float* O_block = dst_d + (ptrdiff_t)row0 * o_row_stride;
//this lines does not support non-contiguous tensors
ggml_sycl_memcpy<Br * DQK>(q_loc, Q_block);
ggml_sycl_memcpy<Bc * DQK>(k_loc, K_block);
ggml_sycl_memcpy<Bc * DV>(v_loc, V_block);
it.barrier(sycl::access::fence_space::local_space);
flash_attn_mul_mat_QK_kernel<DQK>(
it,
Q_block, q_row_stride,
K_block, k_row_stride,
s_loc, (ptrdiff_t)Bc,
Br, Bc
);
it.barrier(sycl::access::fence_space::local_space);
flash_attn_softmax_kernel(
it,
s_loc, p_loc,
m_loc, l_loc,
Br, Bc,
l_d, m_d
);
it.barrier(sycl::access::fence_space::local_space);
flash_attn_mul_mat_PV_kernel<DV>(
it,
p_loc, (ptrdiff_t)Bc,
V_block, v_row_stride,
O_block, o_row_stride,
Br,Bc
);
it.barrier(sycl::access::fence_space::local_space);
});
});
stream->submit([&](sycl::handler& cgh) {
const ptrdiff_t o_stride = o_row_stride;
cgh.parallel_for(sycl::range<1>(N), [=](sycl::id<1> id_row) {
int row = id_row[0];
float l_val = l_d[row];
if (l_val <= 0.0f) {
return;
}
float inv_l = 1.0f / l_val;
float * o_row = dst_d + (ptrdiff_t)row * o_stride;
for (int col = 0; col < DV; ++col) {
o_row[col] *= inv_l;
}
});
});
sycl::free(l_d, *stream);
sycl::free(m_d, *stream);
}
void ggml_sycl_op_flash_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * V = dst->src[2];
switch (Q->ne[0]) {
case 32:
GGML_ASSERT(V->ne[0] == 32);
ggml_sycl_op_flash_attn_2< 32, 32>(ctx, dst);
break;
case 64:
GGML_ASSERT(V->ne[0] == 64);
ggml_sycl_op_flash_attn_2< 64, 64>(ctx, dst);
break;
case 80:
GGML_ASSERT(V->ne[0] == 80);
ggml_sycl_op_flash_attn_2< 80, 80>(ctx, dst);
break;
case 96:
GGML_ASSERT(V->ne[0] == 96);
ggml_sycl_op_flash_attn_2< 96, 96>(ctx, dst);
break;
case 112:
GGML_ASSERT(V->ne[0] == 112);
ggml_sycl_op_flash_attn_2<112, 112>(ctx, dst);
break;
case 128:
GGML_ASSERT(V->ne[0] == 128);
ggml_sycl_op_flash_attn_2<128, 128>(ctx, dst);
break;
case 256:
GGML_ASSERT(V->ne[0] == 256);
ggml_sycl_op_flash_attn_2<256, 256>(ctx, dst);
break;
case 576:
GGML_ASSERT(V->ne[0] == 512);
ggml_sycl_op_flash_attn_2<512, 512>(ctx, dst);
break;
default:
fprintf(stderr, "Warning: Unsupported head size %ld — skipping op\n", Q->ne[0]);
break;
}
}

View File

@ -0,0 +1,13 @@
#ifndef GGML_SYCL_FATTN_HPP
#define GGML_SYCL_FATTN_HPP
#include "common.hpp"
// Flash attention operation for SYCL backend
// This implements the Flash Attention algorithm optimized for SYCL devices
void ggml_sycl_op_flash_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
// Check if flash attention is supported for given tensor
bool ggml_sycl_flash_attn_ext_supported(const ggml_tensor * dst);
#endif // GGML_SYCL_FATTN_HPP

View File

@ -0,0 +1,12 @@
#ifndef GGML_SYCL_FATTN_COMMON_HPP
#define GGML_SYCL_FATTN_COMMON_HPP
template<int N>
inline void ggml_sycl_memcpy(float* dst, const float* src) {
#pragma unroll
for (int i = 0; i < N; ++i) {
dst[i] = src[i];
}
}
#endif // GGML_SYCL_FATTN_COMMON_HPP

View File

@ -0,0 +1,141 @@
#ifndef GGML_SYCL_FATTN_KERNEL_HPP
#define GGML_SYCL_FATTN_KERNEL_HPP
#include <sycl/sycl.hpp>
template <int64_t QKD>
inline void flash_attn_mul_mat_QK_kernel(
sycl::nd_item<2> it,
const float * Q, ptrdiff_t q_row_stride,
const float * K, ptrdiff_t k_row_stride,
float * S, ptrdiff_t s_row_stride,
const int Br, const int Bc) {
const int i = it.get_local_id(0);
if (i >= Br) {
return;
}
const float * q_vec = Q + i * q_row_stride;
float * s_row = S + i * s_row_stride;
for (int j = 0; j < Bc; ++j) {
const float * k_vec = K + j * k_row_stride;
float score = 0.0f;
#pragma unroll
for (int k = 0; k < QKD; ++k) {
score += q_vec[k] * k_vec[k];
}
s_row[j] = score;
}
}
inline void flash_attn_softmax_kernel(
sycl::nd_item<2> it,
float * S, float * P,
float * m_local, float * l_local,
const int Br, const int Bc,
float * l_d, float * m_d
) {
const int li = it.get_local_id(0);
const int gi = it.get_group(0);
const int row = gi * Br + li;
if (li >= Br) {
return;
}
const int row_offset = li * Bc;
float m_old = m_d[row];
float l_old = l_d[row];
// Block max
float m_block = -INFINITY;
for (int j = 0; j < Bc; ++j) {
const float s_ij = S[row_offset + j];
m_block = sycl::fmax(m_block, s_ij);
}
// Block exp-sum
float l_block = 0.0f;
for (int j = 0; j < Bc; ++j) {
const float e = sycl::exp(S[row_offset + j] - m_block);
P[row_offset + j] = e; // temporary store
l_block += e;
}
// Merge block stats with global (streaming softmax)
float m_new;
float l_new;
if (l_old == 0.0f && m_old == -INFINITY) {
// first block for this row
m_new = m_block;
l_new = l_block;
} else {
m_new = sycl::fmax(m_old, m_block);
const float alpha = sycl::exp(m_old - m_new);
const float beta = sycl::exp(m_block - m_new);
l_new = alpha * l_old + beta * l_block;
}
// Store updated global stats
m_d[row] = m_new;
l_d[row] = l_new;
// Convert local e_ij to global probabilities p_ij
float scale_block = 0.0f;
if (l_new > 0.0f) {
scale_block = sycl::exp(m_block - m_new) / l_new;
}
for (int j = 0; j < Bc; ++j) {
P[row_offset + j] *= scale_block;
}
// Optional: keep local copies
m_local[li] = m_new;
l_local[li] = l_new;
}
template <int64_t VD>
inline void flash_attn_mul_mat_PV_kernel(
sycl::nd_item<2> it,
const float * P, ptrdiff_t p_row_stride,
const float * V, ptrdiff_t v_row_stride,
float * O, ptrdiff_t o_row_stride,
const int Br, const int Bc) {
const int i = it.get_local_id(0);
if (i >= Br) {
return;
}
const float * p_row = P + i * p_row_stride;
float * o_row = O + i * o_row_stride;
for (int j = 0; j < VD; ++j) {
float acc = 0.0f;
#pragma unroll
for (int k = 0; k < Bc; ++k) {
const float * v_row = V + k * v_row_stride;
acc += p_row[k] * v_row[j];
}
o_row[j] = acc;
}
}
#endif // GGML_SYCL_FATTN_KERNEL_HPP

View File

@ -3927,6 +3927,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
case GGML_OP_ARGSORT:
ggml_sycl_argsort(ctx, dst);
break;
case GGML_OP_FLASH_ATTN_EXT:
ggml_sycl_op_flash_attn(ctx, dst);
break;
case GGML_OP_TIMESTEP_EMBEDDING:
ggml_sycl_op_timestep_embedding(ctx, dst);
break;
@ -4617,6 +4620,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
return ggml_is_contiguous(op->src[0]);
case GGML_OP_FLASH_ATTN_EXT:
return ggml_sycl_flash_attn_ext_supported(op);
case GGML_OP_ARGSORT:
return op->src[0]->ne[0] * sizeof(int) <=
ggml_sycl_info().devices[device].smpbo;

View File

@ -8162,6 +8162,14 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
}
}
for (int kv : { 4096, 8192, 16384, }) {
for (int hs : { 64, 128, }) {
test_cases.emplace_back(new test_flash_attn_ext(hs, hs, 1, {1, 1}, kv, 1, false, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F32));
}
}
test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, false));
test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, true));