sycl: initialize flash-attention implementation
Co-authored-by: safranowith <bsh155762@gmail.com> Co-authored-by: ye-NX <y8703470@gmail.com>
This commit is contained in:
parent
6de8ed7519
commit
c9429b72d1
|
|
@ -27,6 +27,15 @@ file(GLOB GGML_HEADERS_SYCL "*.hpp")
|
||||||
file(GLOB GGML_SOURCES_SYCL "*.cpp")
|
file(GLOB GGML_SOURCES_SYCL "*.cpp")
|
||||||
target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL} ${GGML_SOURCES_SYCL})
|
target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL} ${GGML_SOURCES_SYCL})
|
||||||
|
|
||||||
|
# Include flash-attn sources (SYCL optimized flash attention implementation)
|
||||||
|
file(GLOB GGML_HEADERS_SYCL_FLASH "flash-attn/*.h" "flash-attn/*.hpp")
|
||||||
|
file(GLOB GGML_SOURCES_SYCL_FLASH "flash-attn/*.cpp" "flash-attn/*.c")
|
||||||
|
target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL_FLASH} ${GGML_SOURCES_SYCL_FLASH})
|
||||||
|
|
||||||
|
# Also include kernel headers under flash-attn/kernels
|
||||||
|
file(GLOB GGML_HEADERS_SYCL_FLASH_KERNELS "flash-attn/kernels/*.h" "flash-attn/kernels/*.hpp")
|
||||||
|
target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL_FLASH_KERNELS})
|
||||||
|
|
||||||
if (WIN32)
|
if (WIN32)
|
||||||
# To generate a Visual Studio solution, using Intel C++ Compiler for ggml-sycl is mandatory
|
# To generate a Visual Studio solution, using Intel C++ Compiler for ggml-sycl is mandatory
|
||||||
if( ${CMAKE_GENERATOR} MATCHES "Visual Studio" AND NOT (${CMAKE_GENERATOR_TOOLSET} MATCHES "Intel C"))
|
if( ${CMAKE_GENERATOR} MATCHES "Visual Studio" AND NOT (${CMAKE_GENERATOR_TOOLSET} MATCHES "Intel C"))
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,98 @@
|
||||||
|
#include "flash-attn-sycl.h"
|
||||||
|
|
||||||
|
#include "kernels/flash-attn-kernel.h"
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
#include <cstring>
|
||||||
|
#include <limits>
|
||||||
|
#include <sycl/sycl.hpp>
|
||||||
|
|
||||||
|
#define FLASH_ATTN_BR_MAX 32
|
||||||
|
#define FLASH_ATTN_BC_MAX 32
|
||||||
|
|
||||||
|
// Flash Attention: https://arxiv.org/abs/2205.14135
|
||||||
|
void ggml_sycl_op_flash_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||||
|
const ggml_tensor * Q = dst->src[0];
|
||||||
|
const ggml_tensor * K = dst->src[1];
|
||||||
|
const ggml_tensor * V = dst->src[2];
|
||||||
|
const ggml_tensor * mask = dst->src[3];
|
||||||
|
|
||||||
|
GGML_ASSERT(Q != nullptr);
|
||||||
|
GGML_ASSERT(K != nullptr);
|
||||||
|
GGML_ASSERT(V != nullptr);
|
||||||
|
GGML_ASSERT(dst != nullptr);
|
||||||
|
|
||||||
|
if (Q->type != GGML_TYPE_F32 || K->type != GGML_TYPE_F32 || V->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) {
|
||||||
|
fprintf(stderr, "[SYCL] FLASH-ATTENTION: tensor type not supported (Q=%d, K=%d, V=%d, dst=%d)\n", Q->type, K->type, V->type, dst->type);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const float * Q_d = (const float *) Q->data;
|
||||||
|
const float * K_d = (const float *) K->data;
|
||||||
|
const float * V_d = (const float *) V->data;
|
||||||
|
float * dst_d = (float *) dst->data;
|
||||||
|
|
||||||
|
dpct::queue_ptr stream = ctx.stream();
|
||||||
|
|
||||||
|
const int64_t d = Q->ne[0];
|
||||||
|
const int64_t N = Q->ne[1];
|
||||||
|
|
||||||
|
float scale;
|
||||||
|
float max_bias;
|
||||||
|
float logit_softcap;
|
||||||
|
|
||||||
|
std::memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
|
||||||
|
std::memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
|
||||||
|
std::memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float));
|
||||||
|
|
||||||
|
const bool masked = (mask != nullptr);
|
||||||
|
|
||||||
|
const int Br = std::min((int) FLASH_ATTN_BR_MAX, (int) N);
|
||||||
|
const int Bc = std::min((int) FLASH_ATTN_BC_MAX, (int) N);
|
||||||
|
|
||||||
|
const int Tr = (N + Br - 1) / Br;
|
||||||
|
const int Tc = (N + Bc - 1) / Bc;
|
||||||
|
|
||||||
|
float * l_d = (float *) sycl::malloc_device(N * sizeof(float), *stream);
|
||||||
|
float * m_d = (float *) sycl::malloc_device(N * sizeof(float), *stream);
|
||||||
|
|
||||||
|
stream->fill(l_d, 0.0f, N);
|
||||||
|
stream->fill(m_d, -std::numeric_limits<float>::infinity(), N);
|
||||||
|
stream->fill(dst_d, 0.0f, N * d);
|
||||||
|
stream->wait();
|
||||||
|
|
||||||
|
for (int j = 0; j < Tc; ++j) {
|
||||||
|
stream->submit([&](sycl::handler & cgh) {
|
||||||
|
cgh.parallel_for(sycl::range<1>(Tr), [=](sycl::id<1> idx) {
|
||||||
|
const int i = idx[0];
|
||||||
|
flash_attn_tiled_kernel<FLASH_ATTN_BR_MAX, FLASH_ATTN_BC_MAX>(Q_d, K_d, V_d, dst_d, l_d, m_d, i, j, Br,
|
||||||
|
Bc, N, d, masked, scale);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
stream->wait();
|
||||||
|
|
||||||
|
sycl::free(l_d, *stream);
|
||||||
|
sycl::free(m_d, *stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ggml_sycl_flash_attn_ext_supported(const ggml_tensor * dst) {
|
||||||
|
const ggml_tensor * Q = dst->src[0];
|
||||||
|
const ggml_tensor * K = dst->src[1];
|
||||||
|
const ggml_tensor * V = dst->src[2];
|
||||||
|
|
||||||
|
if (Q == nullptr || K == nullptr || V == nullptr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (Q->type != GGML_TYPE_F32 || K->type != GGML_TYPE_F32 || V->type != GGML_TYPE_F32) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (dst->type != GGML_TYPE_F32) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,10 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "../common.hpp"
|
||||||
|
|
||||||
|
// Flash attention operation for SYCL backend
|
||||||
|
// This implements the Flash Attention algorithm optimized for SYCL devices
|
||||||
|
void ggml_sycl_op_flash_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
|
// Check if flash attention is supported for given tensor
|
||||||
|
bool ggml_sycl_flash_attn_ext_supported(const ggml_tensor * dst);
|
||||||
|
|
@ -0,0 +1,108 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <sycl/sycl.hpp>
|
||||||
|
|
||||||
|
template <int Br_MAX = 32, int Bc_MAX = 32>
|
||||||
|
inline void flash_attn_tiled_kernel(const float * Q,
|
||||||
|
const float * K,
|
||||||
|
const float * V,
|
||||||
|
float * O,
|
||||||
|
float * l,
|
||||||
|
float * m,
|
||||||
|
const int i_block,
|
||||||
|
const int j_block,
|
||||||
|
const int Br,
|
||||||
|
const int Bc,
|
||||||
|
const int N,
|
||||||
|
const int d,
|
||||||
|
const bool masked,
|
||||||
|
const float scale) {
|
||||||
|
const int i_start = i_block * Br;
|
||||||
|
const int j_start = j_block * Bc;
|
||||||
|
|
||||||
|
float S[Br_MAX][Bc_MAX];
|
||||||
|
float P[Br_MAX][Bc_MAX];
|
||||||
|
float m_local[Br_MAX];
|
||||||
|
float l_local[Br_MAX];
|
||||||
|
|
||||||
|
for (int qi = 0; qi < Br; ++qi) {
|
||||||
|
const int q_row = i_start + qi;
|
||||||
|
if (q_row >= N) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int kj = 0; kj < Bc; ++kj) {
|
||||||
|
const int k_row = j_start + kj;
|
||||||
|
if (k_row >= N) {
|
||||||
|
S[qi][kj] = -INFINITY;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (masked && k_row > q_row) {
|
||||||
|
S[qi][kj] = -INFINITY;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
float score = 0.0f;
|
||||||
|
for (int k = 0; k < d; ++k) {
|
||||||
|
score += Q[q_row * d + k] * K[k_row * d + k];
|
||||||
|
}
|
||||||
|
S[qi][kj] = score * scale;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int qi = 0; qi < Br; ++qi) {
|
||||||
|
const int q_row = i_start + qi;
|
||||||
|
if (q_row >= N) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
m_local[qi] = -INFINITY;
|
||||||
|
for (int kj = 0; kj < Bc; ++kj) {
|
||||||
|
if (j_start + kj < N) {
|
||||||
|
m_local[qi] = sycl::fmax(m_local[qi], S[qi][kj]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
l_local[qi] = 0.0f;
|
||||||
|
for (int kj = 0; kj < Bc; ++kj) {
|
||||||
|
if (j_start + kj < N && !sycl::isinf(S[qi][kj])) {
|
||||||
|
P[qi][kj] = sycl::exp(S[qi][kj] - m_local[qi]);
|
||||||
|
l_local[qi] += P[qi][kj];
|
||||||
|
} else {
|
||||||
|
P[qi][kj] = 0.0f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int qi = 0; qi < Br; ++qi) {
|
||||||
|
const int q_row = i_start + qi;
|
||||||
|
if (q_row >= N) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const float m_old = m[q_row];
|
||||||
|
const float m_new = sycl::fmax(m_old, m_local[qi]);
|
||||||
|
const float l_old = l[q_row];
|
||||||
|
const float l_new = sycl::exp(m_old - m_new) * l_old + sycl::exp(m_local[qi] - m_new) * l_local[qi];
|
||||||
|
|
||||||
|
const float correction_old = sycl::exp(m_old - m_new);
|
||||||
|
const float correction_new = sycl::exp(m_local[qi] - m_new);
|
||||||
|
|
||||||
|
for (int k = 0; k < d; ++k) {
|
||||||
|
float pv = 0.0f;
|
||||||
|
for (int kj = 0; kj < Bc; ++kj) {
|
||||||
|
const int v_row = j_start + kj;
|
||||||
|
if (v_row < N) {
|
||||||
|
pv += P[qi][kj] * V[v_row * d + k];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const int o_idx = q_row * d + k;
|
||||||
|
O[o_idx] = (correction_old * O[o_idx] + correction_new * pv) / l_new;
|
||||||
|
}
|
||||||
|
|
||||||
|
l[q_row] = l_new;
|
||||||
|
m[q_row] = m_new;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -3839,6 +3839,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
||||||
case GGML_OP_ARGSORT:
|
case GGML_OP_ARGSORT:
|
||||||
ggml_sycl_argsort(ctx, dst);
|
ggml_sycl_argsort(ctx, dst);
|
||||||
break;
|
break;
|
||||||
|
case GGML_OP_FLASH_ATTN_EXT:
|
||||||
|
ggml_sycl_op_flash_attn(ctx, dst);
|
||||||
|
break;
|
||||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||||
ggml_sycl_op_timestep_embedding(ctx, dst);
|
ggml_sycl_op_timestep_embedding(ctx, dst);
|
||||||
break;
|
break;
|
||||||
|
|
@ -4501,6 +4504,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||||
case GGML_OP_MEAN:
|
case GGML_OP_MEAN:
|
||||||
case GGML_OP_ARGSORT:
|
case GGML_OP_ARGSORT:
|
||||||
return ggml_is_contiguous(op->src[0]);
|
return ggml_is_contiguous(op->src[0]);
|
||||||
|
case GGML_OP_FLASH_ATTN_EXT:
|
||||||
|
return ggml_sycl_flash_attn_ext_supported(op);
|
||||||
case GGML_OP_POOL_2D:
|
case GGML_OP_POOL_2D:
|
||||||
case GGML_OP_ACC:
|
case GGML_OP_ACC:
|
||||||
return true;
|
return true;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue