add op gated_delta_net (#20455)

This commit is contained in:
Neo Zhang 2026-03-14 22:01:57 +08:00 committed by GitHub
parent 710878a7dd
commit a93c0ef0fa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 354 additions and 28 deletions

View File

@ -15,7 +15,7 @@ Legend:
| Operation | BLAS | CANN | CPU | CUDA | Metal | OpenCL | SYCL | Vulkan | WebGPU | ZenDNN | zDNN |
|-----------|------|------|------|------|------|------|------|------|------|------|------|
| ABS | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| ACC | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | | ✅ | ❌ | ❌ | ❌ |
| ACC | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ |
| ADD | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| ADD_ID | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
@ -47,7 +47,7 @@ Legend:
| FILL | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
| FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| GATED_DELTA_NET | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | | ❌ | ❌ | ❌ | ❌ |
| GATED_DELTA_NET | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | | ❌ | ❌ | ❌ | ❌ |
| GATED_LINEAR_ATTN | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
| GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |

View File

@ -6841,10 +6841,6 @@
"SYCL0","MUL_MAT","type_a=f16,type_b=f32,m=1056,n=1,k=193,bs=[1,1],nr=[4,1],per=[0,2,1,3],k_v=0,o=1","support","1","yes","SYCL"
"SYCL0","MUL_MAT","type_a=f16,type_b=f32,m=1056,n=1,k=67,bs=[1,1],nr=[4,1],per=[0,2,1,3],k_v=0,o=1","support","1","yes","SYCL"
"SYCL0","MUL_MAT","type_a=f32,type_b=f32,m=64,n=77,k=77,bs=[12,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1","support","1","yes","SYCL"
"SYCL0","MUL_MAT","type_a=f16,type_b=f32,m=2,n=1,k=3,bs=[128,1024],nr=[1,1],per=[0,1,2,3],k_v=0,o=1","support","1","yes","SYCL"
"SYCL0","MUL_MAT","type_a=f16,type_b=f32,m=2,n=3,k=4,bs=[128,1024],nr=[1,1],per=[0,1,2,3],k_v=0,o=1","support","1","yes","SYCL"
"SYCL0","MUL_MAT","type_a=f16,type_b=f32,m=2,n=1,k=3,bs=[131072,1],nr=[1,1],per=[0,2,1,3],k_v=0,o=1","support","1","yes","SYCL"
"SYCL0","MUL_MAT","type_a=f16,type_b=f32,m=2,n=1,k=3,bs=[131072,1],nr=[1,1],per=[0,1,2,3],k_v=64,o=1","support","1","yes","SYCL"
"SYCL0","MUL_MAT","type_a=q4_0,type_b=f32,m=576,n=512,k=576,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1","support","1","yes","SYCL"
"SYCL0","MUL_MAT","type_a=q4_0,type_b=f32,m=1,n=2048,k=8192,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1","support","1","yes","SYCL"
"SYCL0","MUL_MAT","type_a=f32,type_b=f32,m=1,n=64,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1","support","1","yes","SYCL"
@ -10261,8 +10257,8 @@
"SYCL0","ACC","type=f32,ne_a=[256,17,1,1],ne_b=[256,16,1,1],stride_dim=-1","support","1","yes","SYCL"
"SYCL0","ACC","type=f32,ne_a=[256,17,2,3],ne_b=[256,16,2,3],stride_dim=-1","support","1","yes","SYCL"
"SYCL0","ACC","type=f32,ne_a=[256,17,2,3],ne_b=[128,16,2,3],stride_dim=-1","support","1","yes","SYCL"
"SYCL0","ACC","type=f32,ne_a=[256,17,2,3],ne_b=[256,16,2,3],stride_dim=1","support","1","yes","SYCL"
"SYCL0","ACC","type=f32,ne_a=[256,17,2,3],ne_b=[128,16,2,3],stride_dim=2","support","1","yes","SYCL"
"SYCL0","ACC","type=f32,ne_a=[256,17,2,3],ne_b=[256,16,2,3],stride_dim=1","support","0","no","SYCL"
"SYCL0","ACC","type=f32,ne_a=[256,17,2,3],ne_b=[128,16,2,3],stride_dim=2","support","0","no","SYCL"
"SYCL0","ACC","type=f32,ne_a=[256,17,2,3],ne_b=[64,16,2,3],stride_dim=3","support","1","yes","SYCL"
"SYCL0","PAD","type=f32,ne_a=[512,512,1,1],pad_0=1,pad_1=1,circular=0","support","1","yes","SYCL"
"SYCL0","PAD","type=f32,ne_a=[33,17,2,1],pad_0=4,pad_1=3,circular=1","support","0","no","SYCL"
@ -13591,16 +13587,21 @@
"SYCL0","CROSS_ENTROPY_LOSS_BACK","type=f32,ne=[30000,1,1,1]","support","0","no","SYCL"
"SYCL0","OPT_STEP_ADAMW","type=f32,ne=[10,5,4,3]","support","0","no","SYCL"
"SYCL0","OPT_STEP_SGD","type=f32,ne=[10,5,4,3]","support","0","no","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=32,head_size=128,n_seq_tokens=1,n_seqs=1,v_repeat=1,permuted=0,kda=0","support","0","no","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=16,head_size=64,n_seq_tokens=1,n_seqs=2,v_repeat=1,permuted=0,kda=0","support","0","no","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=1,v_repeat=1,permuted=0,kda=0","support","0","no","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=2,v_repeat=1,permuted=0,kda=0","support","0","no","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=8,head_size=32,n_seq_tokens=4,n_seqs=2,v_repeat=2,permuted=0,kda=0","support","0","no","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=2,v_repeat=1,permuted=1,kda=0","support","0","no","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=1,v_repeat=1,permuted=1,kda=0","support","0","no","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=1,n_seqs=1,v_repeat=1,permuted=0,kda=1","support","0","no","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=1,n_seqs=2,v_repeat=1,permuted=0,kda=1","support","0","no","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=32,n_seq_tokens=4,n_seqs=1,v_repeat=1,permuted=0,kda=1","support","0","no","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=2,v_repeat=1,permuted=0,kda=1","support","0","no","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=8,head_size=32,n_seq_tokens=4,n_seqs=2,v_repeat=2,permuted=0,kda=1","support","0","no","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=2,v_repeat=1,permuted=1,kda=1","support","0","no","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=32,head_size=128,n_seq_tokens=1,n_seqs=1,v_repeat=1,permuted=0,kda=0","support","1","yes","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=32,head_size=16,n_seq_tokens=1,n_seqs=1,v_repeat=1,permuted=0,kda=0","support","1","yes","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=32,head_size=16,n_seq_tokens=1,n_seqs=1,v_repeat=1,permuted=1,kda=1","support","1","yes","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=32,head_size=16,n_seq_tokens=1,n_seqs=1,v_repeat=1,permuted=0,kda=1","support","1","yes","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=16,head_size=64,n_seq_tokens=1,n_seqs=2,v_repeat=1,permuted=0,kda=0","support","1","yes","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=1,v_repeat=1,permuted=0,kda=0","support","1","yes","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=2,v_repeat=1,permuted=0,kda=0","support","1","yes","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=8,head_size=32,n_seq_tokens=4,n_seqs=2,v_repeat=2,permuted=0,kda=0","support","1","yes","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=2,v_repeat=1,permuted=1,kda=0","support","1","yes","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=1,v_repeat=1,permuted=1,kda=0","support","1","yes","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=1,n_seqs=1,v_repeat=1,permuted=0,kda=1","support","1","yes","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=1,n_seqs=2,v_repeat=1,permuted=0,kda=1","support","1","yes","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=16,n_seq_tokens=1,n_seqs=2,v_repeat=1,permuted=0,kda=1","support","1","yes","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=32,n_seq_tokens=4,n_seqs=1,v_repeat=1,permuted=0,kda=1","support","1","yes","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=2,v_repeat=1,permuted=0,kda=1","support","1","yes","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=8,head_size=32,n_seq_tokens=4,n_seqs=2,v_repeat=2,permuted=0,kda=1","support","1","yes","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=2,v_repeat=1,permuted=1,kda=1","support","1","yes","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=16,n_seq_tokens=4,n_seqs=2,v_repeat=1,permuted=1,kda=1","support","1","yes","SYCL"

Can't render this file because it is too large.

View File

@ -211,7 +211,7 @@ struct sycl_device_info {
// number of compute units on a SYCL device.
// size_t smpb; // max. shared memory per block
size_t smpbo; // max. shared memory per block (with opt-in)
int warp_size; // max sub_group_size of SYCL
int warp_size; // WARP_SIZE(16)|WARP_32_SIZE(32)|WARP_16_SIZE(16). For Intel GPU, 16 is better in most cases. Some OP support 32 only.
int max_wg_per_cu; // max work groups per compute unit - refer to
// cudaOccupancyMaxActiveBlocksPerMultiprocessor
bool vmm; // virtual memory support

View File

@ -0,0 +1,309 @@
#include <sycl/sycl.hpp>
#include "dpct/helper.hpp"
#include "common.hpp"
#include "ggml.h"
#include "gated_delta_net.hpp"
#include <cmath>
template <int S_v, bool KDA>
void gated_delta_net_sycl(const float * q,
const float * k,
const float * v,
const float * g,
const float * beta,
const float * curr_state,
float * dst,
int64_t H,
int64_t n_tokens,
int64_t n_seqs,
int64_t sq1,
int64_t sq2,
int64_t sq3,
int64_t sv1,
int64_t sv2,
int64_t sv3,
int64_t sb1,
int64_t sb2,
int64_t sb3,
const sycl::uint3 neqk1_magic,
const sycl::uint3 rq3_magic,
float scale) {
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
const uint32_t h_idx = item_ct1.get_group(2);
const uint32_t sequence = item_ct1.get_group(1);
// each warp owns one column, using warp-level primitives to reduce across rows
const int lane = item_ct1.get_local_id(2);
const int col = item_ct1.get_group(0) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1);
const uint32_t iq1 = fastmodulo(h_idx, neqk1_magic);
const uint32_t iq3 = fastdiv(sequence, rq3_magic);
const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs;
float * attn_data = dst;
float * state = dst + attn_score_elems;
const int64_t state_offset = (sequence * H + h_idx) * S_v * S_v;
state += state_offset;
curr_state += state_offset;
attn_data += (sequence * n_tokens * H + h_idx) * S_v;
constexpr int warp_size = ggml_sycl_get_physical_warp_size() < S_v ? ggml_sycl_get_physical_warp_size() : S_v;
static_assert(S_v % warp_size == 0, "S_v must be a multiple of warp_size");
constexpr int rows_per_lane = (S_v + warp_size - 1) / warp_size;
float s_shard[rows_per_lane];
#pragma unroll
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
s_shard[r] = curr_state[i * S_v + col];
}
for (int t = 0; t < n_tokens; t++) {
const float * q_t = q + iq3 * sq3 + t * sq2 + iq1 * sq1;
const float * k_t = k + iq3 * sq3 + t * sq2 + iq1 * sq1;
const float * v_t = v + sequence * sv3 + t * sv2 + h_idx * sv1;
const int64_t gb_offset = sequence * sb3 + t * sb2 + h_idx * sb1;
const float * beta_t = beta + gb_offset;
const float * g_t = g + gb_offset * (KDA ? S_v : 1);
const float beta_val = *beta_t;
if constexpr (!KDA) {
const float g_val = sycl::native::exp(*g_t);
// kv[col] = (S^T @ k)[col] = sum_i S[i][col] * k[i]
float kv_shard = 0.0f;
#pragma unroll
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
kv_shard += s_shard[r] * k_t[i];
}
float kv_col = warp_reduce_sum<warp_size>(kv_shard);
// delta[col] = (v[col] - g * kv[col]) * beta
float delta_col = (v_t[col] - g_val * kv_col) * beta_val;
// fused: S[i][col] = g * S[i][col] + k[i] * delta[col]
// attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i]
float attn_partial = 0.0f;
#pragma unroll
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
s_shard[r] = g_val * s_shard[r] + k_t[i] * delta_col;
attn_partial += s_shard[r] * q_t[i];
}
float attn_col = warp_reduce_sum<warp_size>(attn_partial);
if (lane == 0) {
attn_data[col] = attn_col * scale;
}
} else {
// kv[col] = sum_i g[i] * S[i][col] * k[i]
float kv_shard = 0.0f;
#pragma unroll
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
kv_shard += sycl::native::exp(g_t[i]) * s_shard[r] * k_t[i];
}
float kv_col = warp_reduce_sum<warp_size>(kv_shard);
// delta[col] = (v[col] - kv[col]) * beta
float delta_col = (v_t[col] - kv_col) * beta_val;
// fused: S[i][col] = g[i] * S[i][col] + k[i] * delta[col]
// attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i]
float attn_partial = 0.0f;
#pragma unroll
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
s_shard[r] = sycl::native::exp(g_t[i]) * s_shard[r] + k_t[i] * delta_col;
attn_partial += s_shard[r] * q_t[i];
}
float attn_col = warp_reduce_sum<warp_size>(attn_partial);
if (lane == 0) {
attn_data[col] = attn_col * scale;
}
}
attn_data += S_v * H;
}
// Write state back to global memory
#pragma unroll
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
state[i * S_v + col] = s_shard[r];
}
}
template <bool KDA>
static void launch_gated_delta_net(const float * q_d,
const float * k_d,
const float * v_d,
const float * g_d,
const float * b_d,
const float * s_d,
float * dst_d,
int64_t S_v,
int64_t H,
int64_t n_tokens,
int64_t n_seqs,
int64_t sq1,
int64_t sq2,
int64_t sq3,
int64_t sv1,
int64_t sv2,
int64_t sv3,
int64_t sb1,
int64_t sb2,
int64_t sb3,
int64_t neqk1,
int64_t rq3,
float scale,
dpct::queue_ptr stream) {
//TODO: Add chunked kernel for even faster pre-fill
const int warp_size = ggml_sycl_info().devices[ggml_sycl_get_device()].warp_size;
const int num_warps = 4;
dpct::dim3 grid_dims(H, n_seqs, (S_v + num_warps - 1) / num_warps);
dpct::dim3 block_dims(warp_size <= S_v ? warp_size : S_v, num_warps, 1);
const sycl::uint3 neqk1_magic = init_fastdiv_values(neqk1);
const sycl::uint3 rq3_magic = init_fastdiv_values(rq3);
int cc = ggml_sycl_info().devices[ggml_sycl_get_device()].cc;
switch (S_v) {
case 16:
{
constexpr int sv = 16;
stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
gated_delta_net_sycl<sv, KDA>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens,
n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2,
sb3, neqk1_magic, rq3_magic, scale);
});
}
break;
case 32:
{
constexpr int sv = 32;
stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
gated_delta_net_sycl<sv, KDA>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens,
n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2,
sb3, neqk1_magic, rq3_magic, scale);
});
}
break;
case 64: {
{
constexpr int sv = 64;
stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
gated_delta_net_sycl<sv, KDA>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2,
sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
});
}
break;
}
case 128: {
{
constexpr int sv = 128;
stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
gated_delta_net_sycl<sv, KDA>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2,
sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
});
}
break;
}
default:
GGML_ABORT("fatal error");
break;
}
}
void ggml_sycl_op_gated_delta_net(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_tensor * src_q = dst->src[0];
ggml_tensor * src_k = dst->src[1];
ggml_tensor * src_v = dst->src[2];
ggml_tensor * src_g = dst->src[3];
ggml_tensor * src_beta = dst->src[4];
ggml_tensor * src_state = dst->src[5];
GGML_TENSOR_LOCALS(int64_t, neq, src_q, ne);
GGML_TENSOR_LOCALS(size_t , nbq, src_q, nb);
GGML_TENSOR_LOCALS(int64_t, nek, src_k, ne);
GGML_TENSOR_LOCALS(size_t , nbk, src_k, nb);
GGML_TENSOR_LOCALS(int64_t, nev, src_v, ne);
GGML_TENSOR_LOCALS(size_t, nbv, src_v, nb);
GGML_TENSOR_LOCALS(size_t, nbb, src_beta, nb);
const int64_t S_v = nev0;
const int64_t H = nev1;
const int64_t n_tokens = nev2;
const int64_t n_seqs = nev3;
const bool kda = (src_g->ne[0] == S_v);
GGML_ASSERT(neq1 == nek1);
const int64_t neqk1 = neq1;
const int64_t rq3 = nev3 / neq3;
const float * q_d = (const float *) src_q->data;
const float * k_d = (const float *) src_k->data;
const float * v_d = (const float *) src_v->data;
const float * g_d = (const float *) src_g->data;
const float * b_d = (const float *) src_beta->data;
const float * s_d = (const float *) src_state->data;
float * dst_d = (float *) dst->data;
GGML_ASSERT(ggml_is_contiguous_rows(src_q));
GGML_ASSERT(ggml_is_contiguous_rows(src_k));
GGML_ASSERT(ggml_is_contiguous_rows(src_v));
GGML_ASSERT(ggml_are_same_stride(src_q, src_k));
GGML_ASSERT(src_g->ne[0] == 1 || kda);
GGML_ASSERT(ggml_is_contiguous(src_g));
GGML_ASSERT(ggml_is_contiguous(src_beta));
GGML_ASSERT(ggml_is_contiguous(src_state));
// strides in floats (beta strides used for both g and beta offset computation)
const int64_t sq1 = nbq1 / sizeof(float);
const int64_t sq2 = nbq2 / sizeof(float);
const int64_t sq3 = nbq3 / sizeof(float);
const int64_t sv1 = nbv1 / sizeof(float);
const int64_t sv2 = nbv2 / sizeof(float);
const int64_t sv3 = nbv3 / sizeof(float);
const int64_t sb1 = nbb1 / sizeof(float);
const int64_t sb2 = nbb2 / sizeof(float);
const int64_t sb3 = nbb3 / sizeof(float);
const float scale = 1.0f / sqrtf((float) S_v);
dpct::queue_ptr stream = ctx.stream();
if (kda) {
launch_gated_delta_net<true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, neqk1, rq3, scale, stream);
} else {
launch_gated_delta_net<false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, neqk1, rq3, scale, stream);
}
}
void ggml_sycl_gated_delta_net(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/6);
ggml_sycl_op_gated_delta_net(ctx, dst);
}

View File

@ -0,0 +1,8 @@
#pragma once
#include <sycl/sycl.hpp>
#include "dpct/helper.hpp"
#include "common.hpp"
#include "ggml.h"
void ggml_sycl_gated_delta_net(ggml_backend_sycl_context & ctx, ggml_tensor * dst);

View File

@ -35,6 +35,7 @@
#endif
#include <sycl/half_type.hpp>
#include "ggml.h"
#include "ggml-sycl.h"
#include "ggml-impl.h"
#include "ggml-backend-impl.h"
@ -43,17 +44,18 @@
#include "ggml-sycl/backend.hpp"
#include "ggml-sycl/common.hpp"
#include "ggml-sycl/element_wise.hpp"
#include "ggml-sycl/gated_delta_net.hpp"
#include "ggml-sycl/gemm.hpp"
#include "ggml-sycl/getrows.hpp"
#include "ggml-sycl/norm.hpp"
#include "ggml-sycl/presets.hpp"
#include "ggml-sycl/gemm.hpp"
#include "ggml-sycl/quantize.hpp"
#include "ggml-sycl/repeat_back.hpp"
#include "ggml-sycl/set_rows.hpp"
#include "ggml-sycl/set.hpp"
#include "ggml-sycl/sycl_hw.hpp"
#include "ggml-sycl/getrows.hpp"
#include "ggml-sycl/repeat_back.hpp"
#include "ggml-sycl/quantize.hpp"
#include "ggml-sycl/ssm_conv.hpp"
#include "ggml.h"
#include "ggml-sycl/sycl_hw.hpp"
static bool g_sycl_loaded = false;
int g_ggml_sycl_debug = 0;
@ -99,6 +101,8 @@ static ggml_sycl_device_info ggml_sycl_init() {
info.devices[i].nsm = prop.get_max_compute_units() / 16; //16: Number of Xe Cores
info.devices[i].opt_feature.reorder = device.ext_oneapi_architecture_is(syclex::arch_category::intel_gpu);
info.devices[i].smpbo = prop.get_local_mem_size();
info.devices[i].warp_size = WARP_SIZE;
info.max_work_group_sizes[i] = prop.get_max_work_group_size();
info.devices[i].max_wg_per_cu = info.max_work_group_sizes[i] / prop.get_max_compute_units();
@ -4181,6 +4185,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
case GGML_OP_GATED_LINEAR_ATTN:
ggml_sycl_op_gated_linear_attn(ctx, dst);
break;
case GGML_OP_GATED_DELTA_NET:
ggml_sycl_gated_delta_net(ctx, dst);
break;
case GGML_OP_SSM_CONV:
ggml_sycl_ssm_conv(ctx, dst);
break;
@ -4890,6 +4897,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
case GGML_OP_GATED_LINEAR_ATTN:
case GGML_OP_GATED_DELTA_NET:
return true;
case GGML_OP_SSM_CONV:
return op->type == GGML_TYPE_F32 &&