diff --git a/docs/ops.md b/docs/ops.md index f914c2b7d2..00bdcd5d2b 100644 --- a/docs/ops.md +++ b/docs/ops.md @@ -15,7 +15,7 @@ Legend: | Operation | BLAS | CANN | CPU | CUDA | Metal | OpenCL | SYCL | Vulkan | WebGPU | ZenDNN | zDNN | |-----------|------|------|------|------|------|------|------|------|------|------|------| | ABS | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ | -| ACC | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | +| ACC | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ | | ADD | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | | ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | | ADD_ID | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | @@ -47,7 +47,7 @@ Legend: | FILL | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | | FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ | | FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ | -| GATED_DELTA_NET | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| GATED_DELTA_NET | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | | GATED_LINEAR_ATTN | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | | GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ | | GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ | diff --git a/docs/ops/SYCL.csv b/docs/ops/SYCL.csv index 03bfacfc9e..13ac447cc1 100644 --- a/docs/ops/SYCL.csv +++ b/docs/ops/SYCL.csv @@ -6841,10 +6841,6 @@ "SYCL0","MUL_MAT","type_a=f16,type_b=f32,m=1056,n=1,k=193,bs=[1,1],nr=[4,1],per=[0,2,1,3],k_v=0,o=1","support","1","yes","SYCL" "SYCL0","MUL_MAT","type_a=f16,type_b=f32,m=1056,n=1,k=67,bs=[1,1],nr=[4,1],per=[0,2,1,3],k_v=0,o=1","support","1","yes","SYCL" "SYCL0","MUL_MAT","type_a=f32,type_b=f32,m=64,n=77,k=77,bs=[12,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1","support","1","yes","SYCL" -"SYCL0","MUL_MAT","type_a=f16,type_b=f32,m=2,n=1,k=3,bs=[128,1024],nr=[1,1],per=[0,1,2,3],k_v=0,o=1","support","1","yes","SYCL" -"SYCL0","MUL_MAT","type_a=f16,type_b=f32,m=2,n=3,k=4,bs=[128,1024],nr=[1,1],per=[0,1,2,3],k_v=0,o=1","support","1","yes","SYCL" -"SYCL0","MUL_MAT","type_a=f16,type_b=f32,m=2,n=1,k=3,bs=[131072,1],nr=[1,1],per=[0,2,1,3],k_v=0,o=1","support","1","yes","SYCL" -"SYCL0","MUL_MAT","type_a=f16,type_b=f32,m=2,n=1,k=3,bs=[131072,1],nr=[1,1],per=[0,1,2,3],k_v=64,o=1","support","1","yes","SYCL" "SYCL0","MUL_MAT","type_a=q4_0,type_b=f32,m=576,n=512,k=576,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1","support","1","yes","SYCL" "SYCL0","MUL_MAT","type_a=q4_0,type_b=f32,m=1,n=2048,k=8192,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1","support","1","yes","SYCL" "SYCL0","MUL_MAT","type_a=f32,type_b=f32,m=1,n=64,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1","support","1","yes","SYCL" @@ -10261,8 +10257,8 @@ "SYCL0","ACC","type=f32,ne_a=[256,17,1,1],ne_b=[256,16,1,1],stride_dim=-1","support","1","yes","SYCL" "SYCL0","ACC","type=f32,ne_a=[256,17,2,3],ne_b=[256,16,2,3],stride_dim=-1","support","1","yes","SYCL" "SYCL0","ACC","type=f32,ne_a=[256,17,2,3],ne_b=[128,16,2,3],stride_dim=-1","support","1","yes","SYCL" -"SYCL0","ACC","type=f32,ne_a=[256,17,2,3],ne_b=[256,16,2,3],stride_dim=1","support","1","yes","SYCL" -"SYCL0","ACC","type=f32,ne_a=[256,17,2,3],ne_b=[128,16,2,3],stride_dim=2","support","1","yes","SYCL" +"SYCL0","ACC","type=f32,ne_a=[256,17,2,3],ne_b=[256,16,2,3],stride_dim=1","support","0","no","SYCL" +"SYCL0","ACC","type=f32,ne_a=[256,17,2,3],ne_b=[128,16,2,3],stride_dim=2","support","0","no","SYCL" "SYCL0","ACC","type=f32,ne_a=[256,17,2,3],ne_b=[64,16,2,3],stride_dim=3","support","1","yes","SYCL" "SYCL0","PAD","type=f32,ne_a=[512,512,1,1],pad_0=1,pad_1=1,circular=0","support","1","yes","SYCL" "SYCL0","PAD","type=f32,ne_a=[33,17,2,1],pad_0=4,pad_1=3,circular=1","support","0","no","SYCL" @@ -13591,16 +13587,21 @@ "SYCL0","CROSS_ENTROPY_LOSS_BACK","type=f32,ne=[30000,1,1,1]","support","0","no","SYCL" "SYCL0","OPT_STEP_ADAMW","type=f32,ne=[10,5,4,3]","support","0","no","SYCL" "SYCL0","OPT_STEP_SGD","type=f32,ne=[10,5,4,3]","support","0","no","SYCL" -"SYCL0","GATED_DELTA_NET","type=f32,head_count=32,head_size=128,n_seq_tokens=1,n_seqs=1,v_repeat=1,permuted=0,kda=0","support","0","no","SYCL" -"SYCL0","GATED_DELTA_NET","type=f32,head_count=16,head_size=64,n_seq_tokens=1,n_seqs=2,v_repeat=1,permuted=0,kda=0","support","0","no","SYCL" -"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=1,v_repeat=1,permuted=0,kda=0","support","0","no","SYCL" -"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=2,v_repeat=1,permuted=0,kda=0","support","0","no","SYCL" -"SYCL0","GATED_DELTA_NET","type=f32,head_count=8,head_size=32,n_seq_tokens=4,n_seqs=2,v_repeat=2,permuted=0,kda=0","support","0","no","SYCL" -"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=2,v_repeat=1,permuted=1,kda=0","support","0","no","SYCL" -"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=1,v_repeat=1,permuted=1,kda=0","support","0","no","SYCL" -"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=1,n_seqs=1,v_repeat=1,permuted=0,kda=1","support","0","no","SYCL" -"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=1,n_seqs=2,v_repeat=1,permuted=0,kda=1","support","0","no","SYCL" -"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=32,n_seq_tokens=4,n_seqs=1,v_repeat=1,permuted=0,kda=1","support","0","no","SYCL" -"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=2,v_repeat=1,permuted=0,kda=1","support","0","no","SYCL" -"SYCL0","GATED_DELTA_NET","type=f32,head_count=8,head_size=32,n_seq_tokens=4,n_seqs=2,v_repeat=2,permuted=0,kda=1","support","0","no","SYCL" -"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=2,v_repeat=1,permuted=1,kda=1","support","0","no","SYCL" +"SYCL0","GATED_DELTA_NET","type=f32,head_count=32,head_size=128,n_seq_tokens=1,n_seqs=1,v_repeat=1,permuted=0,kda=0","support","1","yes","SYCL" +"SYCL0","GATED_DELTA_NET","type=f32,head_count=32,head_size=16,n_seq_tokens=1,n_seqs=1,v_repeat=1,permuted=0,kda=0","support","1","yes","SYCL" +"SYCL0","GATED_DELTA_NET","type=f32,head_count=32,head_size=16,n_seq_tokens=1,n_seqs=1,v_repeat=1,permuted=1,kda=1","support","1","yes","SYCL" +"SYCL0","GATED_DELTA_NET","type=f32,head_count=32,head_size=16,n_seq_tokens=1,n_seqs=1,v_repeat=1,permuted=0,kda=1","support","1","yes","SYCL" +"SYCL0","GATED_DELTA_NET","type=f32,head_count=16,head_size=64,n_seq_tokens=1,n_seqs=2,v_repeat=1,permuted=0,kda=0","support","1","yes","SYCL" +"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=1,v_repeat=1,permuted=0,kda=0","support","1","yes","SYCL" +"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=2,v_repeat=1,permuted=0,kda=0","support","1","yes","SYCL" +"SYCL0","GATED_DELTA_NET","type=f32,head_count=8,head_size=32,n_seq_tokens=4,n_seqs=2,v_repeat=2,permuted=0,kda=0","support","1","yes","SYCL" +"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=2,v_repeat=1,permuted=1,kda=0","support","1","yes","SYCL" +"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=1,v_repeat=1,permuted=1,kda=0","support","1","yes","SYCL" +"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=1,n_seqs=1,v_repeat=1,permuted=0,kda=1","support","1","yes","SYCL" +"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=1,n_seqs=2,v_repeat=1,permuted=0,kda=1","support","1","yes","SYCL" +"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=16,n_seq_tokens=1,n_seqs=2,v_repeat=1,permuted=0,kda=1","support","1","yes","SYCL" +"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=32,n_seq_tokens=4,n_seqs=1,v_repeat=1,permuted=0,kda=1","support","1","yes","SYCL" +"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=2,v_repeat=1,permuted=0,kda=1","support","1","yes","SYCL" +"SYCL0","GATED_DELTA_NET","type=f32,head_count=8,head_size=32,n_seq_tokens=4,n_seqs=2,v_repeat=2,permuted=0,kda=1","support","1","yes","SYCL" +"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=2,v_repeat=1,permuted=1,kda=1","support","1","yes","SYCL" +"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=16,n_seq_tokens=4,n_seqs=2,v_repeat=1,permuted=1,kda=1","support","1","yes","SYCL" diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index 9f0efb6535..fcb0db99c6 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -211,7 +211,7 @@ struct sycl_device_info { // number of compute units on a SYCL device. // size_t smpb; // max. shared memory per block size_t smpbo; // max. shared memory per block (with opt-in) - int warp_size; // max sub_group_size of SYCL + int warp_size; // WARP_SIZE(16)|WARP_32_SIZE(32)|WARP_16_SIZE(16). For Intel GPU, 16 is better in most cases. Some OP support 32 only. int max_wg_per_cu; // max work groups per compute unit - refer to // cudaOccupancyMaxActiveBlocksPerMultiprocessor bool vmm; // virtual memory support diff --git a/ggml/src/ggml-sycl/gated_delta_net.cpp b/ggml/src/ggml-sycl/gated_delta_net.cpp new file mode 100644 index 0000000000..8c76afbd57 --- /dev/null +++ b/ggml/src/ggml-sycl/gated_delta_net.cpp @@ -0,0 +1,309 @@ +#include +#include "dpct/helper.hpp" +#include "common.hpp" +#include "ggml.h" +#include "gated_delta_net.hpp" +#include + + +template +void gated_delta_net_sycl(const float * q, + const float * k, + const float * v, + const float * g, + const float * beta, + const float * curr_state, + float * dst, + int64_t H, + int64_t n_tokens, + int64_t n_seqs, + int64_t sq1, + int64_t sq2, + int64_t sq3, + int64_t sv1, + int64_t sv2, + int64_t sv3, + int64_t sb1, + int64_t sb2, + int64_t sb3, + const sycl::uint3 neqk1_magic, + const sycl::uint3 rq3_magic, + float scale) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const uint32_t h_idx = item_ct1.get_group(2); + const uint32_t sequence = item_ct1.get_group(1); + // each warp owns one column, using warp-level primitives to reduce across rows + const int lane = item_ct1.get_local_id(2); + const int col = item_ct1.get_group(0) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1); + + const uint32_t iq1 = fastmodulo(h_idx, neqk1_magic); + const uint32_t iq3 = fastdiv(sequence, rq3_magic); + + const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs; + float * attn_data = dst; + float * state = dst + attn_score_elems; + + const int64_t state_offset = (sequence * H + h_idx) * S_v * S_v; + state += state_offset; + curr_state += state_offset; + attn_data += (sequence * n_tokens * H + h_idx) * S_v; + + constexpr int warp_size = ggml_sycl_get_physical_warp_size() < S_v ? ggml_sycl_get_physical_warp_size() : S_v; + static_assert(S_v % warp_size == 0, "S_v must be a multiple of warp_size"); + constexpr int rows_per_lane = (S_v + warp_size - 1) / warp_size; + float s_shard[rows_per_lane]; +#pragma unroll + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + s_shard[r] = curr_state[i * S_v + col]; + } + + for (int t = 0; t < n_tokens; t++) { + const float * q_t = q + iq3 * sq3 + t * sq2 + iq1 * sq1; + const float * k_t = k + iq3 * sq3 + t * sq2 + iq1 * sq1; + const float * v_t = v + sequence * sv3 + t * sv2 + h_idx * sv1; + + const int64_t gb_offset = sequence * sb3 + t * sb2 + h_idx * sb1; + const float * beta_t = beta + gb_offset; + const float * g_t = g + gb_offset * (KDA ? S_v : 1); + + const float beta_val = *beta_t; + + if constexpr (!KDA) { + const float g_val = sycl::native::exp(*g_t); + + // kv[col] = (S^T @ k)[col] = sum_i S[i][col] * k[i] + float kv_shard = 0.0f; +#pragma unroll + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + kv_shard += s_shard[r] * k_t[i]; + } + float kv_col = warp_reduce_sum(kv_shard); + + // delta[col] = (v[col] - g * kv[col]) * beta + float delta_col = (v_t[col] - g_val * kv_col) * beta_val; + + // fused: S[i][col] = g * S[i][col] + k[i] * delta[col] + // attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i] + float attn_partial = 0.0f; +#pragma unroll + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + s_shard[r] = g_val * s_shard[r] + k_t[i] * delta_col; + attn_partial += s_shard[r] * q_t[i]; + } + + float attn_col = warp_reduce_sum(attn_partial); + + if (lane == 0) { + attn_data[col] = attn_col * scale; + } + } else { + // kv[col] = sum_i g[i] * S[i][col] * k[i] + float kv_shard = 0.0f; +#pragma unroll + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + kv_shard += sycl::native::exp(g_t[i]) * s_shard[r] * k_t[i]; + } + + float kv_col = warp_reduce_sum(kv_shard); + + // delta[col] = (v[col] - kv[col]) * beta + float delta_col = (v_t[col] - kv_col) * beta_val; + + // fused: S[i][col] = g[i] * S[i][col] + k[i] * delta[col] + // attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i] + float attn_partial = 0.0f; +#pragma unroll + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + s_shard[r] = sycl::native::exp(g_t[i]) * s_shard[r] + k_t[i] * delta_col; + attn_partial += s_shard[r] * q_t[i]; + } + + float attn_col = warp_reduce_sum(attn_partial); + + if (lane == 0) { + attn_data[col] = attn_col * scale; + } + } + + attn_data += S_v * H; + } + + // Write state back to global memory +#pragma unroll + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + state[i * S_v + col] = s_shard[r]; + } +} + +template +static void launch_gated_delta_net(const float * q_d, + const float * k_d, + const float * v_d, + const float * g_d, + const float * b_d, + const float * s_d, + float * dst_d, + int64_t S_v, + int64_t H, + int64_t n_tokens, + int64_t n_seqs, + int64_t sq1, + int64_t sq2, + int64_t sq3, + int64_t sv1, + int64_t sv2, + int64_t sv3, + int64_t sb1, + int64_t sb2, + int64_t sb3, + int64_t neqk1, + int64_t rq3, + float scale, + dpct::queue_ptr stream) { + //TODO: Add chunked kernel for even faster pre-fill + const int warp_size = ggml_sycl_info().devices[ggml_sycl_get_device()].warp_size; + + const int num_warps = 4; + dpct::dim3 grid_dims(H, n_seqs, (S_v + num_warps - 1) / num_warps); + dpct::dim3 block_dims(warp_size <= S_v ? warp_size : S_v, num_warps, 1); + + const sycl::uint3 neqk1_magic = init_fastdiv_values(neqk1); + const sycl::uint3 rq3_magic = init_fastdiv_values(rq3); + + int cc = ggml_sycl_info().devices[ggml_sycl_get_device()].cc; + + switch (S_v) { + case 16: + { + constexpr int sv = 16; + stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + gated_delta_net_sycl(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, + n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, + sb3, neqk1_magic, rq3_magic, scale); + }); + } + break; + case 32: + { + constexpr int sv = 32; + stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + gated_delta_net_sycl(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, + n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, + sb3, neqk1_magic, rq3_magic, scale); + }); + } + break; + case 64: { + { + constexpr int sv = 64; + stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + gated_delta_net_sycl( + q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, + sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); + }); + } + break; + } + case 128: { + { + constexpr int sv = 128; + stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + gated_delta_net_sycl( + q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, + sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); + }); + } + break; + } + default: + GGML_ABORT("fatal error"); + break; + } +} + +void ggml_sycl_op_gated_delta_net(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_tensor * src_q = dst->src[0]; + ggml_tensor * src_k = dst->src[1]; + ggml_tensor * src_v = dst->src[2]; + ggml_tensor * src_g = dst->src[3]; + ggml_tensor * src_beta = dst->src[4]; + ggml_tensor * src_state = dst->src[5]; + + GGML_TENSOR_LOCALS(int64_t, neq, src_q, ne); + GGML_TENSOR_LOCALS(size_t , nbq, src_q, nb); + GGML_TENSOR_LOCALS(int64_t, nek, src_k, ne); + GGML_TENSOR_LOCALS(size_t , nbk, src_k, nb); + GGML_TENSOR_LOCALS(int64_t, nev, src_v, ne); + GGML_TENSOR_LOCALS(size_t, nbv, src_v, nb); + GGML_TENSOR_LOCALS(size_t, nbb, src_beta, nb); + + const int64_t S_v = nev0; + const int64_t H = nev1; + const int64_t n_tokens = nev2; + const int64_t n_seqs = nev3; + + const bool kda = (src_g->ne[0] == S_v); + + GGML_ASSERT(neq1 == nek1); + const int64_t neqk1 = neq1; + + const int64_t rq3 = nev3 / neq3; + + const float * q_d = (const float *) src_q->data; + const float * k_d = (const float *) src_k->data; + const float * v_d = (const float *) src_v->data; + const float * g_d = (const float *) src_g->data; + const float * b_d = (const float *) src_beta->data; + + const float * s_d = (const float *) src_state->data; + float * dst_d = (float *) dst->data; + + GGML_ASSERT(ggml_is_contiguous_rows(src_q)); + GGML_ASSERT(ggml_is_contiguous_rows(src_k)); + GGML_ASSERT(ggml_is_contiguous_rows(src_v)); + GGML_ASSERT(ggml_are_same_stride(src_q, src_k)); + GGML_ASSERT(src_g->ne[0] == 1 || kda); + GGML_ASSERT(ggml_is_contiguous(src_g)); + GGML_ASSERT(ggml_is_contiguous(src_beta)); + GGML_ASSERT(ggml_is_contiguous(src_state)); + + // strides in floats (beta strides used for both g and beta offset computation) + const int64_t sq1 = nbq1 / sizeof(float); + const int64_t sq2 = nbq2 / sizeof(float); + const int64_t sq3 = nbq3 / sizeof(float); + const int64_t sv1 = nbv1 / sizeof(float); + const int64_t sv2 = nbv2 / sizeof(float); + const int64_t sv3 = nbv3 / sizeof(float); + const int64_t sb1 = nbb1 / sizeof(float); + const int64_t sb2 = nbb2 / sizeof(float); + const int64_t sb3 = nbb3 / sizeof(float); + + const float scale = 1.0f / sqrtf((float) S_v); + + dpct::queue_ptr stream = ctx.stream(); + + if (kda) { + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1, rq3, scale, stream); + } else { + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1, rq3, scale, stream); + } +} + +void ggml_sycl_gated_delta_net(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/6); + ggml_sycl_op_gated_delta_net(ctx, dst); +} diff --git a/ggml/src/ggml-sycl/gated_delta_net.hpp b/ggml/src/ggml-sycl/gated_delta_net.hpp new file mode 100644 index 0000000000..a3308ee876 --- /dev/null +++ b/ggml/src/ggml-sycl/gated_delta_net.hpp @@ -0,0 +1,8 @@ +#pragma once + +#include +#include "dpct/helper.hpp" +#include "common.hpp" +#include "ggml.h" + +void ggml_sycl_gated_delta_net(ggml_backend_sycl_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index f887061b27..1281970584 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -35,6 +35,7 @@ #endif #include +#include "ggml.h" #include "ggml-sycl.h" #include "ggml-impl.h" #include "ggml-backend-impl.h" @@ -43,17 +44,18 @@ #include "ggml-sycl/backend.hpp" #include "ggml-sycl/common.hpp" #include "ggml-sycl/element_wise.hpp" +#include "ggml-sycl/gated_delta_net.hpp" +#include "ggml-sycl/gemm.hpp" +#include "ggml-sycl/getrows.hpp" #include "ggml-sycl/norm.hpp" #include "ggml-sycl/presets.hpp" -#include "ggml-sycl/gemm.hpp" +#include "ggml-sycl/quantize.hpp" +#include "ggml-sycl/repeat_back.hpp" #include "ggml-sycl/set_rows.hpp" #include "ggml-sycl/set.hpp" -#include "ggml-sycl/sycl_hw.hpp" -#include "ggml-sycl/getrows.hpp" -#include "ggml-sycl/repeat_back.hpp" -#include "ggml-sycl/quantize.hpp" #include "ggml-sycl/ssm_conv.hpp" -#include "ggml.h" +#include "ggml-sycl/sycl_hw.hpp" + static bool g_sycl_loaded = false; int g_ggml_sycl_debug = 0; @@ -99,6 +101,8 @@ static ggml_sycl_device_info ggml_sycl_init() { info.devices[i].nsm = prop.get_max_compute_units() / 16; //16: Number of Xe Cores info.devices[i].opt_feature.reorder = device.ext_oneapi_architecture_is(syclex::arch_category::intel_gpu); info.devices[i].smpbo = prop.get_local_mem_size(); + info.devices[i].warp_size = WARP_SIZE; + info.max_work_group_sizes[i] = prop.get_max_work_group_size(); info.devices[i].max_wg_per_cu = info.max_work_group_sizes[i] / prop.get_max_compute_units(); @@ -4181,6 +4185,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg case GGML_OP_GATED_LINEAR_ATTN: ggml_sycl_op_gated_linear_attn(ctx, dst); break; + case GGML_OP_GATED_DELTA_NET: + ggml_sycl_gated_delta_net(ctx, dst); + break; case GGML_OP_SSM_CONV: ggml_sycl_ssm_conv(ctx, dst); break; @@ -4890,6 +4897,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV7: case GGML_OP_GATED_LINEAR_ATTN: + case GGML_OP_GATED_DELTA_NET: return true; case GGML_OP_SSM_CONV: return op->type == GGML_TYPE_F32 &&