Merge 15c48eb069 into 58062860af
This commit is contained in:
commit
529353ce83
|
|
@ -27,6 +27,15 @@ file(GLOB GGML_HEADERS_SYCL "*.hpp")
|
||||||
file(GLOB GGML_SOURCES_SYCL "*.cpp")
|
file(GLOB GGML_SOURCES_SYCL "*.cpp")
|
||||||
target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL} ${GGML_SOURCES_SYCL})
|
target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL} ${GGML_SOURCES_SYCL})
|
||||||
|
|
||||||
|
# Include flash-attn sources (SYCL optimized flash attention implementation)
|
||||||
|
file(GLOB GGML_HEADERS_SYCL_FLASH "fattn*.h" "fattn*.hpp")
|
||||||
|
file(GLOB GGML_SOURCES_SYCL_FLASH "fattn*.cpp" "fattn*.c")
|
||||||
|
target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL_FLASH} ${GGML_SOURCES_SYCL_FLASH})
|
||||||
|
|
||||||
|
# Also include kernel headers under flash-attn/kernels
|
||||||
|
file(GLOB GGML_HEADERS_SYCL_FLASH_KERNELS "fattn_kernel*.h" "fattn_kernel*.hpp")
|
||||||
|
target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL_FLASH_KERNELS})
|
||||||
|
|
||||||
if (WIN32)
|
if (WIN32)
|
||||||
# To generate a Visual Studio solution, using Intel C++ Compiler for ggml-sycl is mandatory
|
# To generate a Visual Studio solution, using Intel C++ Compiler for ggml-sycl is mandatory
|
||||||
if( ${CMAKE_GENERATOR} MATCHES "Visual Studio" AND NOT (${CMAKE_GENERATOR_TOOLSET} MATCHES "Intel C"))
|
if( ${CMAKE_GENERATOR} MATCHES "Visual Studio" AND NOT (${CMAKE_GENERATOR_TOOLSET} MATCHES "Intel C"))
|
||||||
|
|
|
||||||
|
|
@ -40,6 +40,7 @@
|
||||||
#include "tsembd.hpp"
|
#include "tsembd.hpp"
|
||||||
#include "wkv.hpp"
|
#include "wkv.hpp"
|
||||||
#include "pad_reflect_1d.hpp"
|
#include "pad_reflect_1d.hpp"
|
||||||
|
#include "fattn.hpp"
|
||||||
|
|
||||||
|
|
||||||
#endif // GGML_SYCL_BACKEND_HPP
|
#endif // GGML_SYCL_BACKEND_HPP
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,236 @@
|
||||||
|
#include "./fattn.hpp"
|
||||||
|
#include "./fattn_kernel.hpp"
|
||||||
|
#include "./fattn_common.hpp"
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
#include <cstring>
|
||||||
|
#include <limits>
|
||||||
|
#include <sycl/sycl.hpp>
|
||||||
|
|
||||||
|
#define Br 32
|
||||||
|
#define Bc 32
|
||||||
|
|
||||||
|
|
||||||
|
bool ggml_sycl_flash_attn_ext_supported(const ggml_tensor * dst) {
|
||||||
|
const ggml_tensor * Q = dst->src[0];
|
||||||
|
const ggml_tensor * K = dst->src[1];
|
||||||
|
const ggml_tensor * V = dst->src[2];
|
||||||
|
const ggml_tensor * mask = dst->src[3];
|
||||||
|
|
||||||
|
float scale, max_bias, logit_softcap;
|
||||||
|
|
||||||
|
std::memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
|
||||||
|
std::memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
|
||||||
|
std::memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float));
|
||||||
|
|
||||||
|
if( max_bias != 0.0f || logit_softcap != 0.0f){
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (Q == nullptr || K == nullptr || V == nullptr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (mask != 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (Q->type != GGML_TYPE_F32 || K->type != GGML_TYPE_F32 || V->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t DQK = Q->ne[0];
|
||||||
|
int64_t DV = V->ne[0];
|
||||||
|
|
||||||
|
if (DQK != DV){
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (DV != 32 && DV != 64 && DV != 80 && DV != 96 && DV != 112 && DV != 128 && DV != 256 && DV != 512){
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
//not support multi-head yet
|
||||||
|
if (Q->ne[2] != 1 || K->ne[2] != 1 || V->ne[2] != 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<int64_t DQK, int64_t DV>
|
||||||
|
void ggml_sycl_op_flash_attn_2(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||||
|
const ggml_tensor * Q = dst->src[0];
|
||||||
|
const ggml_tensor * K = dst->src[1];
|
||||||
|
const ggml_tensor * V = dst->src[2];
|
||||||
|
|
||||||
|
const float * Q_d = (const float *) Q->data;
|
||||||
|
const float * K_d = (const float *) K->data;
|
||||||
|
const float * V_d = (const float *) V->data;
|
||||||
|
float * dst_d = (float *) dst->data;
|
||||||
|
|
||||||
|
dpct::queue_ptr stream = ctx.stream();
|
||||||
|
|
||||||
|
const int64_t N = Q->ne[1];
|
||||||
|
|
||||||
|
const ptrdiff_t q_row_stride = Q->nb[1] / (ptrdiff_t)sizeof(float);
|
||||||
|
const ptrdiff_t k_row_stride = K->nb[1] / (ptrdiff_t)sizeof(float);
|
||||||
|
const ptrdiff_t v_row_stride = V->nb[1] / (ptrdiff_t)sizeof(float);
|
||||||
|
const ptrdiff_t o_row_stride = dst->nb[1] / (ptrdiff_t)sizeof(float);
|
||||||
|
|
||||||
|
// const int Br = std::min((int) FLASH_ATTN_BR_MAX, (int) N);
|
||||||
|
// const int Bc = std::min((int) FLASH_ATTN_BC_MAX, (int) N);
|
||||||
|
|
||||||
|
const int Tr = (N + Br - 1) / Br;
|
||||||
|
const int Tc = (N + Bc - 1) / Bc;
|
||||||
|
|
||||||
|
float * l_d = (float *) sycl::malloc_device(N * sizeof(float), *stream);
|
||||||
|
float * m_d = (float *) sycl::malloc_device(N * sizeof(float), *stream);
|
||||||
|
|
||||||
|
sycl::range<2> global(Br * Tr, Tc);
|
||||||
|
sycl::range<2> local(Br,1);
|
||||||
|
|
||||||
|
stream->submit([&](sycl::handler& cgh) {
|
||||||
|
sycl::local_accessor<float, 2> Qtile({Br, DQK}, cgh);
|
||||||
|
sycl::local_accessor<float, 2> Ktile({Bc, DQK}, cgh);
|
||||||
|
sycl::local_accessor<float, 2> Vtile({Bc, DV}, cgh);
|
||||||
|
sycl::local_accessor<float, 2> Stile({Br, Bc}, cgh);
|
||||||
|
sycl::local_accessor<float, 1> Ptile({Br * Bc}, cgh);
|
||||||
|
sycl::local_accessor<float, 1> m_local({Br}, cgh);
|
||||||
|
sycl::local_accessor<float, 1> l_local({Br}, cgh);
|
||||||
|
|
||||||
|
cgh.parallel_for(sycl::nd_range<2>(global, local), [=](sycl::nd_item<2> it) {
|
||||||
|
|
||||||
|
float* q_loc = Qtile.template get_multi_ptr<sycl::access::decorated::no>().get();
|
||||||
|
float* k_loc = Ktile.template get_multi_ptr<sycl::access::decorated::no>().get();
|
||||||
|
float* v_loc = Vtile.template get_multi_ptr<sycl::access::decorated::no>().get();
|
||||||
|
float* s_loc = Stile.template get_multi_ptr<sycl::access::decorated::no>().get();
|
||||||
|
float* p_loc = Ptile.template get_multi_ptr<sycl::access::decorated::no>().get();
|
||||||
|
float* m_loc = m_local.template get_multi_ptr<sycl::access::decorated::no>().get();
|
||||||
|
float* l_loc = l_local.template get_multi_ptr<sycl::access::decorated::no>().get();
|
||||||
|
|
||||||
|
auto group = it.get_group();
|
||||||
|
int group_id_i = group.get_group_id(0);
|
||||||
|
int group_id_j = group.get_group_id(1);
|
||||||
|
|
||||||
|
|
||||||
|
int row0 = group_id_i * Br;
|
||||||
|
int col0 = group_id_j * Bc;
|
||||||
|
|
||||||
|
if (row0 >= (int) N || col0 >= (int) N) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const float* Q_block = Q_d + (ptrdiff_t)row0 * q_row_stride;
|
||||||
|
const float* K_block = K_d + (ptrdiff_t)col0 * k_row_stride;
|
||||||
|
const float* V_block = V_d + (ptrdiff_t)col0 * v_row_stride;
|
||||||
|
float* O_block = dst_d + (ptrdiff_t)row0 * o_row_stride;
|
||||||
|
|
||||||
|
//this lines does not support non-contiguous tensors
|
||||||
|
ggml_sycl_memcpy<Br * DQK>(q_loc, Q_block);
|
||||||
|
ggml_sycl_memcpy<Bc * DQK>(k_loc, K_block);
|
||||||
|
ggml_sycl_memcpy<Bc * DV>(v_loc, V_block);
|
||||||
|
|
||||||
|
it.barrier(sycl::access::fence_space::local_space);
|
||||||
|
|
||||||
|
flash_attn_mul_mat_QK_kernel<DQK>(
|
||||||
|
it,
|
||||||
|
Q_block, q_row_stride,
|
||||||
|
K_block, k_row_stride,
|
||||||
|
s_loc, (ptrdiff_t)Bc,
|
||||||
|
Br, Bc
|
||||||
|
);
|
||||||
|
|
||||||
|
it.barrier(sycl::access::fence_space::local_space);
|
||||||
|
|
||||||
|
flash_attn_softmax_kernel(
|
||||||
|
it,
|
||||||
|
s_loc, p_loc,
|
||||||
|
m_loc, l_loc,
|
||||||
|
Br, Bc,
|
||||||
|
l_d, m_d
|
||||||
|
);
|
||||||
|
|
||||||
|
it.barrier(sycl::access::fence_space::local_space);
|
||||||
|
|
||||||
|
flash_attn_mul_mat_PV_kernel<DV>(
|
||||||
|
it,
|
||||||
|
p_loc, (ptrdiff_t)Bc,
|
||||||
|
V_block, v_row_stride,
|
||||||
|
O_block, o_row_stride,
|
||||||
|
Br,Bc
|
||||||
|
);
|
||||||
|
|
||||||
|
it.barrier(sycl::access::fence_space::local_space);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
|
||||||
|
stream->submit([&](sycl::handler& cgh) {
|
||||||
|
const ptrdiff_t o_stride = o_row_stride;
|
||||||
|
|
||||||
|
cgh.parallel_for(sycl::range<1>(N), [=](sycl::id<1> id_row) {
|
||||||
|
int row = id_row[0];
|
||||||
|
float l_val = l_d[row];
|
||||||
|
|
||||||
|
if (l_val <= 0.0f) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
float inv_l = 1.0f / l_val;
|
||||||
|
float * o_row = dst_d + (ptrdiff_t)row * o_stride;
|
||||||
|
|
||||||
|
for (int col = 0; col < DV; ++col) {
|
||||||
|
o_row[col] *= inv_l;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
|
||||||
|
sycl::free(l_d, *stream);
|
||||||
|
sycl::free(m_d, *stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_sycl_op_flash_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||||
|
const ggml_tensor * Q = dst->src[0];
|
||||||
|
const ggml_tensor * V = dst->src[2];
|
||||||
|
|
||||||
|
switch (Q->ne[0]) {
|
||||||
|
case 32:
|
||||||
|
GGML_ASSERT(V->ne[0] == 32);
|
||||||
|
ggml_sycl_op_flash_attn_2< 32, 32>(ctx, dst);
|
||||||
|
break;
|
||||||
|
case 64:
|
||||||
|
GGML_ASSERT(V->ne[0] == 64);
|
||||||
|
ggml_sycl_op_flash_attn_2< 64, 64>(ctx, dst);
|
||||||
|
break;
|
||||||
|
case 80:
|
||||||
|
GGML_ASSERT(V->ne[0] == 80);
|
||||||
|
ggml_sycl_op_flash_attn_2< 80, 80>(ctx, dst);
|
||||||
|
break;
|
||||||
|
case 96:
|
||||||
|
GGML_ASSERT(V->ne[0] == 96);
|
||||||
|
ggml_sycl_op_flash_attn_2< 96, 96>(ctx, dst);
|
||||||
|
break;
|
||||||
|
case 112:
|
||||||
|
GGML_ASSERT(V->ne[0] == 112);
|
||||||
|
ggml_sycl_op_flash_attn_2<112, 112>(ctx, dst);
|
||||||
|
break;
|
||||||
|
case 128:
|
||||||
|
GGML_ASSERT(V->ne[0] == 128);
|
||||||
|
ggml_sycl_op_flash_attn_2<128, 128>(ctx, dst);
|
||||||
|
break;
|
||||||
|
case 256:
|
||||||
|
GGML_ASSERT(V->ne[0] == 256);
|
||||||
|
ggml_sycl_op_flash_attn_2<256, 256>(ctx, dst);
|
||||||
|
break;
|
||||||
|
case 576:
|
||||||
|
GGML_ASSERT(V->ne[0] == 512);
|
||||||
|
ggml_sycl_op_flash_attn_2<512, 512>(ctx, dst);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
fprintf(stderr, "Warning: Unsupported head size %ld — skipping op\n", Q->ne[0]);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
@ -0,0 +1,13 @@
|
||||||
|
#ifndef GGML_SYCL_FATTN_HPP
|
||||||
|
#define GGML_SYCL_FATTN_HPP
|
||||||
|
|
||||||
|
#include "common.hpp"
|
||||||
|
|
||||||
|
// Flash attention operation for SYCL backend
|
||||||
|
// This implements the Flash Attention algorithm optimized for SYCL devices
|
||||||
|
void ggml_sycl_op_flash_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
|
// Check if flash attention is supported for given tensor
|
||||||
|
bool ggml_sycl_flash_attn_ext_supported(const ggml_tensor * dst);
|
||||||
|
|
||||||
|
#endif // GGML_SYCL_FATTN_HPP
|
||||||
|
|
@ -0,0 +1,12 @@
|
||||||
|
#ifndef GGML_SYCL_FATTN_COMMON_HPP
|
||||||
|
#define GGML_SYCL_FATTN_COMMON_HPP
|
||||||
|
|
||||||
|
template<int N>
|
||||||
|
inline void ggml_sycl_memcpy(float* dst, const float* src) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < N; ++i) {
|
||||||
|
dst[i] = src[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // GGML_SYCL_FATTN_COMMON_HPP
|
||||||
|
|
@ -0,0 +1,141 @@
|
||||||
|
#ifndef GGML_SYCL_FATTN_KERNEL_HPP
|
||||||
|
#define GGML_SYCL_FATTN_KERNEL_HPP
|
||||||
|
|
||||||
|
#include <sycl/sycl.hpp>
|
||||||
|
|
||||||
|
template <int64_t QKD>
|
||||||
|
inline void flash_attn_mul_mat_QK_kernel(
|
||||||
|
sycl::nd_item<2> it,
|
||||||
|
const float * Q, ptrdiff_t q_row_stride,
|
||||||
|
const float * K, ptrdiff_t k_row_stride,
|
||||||
|
float * S, ptrdiff_t s_row_stride,
|
||||||
|
const int Br, const int Bc) {
|
||||||
|
|
||||||
|
const int i = it.get_local_id(0);
|
||||||
|
if (i >= Br) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const float * q_vec = Q + i * q_row_stride;
|
||||||
|
float * s_row = S + i * s_row_stride;
|
||||||
|
|
||||||
|
for (int j = 0; j < Bc; ++j) {
|
||||||
|
const float * k_vec = K + j * k_row_stride;
|
||||||
|
float score = 0.0f;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int k = 0; k < QKD; ++k) {
|
||||||
|
score += q_vec[k] * k_vec[k];
|
||||||
|
}
|
||||||
|
|
||||||
|
s_row[j] = score;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
inline void flash_attn_softmax_kernel(
|
||||||
|
sycl::nd_item<2> it,
|
||||||
|
float * S, float * P,
|
||||||
|
float * m_local, float * l_local,
|
||||||
|
const int Br, const int Bc,
|
||||||
|
float * l_d, float * m_d
|
||||||
|
) {
|
||||||
|
const int li = it.get_local_id(0);
|
||||||
|
const int gi = it.get_group(0);
|
||||||
|
const int row = gi * Br + li;
|
||||||
|
|
||||||
|
if (li >= Br) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int row_offset = li * Bc;
|
||||||
|
|
||||||
|
float m_old = m_d[row];
|
||||||
|
float l_old = l_d[row];
|
||||||
|
|
||||||
|
// Block max
|
||||||
|
float m_block = -INFINITY;
|
||||||
|
for (int j = 0; j < Bc; ++j) {
|
||||||
|
const float s_ij = S[row_offset + j];
|
||||||
|
m_block = sycl::fmax(m_block, s_ij);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Block exp-sum
|
||||||
|
float l_block = 0.0f;
|
||||||
|
for (int j = 0; j < Bc; ++j) {
|
||||||
|
const float e = sycl::exp(S[row_offset + j] - m_block);
|
||||||
|
P[row_offset + j] = e; // temporary store
|
||||||
|
l_block += e;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merge block stats with global (streaming softmax)
|
||||||
|
float m_new;
|
||||||
|
float l_new;
|
||||||
|
|
||||||
|
if (l_old == 0.0f && m_old == -INFINITY) {
|
||||||
|
// first block for this row
|
||||||
|
m_new = m_block;
|
||||||
|
l_new = l_block;
|
||||||
|
} else {
|
||||||
|
m_new = sycl::fmax(m_old, m_block);
|
||||||
|
|
||||||
|
const float alpha = sycl::exp(m_old - m_new);
|
||||||
|
const float beta = sycl::exp(m_block - m_new);
|
||||||
|
|
||||||
|
l_new = alpha * l_old + beta * l_block;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store updated global stats
|
||||||
|
m_d[row] = m_new;
|
||||||
|
l_d[row] = l_new;
|
||||||
|
|
||||||
|
// Convert local e_ij to global probabilities p_ij
|
||||||
|
float scale_block = 0.0f;
|
||||||
|
if (l_new > 0.0f) {
|
||||||
|
scale_block = sycl::exp(m_block - m_new) / l_new;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int j = 0; j < Bc; ++j) {
|
||||||
|
P[row_offset + j] *= scale_block;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Optional: keep local copies
|
||||||
|
m_local[li] = m_new;
|
||||||
|
l_local[li] = l_new;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template <int64_t VD>
|
||||||
|
inline void flash_attn_mul_mat_PV_kernel(
|
||||||
|
sycl::nd_item<2> it,
|
||||||
|
const float * P, ptrdiff_t p_row_stride,
|
||||||
|
const float * V, ptrdiff_t v_row_stride,
|
||||||
|
float * O, ptrdiff_t o_row_stride,
|
||||||
|
const int Br, const int Bc) {
|
||||||
|
|
||||||
|
const int i = it.get_local_id(0);
|
||||||
|
if (i >= Br) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const float * p_row = P + i * p_row_stride;
|
||||||
|
float * o_row = O + i * o_row_stride;
|
||||||
|
|
||||||
|
for (int j = 0; j < VD; ++j) {
|
||||||
|
float acc = 0.0f;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int k = 0; k < Bc; ++k) {
|
||||||
|
const float * v_row = V + k * v_row_stride;
|
||||||
|
acc += p_row[k] * v_row[j];
|
||||||
|
}
|
||||||
|
|
||||||
|
o_row[j] = acc;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // GGML_SYCL_FATTN_KERNEL_HPP
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -3927,6 +3927,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
||||||
case GGML_OP_ARGSORT:
|
case GGML_OP_ARGSORT:
|
||||||
ggml_sycl_argsort(ctx, dst);
|
ggml_sycl_argsort(ctx, dst);
|
||||||
break;
|
break;
|
||||||
|
case GGML_OP_FLASH_ATTN_EXT:
|
||||||
|
ggml_sycl_op_flash_attn(ctx, dst);
|
||||||
|
break;
|
||||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||||
ggml_sycl_op_timestep_embedding(ctx, dst);
|
ggml_sycl_op_timestep_embedding(ctx, dst);
|
||||||
break;
|
break;
|
||||||
|
|
@ -4617,6 +4620,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||||
case GGML_OP_SUM_ROWS:
|
case GGML_OP_SUM_ROWS:
|
||||||
case GGML_OP_MEAN:
|
case GGML_OP_MEAN:
|
||||||
return ggml_is_contiguous(op->src[0]);
|
return ggml_is_contiguous(op->src[0]);
|
||||||
|
case GGML_OP_FLASH_ATTN_EXT:
|
||||||
|
return ggml_sycl_flash_attn_ext_supported(op);
|
||||||
case GGML_OP_ARGSORT:
|
case GGML_OP_ARGSORT:
|
||||||
return op->src[0]->ne[0] * sizeof(int) <=
|
return op->src[0]->ne[0] * sizeof(int) <=
|
||||||
ggml_sycl_info().devices[device].smpbo;
|
ggml_sycl_info().devices[device].smpbo;
|
||||||
|
|
|
||||||
|
|
@ -8162,6 +8162,14 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
for (int kv : { 4096, 8192, 16384, }) {
|
||||||
|
for (int hs : { 64, 128, }) {
|
||||||
|
test_cases.emplace_back(new test_flash_attn_ext(hs, hs, 1, {1, 1}, kv, 1, false, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F32));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, false));
|
test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, false));
|
||||||
test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, true));
|
test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, true));
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue