flash-attn sycl: apply fixes and remove old implementation

Co-authored-by: safranowith <bsh155762@gmail.com>
Co-authored-by: ye-NX <y8703470@gmail.com>
This commit is contained in:
yehudit-dev 2025-11-23 16:59:25 +02:00
parent c9429b72d1
commit c62b98b083
8 changed files with 374 additions and 211 deletions

View File

@ -28,12 +28,12 @@ file(GLOB GGML_SOURCES_SYCL "*.cpp")
target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL} ${GGML_SOURCES_SYCL})
# Include flash-attn sources (SYCL optimized flash attention implementation)
file(GLOB GGML_HEADERS_SYCL_FLASH "flash-attn/*.h" "flash-attn/*.hpp")
file(GLOB GGML_SOURCES_SYCL_FLASH "flash-attn/*.cpp" "flash-attn/*.c")
file(GLOB GGML_HEADERS_SYCL_FLASH "fattn*.h" "fattn*.hpp")
file(GLOB GGML_SOURCES_SYCL_FLASH "fattn*.cpp" "fattn*.c")
target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL_FLASH} ${GGML_SOURCES_SYCL_FLASH})
# Also include kernel headers under flash-attn/kernels
file(GLOB GGML_HEADERS_SYCL_FLASH_KERNELS "flash-attn/kernels/*.h" "flash-attn/kernels/*.hpp")
file(GLOB GGML_HEADERS_SYCL_FLASH_KERNELS "fattn_kernel*.h" "fattn_kernel*.hpp")
target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL_FLASH_KERNELS})
if (WIN32)

View File

@ -38,6 +38,7 @@
#include "tsembd.hpp"
#include "wkv.hpp"
#include "pad_reflect_1d.hpp"
#include "fattn.hpp"
#endif // GGML_SYCL_BACKEND_HPP

View File

@ -0,0 +1,212 @@
#include "./fattn.hpp"
#include "./fattn_kernel.hpp"
#include "./fattn_common.hpp"
#include <cmath>
#include <cstring>
#include <limits>
#include <sycl/sycl.hpp>
#define Br 32
#define Bc 32
bool ggml_sycl_flash_attn_ext_supported(const ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
if (Q == nullptr || K == nullptr || V == nullptr) {
return false;
}
if (Q->type == GGML_TYPE_F32 && K->type == GGML_TYPE_F32 && V->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return true;
}
// if (Q->type == GGML_TYPE_F16 && K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
// return true;
// }
return false;
}
template<int64_t DQK, int64_t DV>
void ggml_sycl_op_flash_attn_2(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
GGML_ASSERT(Q != nullptr);
GGML_ASSERT(K != nullptr);
GGML_ASSERT(V != nullptr);
GGML_ASSERT(dst != nullptr);
//not support KV_Cache yet
GGML_ASSERT(K->ne[1] == V->ne[1]);
//not support multi head and gqa yet
GGML_ASSERT(Q->ne[2] == 1);
GGML_ASSERT(K->ne[2] == 1);
GGML_ASSERT(V->ne[2] == 1);
const float * Q_d = (const float *) Q->data;
const float * K_d = (const float *) K->data;
const float * V_d = (const float *) V->data;
float * dst_d = (float *) dst->data;
dpct::queue_ptr stream = ctx.stream();
const int64_t N = Q->ne[1];
const ptrdiff_t q_row_stride = Q->nb[1] / (ptrdiff_t)sizeof(float);
const ptrdiff_t k_row_stride = K->nb[1] / (ptrdiff_t)sizeof(float);
const ptrdiff_t v_row_stride = V->nb[1] / (ptrdiff_t)sizeof(float);
const ptrdiff_t o_row_stride = dst->nb[1] / (ptrdiff_t)sizeof(float);
// const int Br = std::min((int) FLASH_ATTN_BR_MAX, (int) N);
// const int Bc = std::min((int) FLASH_ATTN_BC_MAX, (int) N);
const int Tr = (N + Br - 1) / Br;
const int Tc = (N + Bc - 1) / Bc;
float * l_d = (float *) sycl::malloc_device(N * sizeof(float), *stream);
float * m_d = (float *) sycl::malloc_device(N * sizeof(float), *stream);
sycl::range<2> global(Br * Tr, Tc);
sycl::range<2> local(Br,1);
stream->submit([&](sycl::handler& cgh) {
sycl::local_accessor<float, 2> Qtile({Br, DQK}, cgh);
sycl::local_accessor<float, 2> Ktile({Bc, DQK}, cgh);
sycl::local_accessor<float, 2> Vtile({Bc, DV}, cgh);
sycl::local_accessor<float, 2> Stile({Br, Bc}, cgh);
sycl::local_accessor<float, 1> Ptile({Br * Bc}, cgh);
sycl::local_accessor<float, 1> m_local({Br}, cgh);
sycl::local_accessor<float, 1> l_local({Br}, cgh);
float* q_loc = Qtile.template get_multi_ptr<sycl::access::decorated::no>().get();
float* k_loc = Ktile.template get_multi_ptr<sycl::access::decorated::no>().get();
float* v_loc = Vtile.template get_multi_ptr<sycl::access::decorated::no>().get();
float* s_loc = Stile.template get_multi_ptr<sycl::access::decorated::no>().get();
float* p_loc = Ptile.template get_multi_ptr<sycl::access::decorated::no>().get();
float* m_loc = m_local.template get_multi_ptr<sycl::access::decorated::no>().get();
float* l_loc = l_local.template get_multi_ptr<sycl::access::decorated::no>().get();
cgh.parallel_for(sycl::nd_range<2>(global, local), [=](sycl::nd_item<2> it) {
auto group = it.get_group();
int group_id_i = group.get_group_id(0);
int group_id_j = group.get_group_id(1);
int row0 = group_id_i * Br;
int col0 = group_id_j * Bc;
if (row0 >= (int) N || col0 >= (int) N) {
return;
}
const float* Q_block = Q_d + (ptrdiff_t)row0 * q_row_stride;
const float* K_block = K_d + (ptrdiff_t)col0 * k_row_stride;
const float* V_block = V_d + (ptrdiff_t)col0 * v_row_stride;
float* O_block = dst_d + (ptrdiff_t)row0 * o_row_stride;
//this lines does not support non-contiguous tensors
ggml_sycl_memcpy<Br * DQK>(q_loc, Q_block);
ggml_sycl_memcpy<Bc * DQK>(k_loc, K_block);
ggml_sycl_memcpy<Bc * DV>(v_loc, V_block);
it.barrier(sycl::access::fence_space::local_space);
flash_attn_mul_mat_QK_kernel<DQK>(
it,
Q_block, q_row_stride,
K_block, k_row_stride,
s_loc, (ptrdiff_t)Bc,
Br, Bc
);
it.barrier(sycl::access::fence_space::local_space);
flash_attn_softmax_kernel(
it,
s_loc, p_loc,
m_loc, l_loc,
Br, Bc,
l_d, m_d
);
it.barrier(sycl::access::fence_space::local_space);
flash_attn_mul_mat_PV_kernel<DV>(
it,
p_loc, (ptrdiff_t)Bc,
V_block, v_row_stride,
O_block, o_row_stride,
Br,Bc
);
it.barrier(sycl::access::fence_space::local_space);
});
});
stream->submit([&](sycl::handler& cgh) {
const ptrdiff_t o_stride = o_row_stride;
cgh.parallel_for(sycl::range<1>(N), [=](sycl::id<1> id_row) {
int row = id_row[0];
float l_val = l_d[row];
if (l_val <= 0.0f) {
return;
}
float inv_l = 1.0f / l_val;
float * o_row = dst_d + (ptrdiff_t)row * o_stride;
for (int col = 0; col < DV; ++col) {
o_row[col] *= inv_l;
}
});
});
sycl::free(l_d, *stream);
sycl::free(m_d, *stream);
}
void ggml_sycl_op_flash_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * V = dst->src[2];
switch (Q->ne[0]) {
case 64:
GGML_ASSERT(V->ne[0] == 64);
ggml_sycl_op_flash_attn_2< 64, 64>(ctx, dst);
break;
case 80:
GGML_ASSERT(V->ne[0] == 80);
ggml_sycl_op_flash_attn_2< 80, 80>(ctx, dst);
break;
case 96:
GGML_ASSERT(V->ne[0] == 96);
ggml_sycl_op_flash_attn_2< 96, 96>(ctx, dst);
break;
case 112:
GGML_ASSERT(V->ne[0] == 112);
ggml_sycl_op_flash_attn_2<112, 112>(ctx, dst);
break;
case 128:
GGML_ASSERT(V->ne[0] == 128);
ggml_sycl_op_flash_attn_2<128, 128>(ctx, dst);
break;
case 256:
GGML_ASSERT(V->ne[0] == 256);
ggml_sycl_op_flash_attn_2<256, 256>(ctx, dst);
break;
default:
GGML_ABORT("fatal error");
break;
}
}

View File

@ -1,6 +1,7 @@
#pragma once
#ifndef GGML_SYCL_FATTN_HPP
#define GGML_SYCL_FATTN_HPP
#include "../common.hpp"
#include "common.hpp"
// Flash attention operation for SYCL backend
// This implements the Flash Attention algorithm optimized for SYCL devices
@ -8,3 +9,5 @@ void ggml_sycl_op_flash_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
// Check if flash attention is supported for given tensor
bool ggml_sycl_flash_attn_ext_supported(const ggml_tensor * dst);
#endif // GGML_SYCL_FATTN_HPP

View File

@ -0,0 +1,12 @@
#ifndef GGML_SYCL_FATTN_COMMON_HPP
#define GGML_SYCL_FATTN_COMMON_HPP
template<int N>
inline void ggml_sycl_memcpy(float* dst, const float* src) {
#pragma unroll
for (int i = 0; i < N; ++i) {
dst[i] = src[i];
}
}
#endif // GGML_SYCL_FATTN_COMMON_HPP

View File

@ -0,0 +1,141 @@
#ifndef GGML_SYCL_FATTN_KERNEL_HPP
#define GGML_SYCL_FATTN_KERNEL_HPP
#include <sycl/sycl.hpp>
template <int64_t QKD>
inline void flash_attn_mul_mat_QK_kernel(
sycl::nd_item<2> it,
const float * Q, ptrdiff_t q_row_stride,
const float * K, ptrdiff_t k_row_stride,
float * S, ptrdiff_t s_row_stride,
const int Br, const int Bc) {
const int i = it.get_local_id(0);
if (i >= Br) {
return;
}
const float * q_vec = Q + i * q_row_stride;
float * s_row = S + i * s_row_stride;
for (int j = 0; j < Bc; ++j) {
const float * k_vec = K + j * k_row_stride;
float score = 0.0f;
#pragma unroll
for (int k = 0; k < QKD; ++k) {
score += q_vec[k] * k_vec[k];
}
s_row[j] = score;
}
}
inline void flash_attn_softmax_kernel(
sycl::nd_item<2> it,
float * S, float * P,
float * m_local, float * l_local,
const int Br, const int Bc,
float * l_d, float * m_d
) {
const int li = it.get_local_id(0);
const int gi = it.get_group(0);
const int row = gi * Br + li;
if (li >= Br) {
return;
}
const int row_offset = li * Bc;
float m_old = m_d[row];
float l_old = l_d[row];
// 2. Block max
float m_block = -INFINITY;
for (int j = 0; j < Bc; ++j) {
const float s_ij = S[row_offset + j];
m_block = sycl::fmax(m_block, s_ij);
}
// 3. Block exp-sum
float l_block = 0.0f;
for (int j = 0; j < Bc; ++j) {
const float e = sycl::exp(S[row_offset + j] - m_block);
P[row_offset + j] = e; // temporary store
l_block += e;
}
// 4. Merge block stats with global (streaming softmax)
float m_new;
float l_new;
if (l_old == 0.0f && m_old == -INFINITY) {
// first block for this row
m_new = m_block;
l_new = l_block;
} else {
m_new = sycl::fmax(m_old, m_block);
const float alpha = sycl::exp(m_old - m_new);
const float beta = sycl::exp(m_block - m_new);
l_new = alpha * l_old + beta * l_block;
}
// 5. Store updated global stats
m_d[row] = m_new;
l_d[row] = l_new;
// 6. Convert local e_ij to global probabilities p_ij
float scale_block = 0.0f;
if (l_new > 0.0f) {
scale_block = sycl::exp(m_block - m_new) / l_new;
}
for (int j = 0; j < Bc; ++j) {
P[row_offset + j] *= scale_block;
}
// 7. Optional: keep local copies
m_local[li] = m_new;
l_local[li] = l_new;
}
template <int64_t VD>
inline void flash_attn_mul_mat_PV_kernel(
sycl::nd_item<2> it,
const float * P, ptrdiff_t p_row_stride,
const float * V, ptrdiff_t v_row_stride,
float * O, ptrdiff_t o_row_stride,
const int Br, const int Bc) {
const int i = it.get_local_id(0);
if (i >= Br) {
return;
}
const float * p_row = P + i * p_row_stride;
float * o_row = O + i * o_row_stride;
for (int j = 0; j < VD; ++j) {
float acc = 0.0f;
#pragma unroll
for (int k = 0; k < Bc; ++k) {
const float * v_row = V + k * v_row_stride;
acc += p_row[k] * v_row[j];
}
o_row[j] = acc;
}
}
#endif // GGML_SYCL_FATTN_KERNEL_HPP

View File

@ -1,98 +0,0 @@
#include "flash-attn-sycl.h"
#include "kernels/flash-attn-kernel.h"
#include <cmath>
#include <cstring>
#include <limits>
#include <sycl/sycl.hpp>
#define FLASH_ATTN_BR_MAX 32
#define FLASH_ATTN_BC_MAX 32
// Flash Attention: https://arxiv.org/abs/2205.14135
void ggml_sycl_op_flash_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
const ggml_tensor * mask = dst->src[3];
GGML_ASSERT(Q != nullptr);
GGML_ASSERT(K != nullptr);
GGML_ASSERT(V != nullptr);
GGML_ASSERT(dst != nullptr);
if (Q->type != GGML_TYPE_F32 || K->type != GGML_TYPE_F32 || V->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) {
fprintf(stderr, "[SYCL] FLASH-ATTENTION: tensor type not supported (Q=%d, K=%d, V=%d, dst=%d)\n", Q->type, K->type, V->type, dst->type);
return;
}
const float * Q_d = (const float *) Q->data;
const float * K_d = (const float *) K->data;
const float * V_d = (const float *) V->data;
float * dst_d = (float *) dst->data;
dpct::queue_ptr stream = ctx.stream();
const int64_t d = Q->ne[0];
const int64_t N = Q->ne[1];
float scale;
float max_bias;
float logit_softcap;
std::memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
std::memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
std::memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float));
const bool masked = (mask != nullptr);
const int Br = std::min((int) FLASH_ATTN_BR_MAX, (int) N);
const int Bc = std::min((int) FLASH_ATTN_BC_MAX, (int) N);
const int Tr = (N + Br - 1) / Br;
const int Tc = (N + Bc - 1) / Bc;
float * l_d = (float *) sycl::malloc_device(N * sizeof(float), *stream);
float * m_d = (float *) sycl::malloc_device(N * sizeof(float), *stream);
stream->fill(l_d, 0.0f, N);
stream->fill(m_d, -std::numeric_limits<float>::infinity(), N);
stream->fill(dst_d, 0.0f, N * d);
stream->wait();
for (int j = 0; j < Tc; ++j) {
stream->submit([&](sycl::handler & cgh) {
cgh.parallel_for(sycl::range<1>(Tr), [=](sycl::id<1> idx) {
const int i = idx[0];
flash_attn_tiled_kernel<FLASH_ATTN_BR_MAX, FLASH_ATTN_BC_MAX>(Q_d, K_d, V_d, dst_d, l_d, m_d, i, j, Br,
Bc, N, d, masked, scale);
});
});
}
stream->wait();
sycl::free(l_d, *stream);
sycl::free(m_d, *stream);
}
bool ggml_sycl_flash_attn_ext_supported(const ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
if (Q == nullptr || K == nullptr || V == nullptr) {
return false;
}
if (Q->type != GGML_TYPE_F32 || K->type != GGML_TYPE_F32 || V->type != GGML_TYPE_F32) {
return false;
}
if (dst->type != GGML_TYPE_F32) {
return false;
}
return true;
}

View File

@ -1,108 +0,0 @@
#pragma once
#include <sycl/sycl.hpp>
template <int Br_MAX = 32, int Bc_MAX = 32>
inline void flash_attn_tiled_kernel(const float * Q,
const float * K,
const float * V,
float * O,
float * l,
float * m,
const int i_block,
const int j_block,
const int Br,
const int Bc,
const int N,
const int d,
const bool masked,
const float scale) {
const int i_start = i_block * Br;
const int j_start = j_block * Bc;
float S[Br_MAX][Bc_MAX];
float P[Br_MAX][Bc_MAX];
float m_local[Br_MAX];
float l_local[Br_MAX];
for (int qi = 0; qi < Br; ++qi) {
const int q_row = i_start + qi;
if (q_row >= N) {
continue;
}
for (int kj = 0; kj < Bc; ++kj) {
const int k_row = j_start + kj;
if (k_row >= N) {
S[qi][kj] = -INFINITY;
continue;
}
if (masked && k_row > q_row) {
S[qi][kj] = -INFINITY;
continue;
}
float score = 0.0f;
for (int k = 0; k < d; ++k) {
score += Q[q_row * d + k] * K[k_row * d + k];
}
S[qi][kj] = score * scale;
}
}
for (int qi = 0; qi < Br; ++qi) {
const int q_row = i_start + qi;
if (q_row >= N) {
continue;
}
m_local[qi] = -INFINITY;
for (int kj = 0; kj < Bc; ++kj) {
if (j_start + kj < N) {
m_local[qi] = sycl::fmax(m_local[qi], S[qi][kj]);
}
}
l_local[qi] = 0.0f;
for (int kj = 0; kj < Bc; ++kj) {
if (j_start + kj < N && !sycl::isinf(S[qi][kj])) {
P[qi][kj] = sycl::exp(S[qi][kj] - m_local[qi]);
l_local[qi] += P[qi][kj];
} else {
P[qi][kj] = 0.0f;
}
}
}
for (int qi = 0; qi < Br; ++qi) {
const int q_row = i_start + qi;
if (q_row >= N) {
continue;
}
const float m_old = m[q_row];
const float m_new = sycl::fmax(m_old, m_local[qi]);
const float l_old = l[q_row];
const float l_new = sycl::exp(m_old - m_new) * l_old + sycl::exp(m_local[qi] - m_new) * l_local[qi];
const float correction_old = sycl::exp(m_old - m_new);
const float correction_new = sycl::exp(m_local[qi] - m_new);
for (int k = 0; k < d; ++k) {
float pv = 0.0f;
for (int kj = 0; kj < Bc; ++kj) {
const int v_row = j_start + kj;
if (v_row < N) {
pv += P[qi][kj] * V[v_row * d + k];
}
}
const int o_idx = q_row * d + k;
O[o_idx] = (correction_old * O[o_idx] + correction_new * pv) / l_new;
}
l[q_row] = l_new;
m[q_row] = m_new;
}
}