flash-attn sycl: apply fixes and remove old implementation
Co-authored-by: safranowith <bsh155762@gmail.com> Co-authored-by: ye-NX <y8703470@gmail.com>
This commit is contained in:
parent
c9429b72d1
commit
c62b98b083
|
|
@ -28,12 +28,12 @@ file(GLOB GGML_SOURCES_SYCL "*.cpp")
|
|||
target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL} ${GGML_SOURCES_SYCL})
|
||||
|
||||
# Include flash-attn sources (SYCL optimized flash attention implementation)
|
||||
file(GLOB GGML_HEADERS_SYCL_FLASH "flash-attn/*.h" "flash-attn/*.hpp")
|
||||
file(GLOB GGML_SOURCES_SYCL_FLASH "flash-attn/*.cpp" "flash-attn/*.c")
|
||||
file(GLOB GGML_HEADERS_SYCL_FLASH "fattn*.h" "fattn*.hpp")
|
||||
file(GLOB GGML_SOURCES_SYCL_FLASH "fattn*.cpp" "fattn*.c")
|
||||
target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL_FLASH} ${GGML_SOURCES_SYCL_FLASH})
|
||||
|
||||
# Also include kernel headers under flash-attn/kernels
|
||||
file(GLOB GGML_HEADERS_SYCL_FLASH_KERNELS "flash-attn/kernels/*.h" "flash-attn/kernels/*.hpp")
|
||||
file(GLOB GGML_HEADERS_SYCL_FLASH_KERNELS "fattn_kernel*.h" "fattn_kernel*.hpp")
|
||||
target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL_FLASH_KERNELS})
|
||||
|
||||
if (WIN32)
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@
|
|||
#include "tsembd.hpp"
|
||||
#include "wkv.hpp"
|
||||
#include "pad_reflect_1d.hpp"
|
||||
#include "fattn.hpp"
|
||||
|
||||
|
||||
#endif // GGML_SYCL_BACKEND_HPP
|
||||
|
|
|
|||
|
|
@ -0,0 +1,212 @@
|
|||
#include "./fattn.hpp"
|
||||
#include "./fattn_kernel.hpp"
|
||||
#include "./fattn_common.hpp"
|
||||
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <limits>
|
||||
#include <sycl/sycl.hpp>
|
||||
|
||||
#define Br 32
|
||||
#define Bc 32
|
||||
|
||||
|
||||
bool ggml_sycl_flash_attn_ext_supported(const ggml_tensor * dst) {
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
|
||||
if (Q == nullptr || K == nullptr || V == nullptr) {
|
||||
return false;
|
||||
}
|
||||
if (Q->type == GGML_TYPE_F32 && K->type == GGML_TYPE_F32 && V->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return true;
|
||||
}
|
||||
// if (Q->type == GGML_TYPE_F16 && K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
|
||||
// return true;
|
||||
// }
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
template<int64_t DQK, int64_t DV>
|
||||
void ggml_sycl_op_flash_attn_2(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
|
||||
GGML_ASSERT(Q != nullptr);
|
||||
GGML_ASSERT(K != nullptr);
|
||||
GGML_ASSERT(V != nullptr);
|
||||
GGML_ASSERT(dst != nullptr);
|
||||
|
||||
//not support KV_Cache yet
|
||||
GGML_ASSERT(K->ne[1] == V->ne[1]);
|
||||
|
||||
//not support multi head and gqa yet
|
||||
GGML_ASSERT(Q->ne[2] == 1);
|
||||
GGML_ASSERT(K->ne[2] == 1);
|
||||
GGML_ASSERT(V->ne[2] == 1);
|
||||
|
||||
const float * Q_d = (const float *) Q->data;
|
||||
const float * K_d = (const float *) K->data;
|
||||
const float * V_d = (const float *) V->data;
|
||||
float * dst_d = (float *) dst->data;
|
||||
|
||||
dpct::queue_ptr stream = ctx.stream();
|
||||
|
||||
const int64_t N = Q->ne[1];
|
||||
|
||||
const ptrdiff_t q_row_stride = Q->nb[1] / (ptrdiff_t)sizeof(float);
|
||||
const ptrdiff_t k_row_stride = K->nb[1] / (ptrdiff_t)sizeof(float);
|
||||
const ptrdiff_t v_row_stride = V->nb[1] / (ptrdiff_t)sizeof(float);
|
||||
const ptrdiff_t o_row_stride = dst->nb[1] / (ptrdiff_t)sizeof(float);
|
||||
|
||||
// const int Br = std::min((int) FLASH_ATTN_BR_MAX, (int) N);
|
||||
// const int Bc = std::min((int) FLASH_ATTN_BC_MAX, (int) N);
|
||||
|
||||
const int Tr = (N + Br - 1) / Br;
|
||||
const int Tc = (N + Bc - 1) / Bc;
|
||||
|
||||
float * l_d = (float *) sycl::malloc_device(N * sizeof(float), *stream);
|
||||
float * m_d = (float *) sycl::malloc_device(N * sizeof(float), *stream);
|
||||
|
||||
sycl::range<2> global(Br * Tr, Tc);
|
||||
sycl::range<2> local(Br,1);
|
||||
|
||||
stream->submit([&](sycl::handler& cgh) {
|
||||
sycl::local_accessor<float, 2> Qtile({Br, DQK}, cgh);
|
||||
sycl::local_accessor<float, 2> Ktile({Bc, DQK}, cgh);
|
||||
sycl::local_accessor<float, 2> Vtile({Bc, DV}, cgh);
|
||||
sycl::local_accessor<float, 2> Stile({Br, Bc}, cgh);
|
||||
sycl::local_accessor<float, 1> Ptile({Br * Bc}, cgh);
|
||||
sycl::local_accessor<float, 1> m_local({Br}, cgh);
|
||||
sycl::local_accessor<float, 1> l_local({Br}, cgh);
|
||||
|
||||
float* q_loc = Qtile.template get_multi_ptr<sycl::access::decorated::no>().get();
|
||||
float* k_loc = Ktile.template get_multi_ptr<sycl::access::decorated::no>().get();
|
||||
float* v_loc = Vtile.template get_multi_ptr<sycl::access::decorated::no>().get();
|
||||
float* s_loc = Stile.template get_multi_ptr<sycl::access::decorated::no>().get();
|
||||
float* p_loc = Ptile.template get_multi_ptr<sycl::access::decorated::no>().get();
|
||||
float* m_loc = m_local.template get_multi_ptr<sycl::access::decorated::no>().get();
|
||||
float* l_loc = l_local.template get_multi_ptr<sycl::access::decorated::no>().get();
|
||||
|
||||
cgh.parallel_for(sycl::nd_range<2>(global, local), [=](sycl::nd_item<2> it) {
|
||||
auto group = it.get_group();
|
||||
int group_id_i = group.get_group_id(0);
|
||||
int group_id_j = group.get_group_id(1);
|
||||
|
||||
|
||||
int row0 = group_id_i * Br;
|
||||
int col0 = group_id_j * Bc;
|
||||
|
||||
if (row0 >= (int) N || col0 >= (int) N) {
|
||||
return;
|
||||
}
|
||||
|
||||
const float* Q_block = Q_d + (ptrdiff_t)row0 * q_row_stride;
|
||||
const float* K_block = K_d + (ptrdiff_t)col0 * k_row_stride;
|
||||
const float* V_block = V_d + (ptrdiff_t)col0 * v_row_stride;
|
||||
float* O_block = dst_d + (ptrdiff_t)row0 * o_row_stride;
|
||||
|
||||
//this lines does not support non-contiguous tensors
|
||||
ggml_sycl_memcpy<Br * DQK>(q_loc, Q_block);
|
||||
ggml_sycl_memcpy<Bc * DQK>(k_loc, K_block);
|
||||
ggml_sycl_memcpy<Bc * DV>(v_loc, V_block);
|
||||
|
||||
it.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
flash_attn_mul_mat_QK_kernel<DQK>(
|
||||
it,
|
||||
Q_block, q_row_stride,
|
||||
K_block, k_row_stride,
|
||||
s_loc, (ptrdiff_t)Bc,
|
||||
Br, Bc
|
||||
);
|
||||
|
||||
it.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
flash_attn_softmax_kernel(
|
||||
it,
|
||||
s_loc, p_loc,
|
||||
m_loc, l_loc,
|
||||
Br, Bc,
|
||||
l_d, m_d
|
||||
);
|
||||
|
||||
it.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
flash_attn_mul_mat_PV_kernel<DV>(
|
||||
it,
|
||||
p_loc, (ptrdiff_t)Bc,
|
||||
V_block, v_row_stride,
|
||||
O_block, o_row_stride,
|
||||
Br,Bc
|
||||
);
|
||||
|
||||
it.barrier(sycl::access::fence_space::local_space);
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
stream->submit([&](sycl::handler& cgh) {
|
||||
const ptrdiff_t o_stride = o_row_stride;
|
||||
|
||||
cgh.parallel_for(sycl::range<1>(N), [=](sycl::id<1> id_row) {
|
||||
int row = id_row[0];
|
||||
float l_val = l_d[row];
|
||||
|
||||
if (l_val <= 0.0f) {
|
||||
return;
|
||||
}
|
||||
|
||||
float inv_l = 1.0f / l_val;
|
||||
float * o_row = dst_d + (ptrdiff_t)row * o_stride;
|
||||
|
||||
for (int col = 0; col < DV; ++col) {
|
||||
o_row[col] *= inv_l;
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
sycl::free(l_d, *stream);
|
||||
sycl::free(m_d, *stream);
|
||||
}
|
||||
|
||||
|
||||
void ggml_sycl_op_flash_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
|
||||
switch (Q->ne[0]) {
|
||||
case 64:
|
||||
GGML_ASSERT(V->ne[0] == 64);
|
||||
ggml_sycl_op_flash_attn_2< 64, 64>(ctx, dst);
|
||||
break;
|
||||
case 80:
|
||||
GGML_ASSERT(V->ne[0] == 80);
|
||||
ggml_sycl_op_flash_attn_2< 80, 80>(ctx, dst);
|
||||
break;
|
||||
case 96:
|
||||
GGML_ASSERT(V->ne[0] == 96);
|
||||
ggml_sycl_op_flash_attn_2< 96, 96>(ctx, dst);
|
||||
break;
|
||||
case 112:
|
||||
GGML_ASSERT(V->ne[0] == 112);
|
||||
ggml_sycl_op_flash_attn_2<112, 112>(ctx, dst);
|
||||
break;
|
||||
case 128:
|
||||
GGML_ASSERT(V->ne[0] == 128);
|
||||
ggml_sycl_op_flash_attn_2<128, 128>(ctx, dst);
|
||||
break;
|
||||
case 256:
|
||||
GGML_ASSERT(V->ne[0] == 256);
|
||||
ggml_sycl_op_flash_attn_2<256, 256>(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
#pragma once
|
||||
#ifndef GGML_SYCL_FATTN_HPP
|
||||
#define GGML_SYCL_FATTN_HPP
|
||||
|
||||
#include "../common.hpp"
|
||||
#include "common.hpp"
|
||||
|
||||
// Flash attention operation for SYCL backend
|
||||
// This implements the Flash Attention algorithm optimized for SYCL devices
|
||||
|
|
@ -8,3 +9,5 @@ void ggml_sycl_op_flash_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
|
|||
|
||||
// Check if flash attention is supported for given tensor
|
||||
bool ggml_sycl_flash_attn_ext_supported(const ggml_tensor * dst);
|
||||
|
||||
#endif // GGML_SYCL_FATTN_HPP
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
#ifndef GGML_SYCL_FATTN_COMMON_HPP
|
||||
#define GGML_SYCL_FATTN_COMMON_HPP
|
||||
|
||||
template<int N>
|
||||
inline void ggml_sycl_memcpy(float* dst, const float* src) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N; ++i) {
|
||||
dst[i] = src[i];
|
||||
}
|
||||
}
|
||||
|
||||
#endif // GGML_SYCL_FATTN_COMMON_HPP
|
||||
|
|
@ -0,0 +1,141 @@
|
|||
#ifndef GGML_SYCL_FATTN_KERNEL_HPP
|
||||
#define GGML_SYCL_FATTN_KERNEL_HPP
|
||||
|
||||
#include <sycl/sycl.hpp>
|
||||
|
||||
template <int64_t QKD>
|
||||
inline void flash_attn_mul_mat_QK_kernel(
|
||||
sycl::nd_item<2> it,
|
||||
const float * Q, ptrdiff_t q_row_stride,
|
||||
const float * K, ptrdiff_t k_row_stride,
|
||||
float * S, ptrdiff_t s_row_stride,
|
||||
const int Br, const int Bc) {
|
||||
|
||||
const int i = it.get_local_id(0);
|
||||
if (i >= Br) {
|
||||
return;
|
||||
}
|
||||
|
||||
const float * q_vec = Q + i * q_row_stride;
|
||||
float * s_row = S + i * s_row_stride;
|
||||
|
||||
for (int j = 0; j < Bc; ++j) {
|
||||
const float * k_vec = K + j * k_row_stride;
|
||||
float score = 0.0f;
|
||||
|
||||
#pragma unroll
|
||||
for (int k = 0; k < QKD; ++k) {
|
||||
score += q_vec[k] * k_vec[k];
|
||||
}
|
||||
|
||||
s_row[j] = score;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
inline void flash_attn_softmax_kernel(
|
||||
sycl::nd_item<2> it,
|
||||
float * S, float * P,
|
||||
float * m_local, float * l_local,
|
||||
const int Br, const int Bc,
|
||||
float * l_d, float * m_d
|
||||
) {
|
||||
const int li = it.get_local_id(0);
|
||||
const int gi = it.get_group(0);
|
||||
const int row = gi * Br + li;
|
||||
|
||||
if (li >= Br) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int row_offset = li * Bc;
|
||||
|
||||
float m_old = m_d[row];
|
||||
float l_old = l_d[row];
|
||||
|
||||
// 2. Block max
|
||||
float m_block = -INFINITY;
|
||||
for (int j = 0; j < Bc; ++j) {
|
||||
const float s_ij = S[row_offset + j];
|
||||
m_block = sycl::fmax(m_block, s_ij);
|
||||
}
|
||||
|
||||
// 3. Block exp-sum
|
||||
float l_block = 0.0f;
|
||||
for (int j = 0; j < Bc; ++j) {
|
||||
const float e = sycl::exp(S[row_offset + j] - m_block);
|
||||
P[row_offset + j] = e; // temporary store
|
||||
l_block += e;
|
||||
}
|
||||
|
||||
// 4. Merge block stats with global (streaming softmax)
|
||||
float m_new;
|
||||
float l_new;
|
||||
|
||||
if (l_old == 0.0f && m_old == -INFINITY) {
|
||||
// first block for this row
|
||||
m_new = m_block;
|
||||
l_new = l_block;
|
||||
} else {
|
||||
m_new = sycl::fmax(m_old, m_block);
|
||||
|
||||
const float alpha = sycl::exp(m_old - m_new);
|
||||
const float beta = sycl::exp(m_block - m_new);
|
||||
|
||||
l_new = alpha * l_old + beta * l_block;
|
||||
}
|
||||
|
||||
// 5. Store updated global stats
|
||||
m_d[row] = m_new;
|
||||
l_d[row] = l_new;
|
||||
|
||||
// 6. Convert local e_ij to global probabilities p_ij
|
||||
float scale_block = 0.0f;
|
||||
if (l_new > 0.0f) {
|
||||
scale_block = sycl::exp(m_block - m_new) / l_new;
|
||||
}
|
||||
|
||||
for (int j = 0; j < Bc; ++j) {
|
||||
P[row_offset + j] *= scale_block;
|
||||
}
|
||||
|
||||
// 7. Optional: keep local copies
|
||||
m_local[li] = m_new;
|
||||
l_local[li] = l_new;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
template <int64_t VD>
|
||||
inline void flash_attn_mul_mat_PV_kernel(
|
||||
sycl::nd_item<2> it,
|
||||
const float * P, ptrdiff_t p_row_stride,
|
||||
const float * V, ptrdiff_t v_row_stride,
|
||||
float * O, ptrdiff_t o_row_stride,
|
||||
const int Br, const int Bc) {
|
||||
|
||||
const int i = it.get_local_id(0);
|
||||
if (i >= Br) {
|
||||
return;
|
||||
}
|
||||
|
||||
const float * p_row = P + i * p_row_stride;
|
||||
float * o_row = O + i * o_row_stride;
|
||||
|
||||
for (int j = 0; j < VD; ++j) {
|
||||
float acc = 0.0f;
|
||||
|
||||
#pragma unroll
|
||||
for (int k = 0; k < Bc; ++k) {
|
||||
const float * v_row = V + k * v_row_stride;
|
||||
acc += p_row[k] * v_row[j];
|
||||
}
|
||||
|
||||
o_row[j] = acc;
|
||||
}
|
||||
}
|
||||
|
||||
#endif // GGML_SYCL_FATTN_KERNEL_HPP
|
||||
|
||||
|
||||
|
|
@ -1,98 +0,0 @@
|
|||
#include "flash-attn-sycl.h"
|
||||
|
||||
#include "kernels/flash-attn-kernel.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <limits>
|
||||
#include <sycl/sycl.hpp>
|
||||
|
||||
#define FLASH_ATTN_BR_MAX 32
|
||||
#define FLASH_ATTN_BC_MAX 32
|
||||
|
||||
// Flash Attention: https://arxiv.org/abs/2205.14135
|
||||
void ggml_sycl_op_flash_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
const ggml_tensor * mask = dst->src[3];
|
||||
|
||||
GGML_ASSERT(Q != nullptr);
|
||||
GGML_ASSERT(K != nullptr);
|
||||
GGML_ASSERT(V != nullptr);
|
||||
GGML_ASSERT(dst != nullptr);
|
||||
|
||||
if (Q->type != GGML_TYPE_F32 || K->type != GGML_TYPE_F32 || V->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) {
|
||||
fprintf(stderr, "[SYCL] FLASH-ATTENTION: tensor type not supported (Q=%d, K=%d, V=%d, dst=%d)\n", Q->type, K->type, V->type, dst->type);
|
||||
return;
|
||||
}
|
||||
|
||||
const float * Q_d = (const float *) Q->data;
|
||||
const float * K_d = (const float *) K->data;
|
||||
const float * V_d = (const float *) V->data;
|
||||
float * dst_d = (float *) dst->data;
|
||||
|
||||
dpct::queue_ptr stream = ctx.stream();
|
||||
|
||||
const int64_t d = Q->ne[0];
|
||||
const int64_t N = Q->ne[1];
|
||||
|
||||
float scale;
|
||||
float max_bias;
|
||||
float logit_softcap;
|
||||
|
||||
std::memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
|
||||
std::memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
|
||||
std::memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float));
|
||||
|
||||
const bool masked = (mask != nullptr);
|
||||
|
||||
const int Br = std::min((int) FLASH_ATTN_BR_MAX, (int) N);
|
||||
const int Bc = std::min((int) FLASH_ATTN_BC_MAX, (int) N);
|
||||
|
||||
const int Tr = (N + Br - 1) / Br;
|
||||
const int Tc = (N + Bc - 1) / Bc;
|
||||
|
||||
float * l_d = (float *) sycl::malloc_device(N * sizeof(float), *stream);
|
||||
float * m_d = (float *) sycl::malloc_device(N * sizeof(float), *stream);
|
||||
|
||||
stream->fill(l_d, 0.0f, N);
|
||||
stream->fill(m_d, -std::numeric_limits<float>::infinity(), N);
|
||||
stream->fill(dst_d, 0.0f, N * d);
|
||||
stream->wait();
|
||||
|
||||
for (int j = 0; j < Tc; ++j) {
|
||||
stream->submit([&](sycl::handler & cgh) {
|
||||
cgh.parallel_for(sycl::range<1>(Tr), [=](sycl::id<1> idx) {
|
||||
const int i = idx[0];
|
||||
flash_attn_tiled_kernel<FLASH_ATTN_BR_MAX, FLASH_ATTN_BC_MAX>(Q_d, K_d, V_d, dst_d, l_d, m_d, i, j, Br,
|
||||
Bc, N, d, masked, scale);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
stream->wait();
|
||||
|
||||
sycl::free(l_d, *stream);
|
||||
sycl::free(m_d, *stream);
|
||||
}
|
||||
|
||||
bool ggml_sycl_flash_attn_ext_supported(const ggml_tensor * dst) {
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
|
||||
if (Q == nullptr || K == nullptr || V == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (Q->type != GGML_TYPE_F32 || K->type != GGML_TYPE_F32 || V->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (dst->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
|
@ -1,108 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include <sycl/sycl.hpp>
|
||||
|
||||
template <int Br_MAX = 32, int Bc_MAX = 32>
|
||||
inline void flash_attn_tiled_kernel(const float * Q,
|
||||
const float * K,
|
||||
const float * V,
|
||||
float * O,
|
||||
float * l,
|
||||
float * m,
|
||||
const int i_block,
|
||||
const int j_block,
|
||||
const int Br,
|
||||
const int Bc,
|
||||
const int N,
|
||||
const int d,
|
||||
const bool masked,
|
||||
const float scale) {
|
||||
const int i_start = i_block * Br;
|
||||
const int j_start = j_block * Bc;
|
||||
|
||||
float S[Br_MAX][Bc_MAX];
|
||||
float P[Br_MAX][Bc_MAX];
|
||||
float m_local[Br_MAX];
|
||||
float l_local[Br_MAX];
|
||||
|
||||
for (int qi = 0; qi < Br; ++qi) {
|
||||
const int q_row = i_start + qi;
|
||||
if (q_row >= N) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (int kj = 0; kj < Bc; ++kj) {
|
||||
const int k_row = j_start + kj;
|
||||
if (k_row >= N) {
|
||||
S[qi][kj] = -INFINITY;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (masked && k_row > q_row) {
|
||||
S[qi][kj] = -INFINITY;
|
||||
continue;
|
||||
}
|
||||
|
||||
float score = 0.0f;
|
||||
for (int k = 0; k < d; ++k) {
|
||||
score += Q[q_row * d + k] * K[k_row * d + k];
|
||||
}
|
||||
S[qi][kj] = score * scale;
|
||||
}
|
||||
}
|
||||
|
||||
for (int qi = 0; qi < Br; ++qi) {
|
||||
const int q_row = i_start + qi;
|
||||
if (q_row >= N) {
|
||||
continue;
|
||||
}
|
||||
|
||||
m_local[qi] = -INFINITY;
|
||||
for (int kj = 0; kj < Bc; ++kj) {
|
||||
if (j_start + kj < N) {
|
||||
m_local[qi] = sycl::fmax(m_local[qi], S[qi][kj]);
|
||||
}
|
||||
}
|
||||
|
||||
l_local[qi] = 0.0f;
|
||||
for (int kj = 0; kj < Bc; ++kj) {
|
||||
if (j_start + kj < N && !sycl::isinf(S[qi][kj])) {
|
||||
P[qi][kj] = sycl::exp(S[qi][kj] - m_local[qi]);
|
||||
l_local[qi] += P[qi][kj];
|
||||
} else {
|
||||
P[qi][kj] = 0.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int qi = 0; qi < Br; ++qi) {
|
||||
const int q_row = i_start + qi;
|
||||
if (q_row >= N) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const float m_old = m[q_row];
|
||||
const float m_new = sycl::fmax(m_old, m_local[qi]);
|
||||
const float l_old = l[q_row];
|
||||
const float l_new = sycl::exp(m_old - m_new) * l_old + sycl::exp(m_local[qi] - m_new) * l_local[qi];
|
||||
|
||||
const float correction_old = sycl::exp(m_old - m_new);
|
||||
const float correction_new = sycl::exp(m_local[qi] - m_new);
|
||||
|
||||
for (int k = 0; k < d; ++k) {
|
||||
float pv = 0.0f;
|
||||
for (int kj = 0; kj < Bc; ++kj) {
|
||||
const int v_row = j_start + kj;
|
||||
if (v_row < N) {
|
||||
pv += P[qi][kj] * V[v_row * d + k];
|
||||
}
|
||||
}
|
||||
|
||||
const int o_idx = q_row * d + k;
|
||||
O[o_idx] = (correction_old * O[o_idx] + correction_new * pv) / l_new;
|
||||
}
|
||||
|
||||
l[q_row] = l_new;
|
||||
m[q_row] = m_new;
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue