llama : enable chunked fused GDN path (#20340)
* llama : enable chunked fused GDN path
* models : avoid Q and K repeats when using fused GDA
* cont : fix comment
Co-authored-by: Aman Gupta <amangupta052@gmail.com>
* cont : fix the fix
Co-authored-by: Aman Gupta <amangupta052@gmail.com>
* cont : fix
* metal : add GDN kernel (#20361)
* metal : add Metal backend for GGML_OP_GATED_DELTA_NET
Add a fused Metal kernel for the gated delta net recurrence op
(#19504), enabling GPU-accelerated inference for DeltaNet-based
models (Qwen3.5, etc.) on Apple Silicon.
Supports both GDA (scalar gate) and KDA (per-row gate) modes
with head_size 64 and 128. Unsupported configurations (head_size
32, non-contiguous tensors) gracefully fall back to CPU.
Performance: Qwen3.5-0.8B Q4_K_M on M4 Max
tg128: 170 -> 213 t/s (+25%)
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
* metal : validate contiguity of all input tensors in supports_op
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
* metal : add algorithm equivalence comment for GDA decay path
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
* cont : unslop + optimize
* cont : clean-up
---------
Co-authored-by: Paul Flynn <paul@arkavo.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
* CUDA: AR gated delta net improvements (#20391)
* Add FastDiv to gated_delta_net_cuda
* Shard columns across warps
This reduces register pressure (avoids spill for S_v = 128) and gives
the warp-scheduler more CTAs to schedule (thus hiding data-access
latencies).
* Remove unneded include in gated_delta_net.cu
* Improve comments
* Apply code-formating
* Make sharding HIP-compatible
1. Use ggml_cuda_get_physical_warp_size() to determine warp size flexibly
2. Add test with partial warp to test sum reduction on CUDA
* Remove fastdiv_s64, as we can treat neqk1 and rq3 as uint32_t
* Rename variables
* Enable GDN also for prefill, move TODO for chunked_GDN
* Actually remove the TODO from 2068908975
* Get warp size at runtime
warp_size is not known at compile time in hip host code.
* Don't expose ggml_cuda_get_physical_warp_size on host
---------
Co-authored-by: uvos <devnull@uvos.xyz>
* llama : refactor llm_build_delta_net_base API
---------
Co-authored-by: Aman Gupta <amangupta052@gmail.com>
Co-authored-by: Paul Flynn <paul@arkavo.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Oliver Simons <osimons@nvidia.com>
Co-authored-by: uvos <devnull@uvos.xyz>
This commit is contained in:
parent
f90bd1dd84
commit
d28961d81e
|
|
@ -2466,6 +2466,8 @@ extern "C" {
|
|||
bool lower,
|
||||
bool uni);
|
||||
|
||||
// TODO: add ggml_gated_delta_net_set_bcast() to be able to configure Q, K broadcast type: tiled vs interleaved [TAG_GGML_GDN_BCAST]
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/19468#discussion_r2786394306
|
||||
GGML_API struct ggml_tensor * ggml_gated_delta_net(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * q,
|
||||
|
|
|
|||
|
|
@ -10443,8 +10443,8 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
|
|||
|
||||
const float * state_in_base = (const float *)src_state->data;
|
||||
|
||||
const int64_t rq1 = nev1 / neq1;
|
||||
const int64_t rk1 = nev1 / nek1;
|
||||
//const int64_t rq1 = nev1 / neq1;
|
||||
//const int64_t rk1 = nev1 / nek1;
|
||||
const int64_t rq3 = nev3 / neq3;
|
||||
const int64_t rk3 = nev3 / nek3;
|
||||
|
||||
|
|
@ -10454,8 +10454,8 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
|
|||
const int64_t iv1 = ir % H; // head_index
|
||||
const int64_t iv3 = ir / H; // sequence
|
||||
|
||||
const int64_t iq1 = iv1 / rq1;
|
||||
const int64_t ik1 = iv1 / rk1;
|
||||
const int64_t iq1 = iv1 % neq1;
|
||||
const int64_t ik1 = iv1 % nek1;
|
||||
|
||||
const int64_t iq3 = iv3 / rq3;
|
||||
const int64_t ik3 = iv3 / rk3;
|
||||
|
|
@ -10475,7 +10475,7 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
|
|||
const float * v_d = (const float *)((const char *)src_v->data + iv3 * nbv3 + t * nbv2 + iv1 * nbv1);
|
||||
|
||||
const float beta_val = *(const float *)((const char *)src_beta->data + iv3 * nbb3 + t * nbb2 + iv1 * nbb1);
|
||||
const float * g_d = (const float *)((const char *)src_g->data + iv3 * nbg3 + t * nbg2 + iv1 * nbg1);
|
||||
const float * g_d = (const float *)((const char *)src_g->data + iv3 * nbg3 + t * nbg2 + iv1 * nbg1);
|
||||
|
||||
if (kda) {
|
||||
for (int64_t i = 0; i < S_v; ++i) {
|
||||
|
|
@ -10508,7 +10508,6 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
|
|||
|
||||
attn_data += S_v * H; // advance to next token
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,36 +1,36 @@
|
|||
#include "gated_delta_net.cuh"
|
||||
#include "ggml-cuda/common.cuh"
|
||||
|
||||
template <int S_v, bool KDA>
|
||||
__global__ void __launch_bounds__(S_v, 1)
|
||||
gated_delta_net_cuda(const float * q,
|
||||
const float * k,
|
||||
const float * v,
|
||||
const float * g,
|
||||
const float * beta,
|
||||
const float * curr_state,
|
||||
float * dst,
|
||||
const int64_t H,
|
||||
const int64_t n_tokens,
|
||||
const int64_t n_seqs,
|
||||
const int64_t sq1,
|
||||
const int64_t sq2,
|
||||
const int64_t sq3,
|
||||
const int64_t sv1,
|
||||
const int64_t sv2,
|
||||
const int64_t sv3,
|
||||
const int64_t sb1,
|
||||
const int64_t sb2,
|
||||
const int64_t sb3,
|
||||
const int64_t rq1,
|
||||
const int64_t rq3,
|
||||
const float scale) {
|
||||
const int64_t h_idx = blockIdx.x;
|
||||
const int64_t sequence = blockIdx.y;
|
||||
const int col = threadIdx.x; // each thread owns one column
|
||||
__global__ void gated_delta_net_cuda(const float * q,
|
||||
const float * k,
|
||||
const float * v,
|
||||
const float * g,
|
||||
const float * beta,
|
||||
const float * curr_state,
|
||||
float * dst,
|
||||
int64_t H,
|
||||
int64_t n_tokens,
|
||||
int64_t n_seqs,
|
||||
int64_t sq1,
|
||||
int64_t sq2,
|
||||
int64_t sq3,
|
||||
int64_t sv1,
|
||||
int64_t sv2,
|
||||
int64_t sv3,
|
||||
int64_t sb1,
|
||||
int64_t sb2,
|
||||
int64_t sb3,
|
||||
const uint3 neqk1_magic,
|
||||
const uint3 rq3_magic,
|
||||
float scale) {
|
||||
const uint32_t h_idx = blockIdx.x;
|
||||
const uint32_t sequence = blockIdx.y;
|
||||
// each warp owns one column, using warp-level primitives to reduce across rows
|
||||
const int lane = threadIdx.x;
|
||||
const int col = blockIdx.z * blockDim.y + threadIdx.y;
|
||||
|
||||
const int64_t iq1 = h_idx / rq1;
|
||||
const int64_t iq3 = sequence / rq3;
|
||||
const uint32_t iq1 = fastmodulo(h_idx, neqk1_magic);
|
||||
const uint32_t iq3 = fastdiv(sequence, rq3_magic);
|
||||
|
||||
const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs;
|
||||
float * attn_data = dst;
|
||||
|
|
@ -41,17 +41,14 @@ gated_delta_net_cuda(const float * q,
|
|||
curr_state += state_offset;
|
||||
attn_data += (sequence * n_tokens * H + h_idx) * S_v;
|
||||
|
||||
// GCN and CDNA devices spill registers, we use shared mem for them. See https://github.com/ggml-org/llama.cpp/pull/20282#issuecomment-4025770229
|
||||
// TODO: check optimal path for RDNA1 and RDNA2 devices.
|
||||
#if (defined(GGML_USE_HIP) && !defined(RDNA3) && !defined(RDNA4)) || defined(GGML_USE_MUSA)
|
||||
extern __shared__ float s_shared[];
|
||||
float * s = s_shared + col * S_v;
|
||||
#else
|
||||
float s[S_v];
|
||||
#endif
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v;
|
||||
static_assert(S_v % warp_size == 0, "S_v must be a multiple of warp_size");
|
||||
constexpr int rows_per_lane = (S_v + warp_size - 1) / warp_size;
|
||||
float s_shard[rows_per_lane];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < S_v; i++) {
|
||||
s[i] = curr_state[i * S_v + col];
|
||||
for (int r = 0; r < rows_per_lane; r++) {
|
||||
const int i = r * warp_size + lane;
|
||||
s_shard[r] = curr_state[i * S_v + col];
|
||||
}
|
||||
|
||||
for (int t = 0; t < n_tokens; t++) {
|
||||
|
|
@ -69,46 +66,61 @@ gated_delta_net_cuda(const float * q,
|
|||
const float g_val = expf(*g_t);
|
||||
|
||||
// kv[col] = (S^T @ k)[col] = sum_i S[i][col] * k[i]
|
||||
float kv_col = 0.0f;
|
||||
float kv_shard = 0.0f;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < S_v; i++) {
|
||||
kv_col += s[i] * k_t[i];
|
||||
for (int r = 0; r < rows_per_lane; r++) {
|
||||
const int i = r * warp_size + lane;
|
||||
kv_shard += s_shard[r] * k_t[i];
|
||||
}
|
||||
float kv_col = warp_reduce_sum<warp_size>(kv_shard);
|
||||
|
||||
// delta[col] = (v[col] - g * kv[col]) * beta
|
||||
float delta_col = (v_t[col] - g_val * kv_col) * beta_val;
|
||||
|
||||
// fused: S[i][col] = g * S[i][col] + k[i] * delta[col]
|
||||
// attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i]
|
||||
float attn_col = 0.0f;
|
||||
float attn_partial = 0.0f;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < S_v; i++) {
|
||||
s[i] = g_val * s[i] + k_t[i] * delta_col;
|
||||
attn_col += s[i] * q_t[i];
|
||||
for (int r = 0; r < rows_per_lane; r++) {
|
||||
const int i = r * warp_size + lane;
|
||||
s_shard[r] = g_val * s_shard[r] + k_t[i] * delta_col;
|
||||
attn_partial += s_shard[r] * q_t[i];
|
||||
}
|
||||
|
||||
attn_data[col] = attn_col * scale;
|
||||
float attn_col = warp_reduce_sum<warp_size>(attn_partial);
|
||||
|
||||
if (lane == 0) {
|
||||
attn_data[col] = attn_col * scale;
|
||||
}
|
||||
} else {
|
||||
// kv[col] = sum_i g[i] * S[i][col] * k[i]
|
||||
float kv_col = 0.0f;
|
||||
float kv_shard = 0.0f;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < S_v; i++) {
|
||||
kv_col += expf(g_t[i]) * s[i] * k_t[i];
|
||||
for (int r = 0; r < rows_per_lane; r++) {
|
||||
const int i = r * warp_size + lane;
|
||||
kv_shard += expf(g_t[i]) * s_shard[r] * k_t[i];
|
||||
}
|
||||
|
||||
float kv_col = warp_reduce_sum<warp_size>(kv_shard);
|
||||
|
||||
// delta[col] = (v[col] - kv[col]) * beta
|
||||
float delta_col = (v_t[col] - kv_col) * beta_val;
|
||||
|
||||
// fused: S[i][col] = g[i] * S[i][col] + k[i] * delta[col]
|
||||
// attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i]
|
||||
float attn_col = 0.0f;
|
||||
float attn_partial = 0.0f;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < S_v; i++) {
|
||||
s[i] = expf(g_t[i]) * s[i] + k_t[i] * delta_col;
|
||||
attn_col += s[i] * q_t[i];
|
||||
for (int r = 0; r < rows_per_lane; r++) {
|
||||
const int i = r * warp_size + lane;
|
||||
s_shard[r] = expf(g_t[i]) * s_shard[r] + k_t[i] * delta_col;
|
||||
attn_partial += s_shard[r] * q_t[i];
|
||||
}
|
||||
|
||||
attn_data[col] = attn_col * scale;
|
||||
float attn_col = warp_reduce_sum<warp_size>(attn_partial);
|
||||
|
||||
if (lane == 0) {
|
||||
attn_data[col] = attn_col * scale;
|
||||
}
|
||||
}
|
||||
|
||||
attn_data += S_v * H;
|
||||
|
|
@ -116,8 +128,9 @@ gated_delta_net_cuda(const float * q,
|
|||
|
||||
// Write state back to global memory
|
||||
#pragma unroll
|
||||
for (int i = 0; i < S_v; i++) {
|
||||
state[i * S_v + col] = s[i];
|
||||
for (int r = 0; r < rows_per_lane; r++) {
|
||||
const int i = r * warp_size + lane;
|
||||
state[i * S_v + col] = s_shard[r];
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -135,35 +148,43 @@ static void launch_gated_delta_net(
|
|||
const float * q_d, const float * k_d, const float * v_d,
|
||||
const float * g_d, const float * b_d, const float * s_d,
|
||||
float * dst_d,
|
||||
int64_t S_v, int64_t H, int64_t n_tokens, int64_t n_seqs,
|
||||
int64_t sq1, int64_t sq2, int64_t sq3,
|
||||
int64_t sv1, int64_t sv2, int64_t sv3,
|
||||
int64_t sb1, int64_t sb2, int64_t sb3,
|
||||
int64_t rq1, int64_t rq3,
|
||||
int64_t S_v, int64_t H, int64_t n_tokens, int64_t n_seqs,
|
||||
int64_t sq1, int64_t sq2, int64_t sq3,
|
||||
int64_t sv1, int64_t sv2, int64_t sv3,
|
||||
int64_t sb1, int64_t sb2, int64_t sb3,
|
||||
int64_t neqk1, int64_t rq3,
|
||||
float scale, cudaStream_t stream) {
|
||||
//TODO: Add chunked kernel for even faster pre-fill
|
||||
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
|
||||
const int num_warps = 4;
|
||||
dim3 grid_dims(H, n_seqs, (S_v + num_warps - 1) / num_warps);
|
||||
dim3 block_dims(warp_size <= S_v ? warp_size : S_v, num_warps, 1);
|
||||
|
||||
dim3 grid_dims(H, n_seqs, 1);
|
||||
dim3 block_dims(S_v, 1, 1);
|
||||
const uint3 neqk1_magic = init_fastdiv_values(neqk1);
|
||||
const uint3 rq3_magic = init_fastdiv_values(rq3);
|
||||
|
||||
int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
|
||||
switch (S_v) {
|
||||
case 32: {
|
||||
constexpr int sv = 32;
|
||||
size_t smem = calculate_smem(sv, cc);
|
||||
gated_delta_net_cuda<sv, KDA><<<grid_dims, block_dims, smem, stream>>>(
|
||||
case 16:
|
||||
gated_delta_net_cuda<16, KDA><<<grid_dims, block_dims, 0, stream>>>(
|
||||
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
|
||||
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
|
||||
sb1, sb2, sb3, rq1, rq3, scale);
|
||||
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
|
||||
break;
|
||||
case 32:
|
||||
gated_delta_net_cuda<32, KDA><<<grid_dims, block_dims, 0, stream>>>(
|
||||
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
|
||||
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
|
||||
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
|
||||
break;
|
||||
}
|
||||
case 64: {
|
||||
constexpr int sv = 64;
|
||||
size_t smem = calculate_smem(sv, cc);
|
||||
gated_delta_net_cuda<sv, KDA><<<grid_dims, block_dims, smem, stream>>>(
|
||||
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
|
||||
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
|
||||
sb1, sb2, sb3, rq1, rq3, scale);
|
||||
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
|
||||
break;
|
||||
}
|
||||
case 128: {
|
||||
|
|
@ -172,7 +193,7 @@ static void launch_gated_delta_net(
|
|||
gated_delta_net_cuda<sv, KDA><<<grid_dims, block_dims, smem, stream>>>(
|
||||
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
|
||||
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
|
||||
sb1, sb2, sb3, rq1, rq3, scale);
|
||||
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
|
|
@ -190,10 +211,12 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|||
ggml_tensor * src_state = dst->src[5];
|
||||
|
||||
GGML_TENSOR_LOCALS(int64_t, neq, src_q, ne);
|
||||
GGML_TENSOR_LOCALS(size_t, nbq, src_q, nb);
|
||||
GGML_TENSOR_LOCALS(size_t , nbq, src_q, nb);
|
||||
GGML_TENSOR_LOCALS(int64_t, nek, src_k, ne);
|
||||
GGML_TENSOR_LOCALS(size_t , nbk, src_k, nb);
|
||||
GGML_TENSOR_LOCALS(int64_t, nev, src_v, ne);
|
||||
GGML_TENSOR_LOCALS(size_t, nbv, src_v, nb);
|
||||
GGML_TENSOR_LOCALS(size_t, nbb, src_beta, nb);
|
||||
GGML_TENSOR_LOCALS(size_t, nbv, src_v, nb);
|
||||
GGML_TENSOR_LOCALS(size_t, nbb, src_beta, nb);
|
||||
|
||||
const int64_t S_v = nev0;
|
||||
const int64_t H = nev1;
|
||||
|
|
@ -202,7 +225,9 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|||
|
||||
const bool kda = (src_g->ne[0] == S_v);
|
||||
|
||||
const int64_t rq1 = nev1 / neq1;
|
||||
GGML_ASSERT(neq1 == nek1);
|
||||
const int64_t neqk1 = neq1;
|
||||
|
||||
const int64_t rq3 = nev3 / neq3;
|
||||
|
||||
const float * q_d = (const float *) src_q->data;
|
||||
|
|
@ -241,10 +266,10 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|||
if (kda) {
|
||||
launch_gated_delta_net<true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
|
||||
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
|
||||
sb1, sb2, sb3, rq1, rq3, scale, stream);
|
||||
sb1, sb2, sb3, neqk1, rq3, scale, stream);
|
||||
} else {
|
||||
launch_gated_delta_net<false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
|
||||
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
|
||||
sb1, sb2, sb3, rq1, rq3, scale, stream);
|
||||
sb1, sb2, sb3, neqk1, rq3, scale, stream);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -577,6 +577,41 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv(ggml_metal_
|
|||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
// v is src[2], dimensions: S_v = ne[0], H = ne[1]
|
||||
const int ne20 = op->src[2]->ne[0]; // S_v
|
||||
const int ne21 = op->src[2]->ne[1]; // H
|
||||
const int ne30 = op->src[3]->ne[0]; // G
|
||||
|
||||
const int nsg = op->src[2]->ne[0]/32;
|
||||
|
||||
GGML_ASSERT(op->src[5]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(op->ne[0] == ne20 * ne21);
|
||||
GGML_ASSERT(ne20 % 32 == 0);
|
||||
|
||||
snprintf(base, 256, "kernel_gated_delta_net_%s_%d", ggml_type_name(op->src[0]->type), nsg);
|
||||
snprintf(name, 256, "%s_ne20=%d_ne30=%d", base, ne20, ne30);
|
||||
|
||||
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (!res.pipeline) {
|
||||
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
||||
|
||||
ggml_metal_cv_set_int16(cv, ne20, FC_GATED_DELTA_NET + 0);
|
||||
ggml_metal_cv_set_int16(cv, ne30, FC_GATED_DELTA_NET + 1);
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
||||
|
||||
ggml_metal_cv_free(cv);
|
||||
}
|
||||
|
||||
res.nsg = nsg;
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
|
|
|||
|
|
@ -125,6 +125,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv
|
|||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched (ggml_metal_library_t lib, const struct ggml_tensor * op, int ssm_conv_bs);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
|
|
|
|||
|
|
@ -1155,6 +1155,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
|||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
return true;
|
||||
case GGML_OP_GATED_DELTA_NET:
|
||||
return op->src[2]->ne[0] % 32 == 0;
|
||||
case GGML_OP_SOLVE_TRI:
|
||||
case GGML_OP_MUL_MAT:
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
|
|
|
|||
|
|
@ -84,6 +84,7 @@
|
|||
#define FC_BIN 1300
|
||||
#define FC_SUM_ROWS 1400
|
||||
#define FC_UPSCALE 1500
|
||||
#define FC_GATED_DELTA_NET 1600
|
||||
|
||||
// op-specific constants
|
||||
#define OP_FLASH_ATTN_EXT_NQPSG 8
|
||||
|
|
@ -793,6 +794,44 @@ typedef struct {
|
|||
uint64_t nb0;
|
||||
} ggml_metal_kargs_ssm_scan;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne01;
|
||||
int32_t ne02;
|
||||
int32_t ne03;
|
||||
uint64_t nb00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
int32_t ne10;
|
||||
int32_t ne11;
|
||||
int32_t ne12;
|
||||
int32_t ne13;
|
||||
uint64_t nb10;
|
||||
uint64_t nb11;
|
||||
uint64_t nb12;
|
||||
uint64_t nb13;
|
||||
int32_t ne20;
|
||||
int32_t ne21;
|
||||
int32_t ne22;
|
||||
int32_t ne23;
|
||||
uint64_t nb20;
|
||||
uint64_t nb21;
|
||||
uint64_t nb22;
|
||||
uint64_t nb23;
|
||||
int32_t ns02;
|
||||
int32_t ns12;
|
||||
int32_t ns22;
|
||||
int32_t ne0;
|
||||
int32_t ne1;
|
||||
int32_t ne2;
|
||||
int32_t ne3;
|
||||
uint64_t nb0;
|
||||
uint64_t nb1;
|
||||
uint64_t nb2;
|
||||
uint64_t nb3;
|
||||
} ggml_metal_kargs_gated_delta_net;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne01;
|
||||
|
|
|
|||
|
|
@ -333,6 +333,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
|||
{
|
||||
n_fuse = ggml_metal_op_rwkv(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_GATED_DELTA_NET:
|
||||
{
|
||||
n_fuse = ggml_metal_op_gated_delta_net(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_SOLVE_TRI:
|
||||
{
|
||||
n_fuse = ggml_metal_op_solve_tri(ctx, idx);
|
||||
|
|
@ -1562,6 +1566,81 @@ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) {
|
|||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_gated_delta_net(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
auto pipeline = ggml_metal_library_get_pipeline_gated_delta_net(lib, op);
|
||||
|
||||
int ida = 0;
|
||||
|
||||
ggml_metal_kargs_gated_delta_net args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.ne03 =*/ ne03,
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.ne10 =*/ ne10,
|
||||
/*.ne11 =*/ ne11,
|
||||
/*.ne12 =*/ ne12,
|
||||
/*.ne13 =*/ ne13,
|
||||
/*.nb10 =*/ nb10,
|
||||
/*.nb11 =*/ nb11,
|
||||
/*.nb12 =*/ nb12,
|
||||
/*.nb13 =*/ nb13,
|
||||
/*.ne20 =*/ ne20,
|
||||
/*.ne21 =*/ ne21,
|
||||
/*.ne22 =*/ ne22,
|
||||
/*.ne23 =*/ ne23,
|
||||
/*.nb20 =*/ nb20,
|
||||
/*.nb21 =*/ nb21,
|
||||
/*.nb22 =*/ nb22,
|
||||
/*.nb23 =*/ nb23,
|
||||
/*.ns02 =*/ (int32_t) (nb02/sizeof(float)),
|
||||
/*.ns12 =*/ (int32_t) (nb12/sizeof(float)),
|
||||
/*.ns22 =*/ (int32_t) (nb22/sizeof(float)),
|
||||
/*.ne0 =*/ ne0,
|
||||
/*.ne1 =*/ ne1,
|
||||
/*.ne2 =*/ ne2,
|
||||
/*.ne3 =*/ ne3,
|
||||
/*.nb0 =*/ nb0,
|
||||
/*.nb1 =*/ nb1,
|
||||
/*.nb2 =*/ nb2,
|
||||
/*.nb3 =*/ nb3,
|
||||
};
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++); // q
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++); // k
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++); // v
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++); // gate
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++); // beta
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[5]), ida++); // state
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), ida++); // dst
|
||||
|
||||
const int nsg = pipeline.nsg;
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, op->src[2]->ne[0]/nsg, op->src[2]->ne[1], op->src[2]->ne[3], 32, nsg, 1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
|
|
|
|||
|
|
@ -58,6 +58,7 @@ int ggml_metal_op_soft_max (ggml_metal_op_t ctx, int idx);
|
|||
int ggml_metal_op_ssm_conv (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_ssm_scan (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_rwkv (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_gated_delta_net (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_solve_tri (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_set (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_cpy (ggml_metal_op_t ctx, int idx);
|
||||
|
|
|
|||
|
|
@ -2434,6 +2434,227 @@ kernel void kernel_rwkv_wkv7_f32(
|
|||
}
|
||||
}
|
||||
|
||||
constant short FC_gated_delta_net_ne20 [[function_constant(FC_GATED_DELTA_NET + 0)]];
|
||||
constant short FC_gated_delta_net_ne30 [[function_constant(FC_GATED_DELTA_NET + 1)]];
|
||||
|
||||
#if 1
|
||||
template<short NSG>
|
||||
kernel void kernel_gated_delta_net_impl(
|
||||
constant ggml_metal_kargs_gated_delta_net & args,
|
||||
device const char * q,
|
||||
device const char * k,
|
||||
device const char * v,
|
||||
device const char * g,
|
||||
device const char * b,
|
||||
device const char * s,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
#define S_v FC_gated_delta_net_ne20
|
||||
#define G FC_gated_delta_net_ne30
|
||||
|
||||
const uint tx = tpitg.x;
|
||||
const uint ty = tpitg.y;
|
||||
|
||||
const uint i23 = tgpig.z; // B
|
||||
const uint i21 = tgpig.y; // H
|
||||
const uint i20 = tgpig.x*NSG + ty;
|
||||
|
||||
const uint i01 = i21 % args.ne01;
|
||||
const uint i11 = i21 % args.ne11;
|
||||
|
||||
const float scale = 1.0f / sqrt((float)S_v);
|
||||
|
||||
device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20;
|
||||
|
||||
float ls[NSG];
|
||||
|
||||
FOR_UNROLL (short j = 0; j < NSG; j++) {
|
||||
const short is = tx*NSG + j;
|
||||
ls[j] = s_ptr[is*S_v];
|
||||
}
|
||||
|
||||
device float * dst_attn = (device float *) (dst) + (i23*args.ne22*args.ne21 + i21)*S_v + i20;
|
||||
|
||||
device const float * q_ptr = (device const float *) (q + i23*args.nb03 + i01*args.nb01);
|
||||
device const float * k_ptr = (device const float *) (k + i23*args.nb13 + i11*args.nb11);
|
||||
device const float * v_ptr = (device const float *) (v + i23*args.nb23 + i21*args.nb21);
|
||||
|
||||
device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21);
|
||||
device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G;
|
||||
|
||||
for (short t = 0; t < args.ne22; t++) {
|
||||
float s_k = 0.0f;
|
||||
|
||||
if (G == 1) {
|
||||
const float g_exp = exp(g_ptr[0]);
|
||||
|
||||
FOR_UNROLL (short j = 0; j < NSG; j++) {
|
||||
const short is = tx*NSG + j;
|
||||
ls[j] *= g_exp;
|
||||
|
||||
s_k += ls[j]*k_ptr[is];
|
||||
}
|
||||
} else {
|
||||
// KDA
|
||||
FOR_UNROLL (short j = 0; j < NSG; j++) {
|
||||
const short is = tx*NSG + j;
|
||||
ls[j] *= exp(g_ptr[is]);
|
||||
|
||||
s_k += ls[j]*k_ptr[is];
|
||||
}
|
||||
}
|
||||
|
||||
s_k = simd_sum(s_k);
|
||||
|
||||
const float d = (v_ptr[i20] - s_k)*b_ptr[0];
|
||||
|
||||
float y = 0.0f;
|
||||
|
||||
FOR_UNROLL (short j = 0; j < NSG; j++) {
|
||||
const short is = tx*NSG + j;
|
||||
ls[j] += k_ptr[is]*d;
|
||||
|
||||
y += ls[j]*q_ptr[is];
|
||||
}
|
||||
|
||||
y = simd_sum(y);
|
||||
|
||||
if (tx == 0) {
|
||||
dst_attn[t*args.ne21*S_v] = y*scale;
|
||||
}
|
||||
|
||||
q_ptr += args.ns02;
|
||||
k_ptr += args.ns12;
|
||||
v_ptr += args.ns22;
|
||||
|
||||
b_ptr += args.ne21;
|
||||
g_ptr += args.ne21*G;
|
||||
}
|
||||
|
||||
device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20;
|
||||
|
||||
FOR_UNROLL (short j = 0; j < NSG; j++) {
|
||||
const short is = tx*NSG + j;
|
||||
dst_state[is*S_v] = ls[j];
|
||||
}
|
||||
|
||||
#undef S_v
|
||||
#undef G
|
||||
}
|
||||
|
||||
typedef decltype(kernel_gated_delta_net_impl<4>) kernel_gated_delta_net_t;
|
||||
|
||||
template [[host_name("kernel_gated_delta_net_f32_1")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<1>;
|
||||
template [[host_name("kernel_gated_delta_net_f32_2")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<2>;
|
||||
template [[host_name("kernel_gated_delta_net_f32_4")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<4>;
|
||||
|
||||
#else
|
||||
// a simplified version of the above
|
||||
// no performance improvement, so keep the above version for now
|
||||
|
||||
template<typename T, short NSG>
|
||||
kernel void kernel_gated_delta_net_impl(
|
||||
constant ggml_metal_kargs_gated_delta_net & args,
|
||||
device const char * q,
|
||||
device const char * k,
|
||||
device const char * v,
|
||||
device const char * g,
|
||||
device const char * b,
|
||||
device const char * s,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
#define S_v FC_gated_delta_net_ne20
|
||||
#define G FC_gated_delta_net_ne30
|
||||
|
||||
const uint tx = tpitg.x;
|
||||
const uint ty = tpitg.y;
|
||||
|
||||
const uint i23 = tgpig.z; // B
|
||||
const uint i21 = tgpig.y; // H
|
||||
const uint i20 = tgpig.x*NSG + ty;
|
||||
|
||||
const uint i01 = i21 % args.ne01;
|
||||
const uint i11 = i21 % args.ne11;
|
||||
|
||||
const float scale = 1.0f / sqrt((float)S_v);
|
||||
|
||||
device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20;
|
||||
|
||||
float lsf[NSG];
|
||||
|
||||
FOR_UNROLL (short j = 0; j < NSG; j++) {
|
||||
const short is = tx*NSG + j;
|
||||
lsf[j] = s_ptr[is*S_v];
|
||||
}
|
||||
|
||||
thread T * ls = (thread T *) (lsf);
|
||||
|
||||
device float * dst_attn = (device float *) (dst) + (i23*args.ne22*args.ne21 + i21)*S_v + i20;
|
||||
|
||||
device const float * q_ptr = (device const float *) (q + i23*args.nb03 + i01*args.nb01);
|
||||
device const float * k_ptr = (device const float *) (k + i23*args.nb13 + i11*args.nb11);
|
||||
device const float * v_ptr = (device const float *) (v + i23*args.nb23 + i21*args.nb21);
|
||||
|
||||
device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21);
|
||||
device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G;
|
||||
|
||||
for (short t = 0; t < args.ne22; t++) {
|
||||
device const T * qt_ptr = (device const T *) (q_ptr);
|
||||
device const T * kt_ptr = (device const T *) (k_ptr);
|
||||
device const T * gt_ptr = (device const T *) (g_ptr);
|
||||
|
||||
if (G == 1) {
|
||||
*ls *= exp(g_ptr[0]);
|
||||
} else {
|
||||
// KDA
|
||||
*ls *= exp(gt_ptr[tx]);
|
||||
}
|
||||
|
||||
const float s_k = simd_sum(dot(*ls, kt_ptr[tx]));
|
||||
|
||||
const float d = (v_ptr[i20] - s_k)*b_ptr[0];
|
||||
|
||||
*ls += kt_ptr[tx]*d;
|
||||
|
||||
const float y = simd_sum(dot(*ls, qt_ptr[tx]));
|
||||
|
||||
if (tx == 0) {
|
||||
*dst_attn = y*scale;
|
||||
}
|
||||
|
||||
q_ptr += args.ns02;
|
||||
k_ptr += args.ns12;
|
||||
v_ptr += args.ns22;
|
||||
|
||||
b_ptr += args.ne21;
|
||||
g_ptr += args.ne21*G;
|
||||
|
||||
dst_attn += args.ne21*S_v;
|
||||
}
|
||||
|
||||
device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20;
|
||||
device T * dstt_state = (device T *) (dst_state);
|
||||
|
||||
FOR_UNROLL (short j = 0; j < NSG; j++) {
|
||||
const short is = tx*NSG + j;
|
||||
dst_state[is*S_v] = lsf[j];
|
||||
}
|
||||
|
||||
#undef S_v
|
||||
#undef G
|
||||
}
|
||||
|
||||
typedef decltype(kernel_gated_delta_net_impl<float4, 4>) kernel_gated_delta_net_t;
|
||||
|
||||
template [[host_name("kernel_gated_delta_net_f32_1")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<float, 1>;
|
||||
template [[host_name("kernel_gated_delta_net_f32_2")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<float2, 2>;
|
||||
template [[host_name("kernel_gated_delta_net_f32_4")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<float4, 4>;
|
||||
#endif
|
||||
|
||||
constant short FC_solve_tri_nsg [[function_constant(FC_SOLVE_TRI + 0)]];
|
||||
constant short FC_solve_tri_n [[function_constant(FC_SOLVE_TRI + 1)]];
|
||||
constant short FC_solve_tri_k [[function_constant(FC_SOLVE_TRI + 2)]];
|
||||
|
|
|
|||
|
|
@ -151,7 +151,8 @@ llama_context::llama_context(
|
|||
cparams.auto_fa = params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO;
|
||||
|
||||
cparams.fused_gdn_ar = true;
|
||||
cparams.fused_gdn_ch = false; // TODO: implement
|
||||
cparams.fused_gdn_ch = true;
|
||||
cparams.auto_fgdn = true;
|
||||
|
||||
// with causal attention, the batch size is limited by the context size
|
||||
cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
|
||||
|
|
@ -462,37 +463,81 @@ void llama_context::sched_reserve() {
|
|||
cparams.auto_fa = false;
|
||||
}
|
||||
|
||||
if (cparams.fused_gdn_ar) {
|
||||
auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true);
|
||||
if (!gf) {
|
||||
throw std::runtime_error("failed to reserve graph for fused Gated Delta Net check");
|
||||
}
|
||||
if (cparams.auto_fgdn) {
|
||||
LLAMA_LOG_INFO("%s: resolving fused Gated Delta Net support:\n", __func__);
|
||||
|
||||
const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FGDNAR) + 1;
|
||||
bool gdn_device_mismatch = false;
|
||||
for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
|
||||
ggml_tensor * n = ggml_graph_node(gf, i);
|
||||
if (n->op != GGML_OP_GATED_DELTA_NET) {
|
||||
continue;
|
||||
if (cparams.fused_gdn_ar) {
|
||||
auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true);
|
||||
if (!gf) {
|
||||
throw std::runtime_error("failed to reserve graph for fused Gated Delta Net check (autoregressive)");
|
||||
}
|
||||
ggml_backend_dev_t device_gdn = ggml_backend_get_device(ggml_backend_sched_get_tensor_backend(sched.get(), n));
|
||||
|
||||
GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FGDNAR "-", prefix_len) == 0);
|
||||
const int il = std::stoi(n->name + prefix_len);
|
||||
ggml_backend_dev_t device_kv = model.dev_layer(il);
|
||||
if (device_gdn != device_kv) {
|
||||
LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the fused Gated Delta Net tensor "
|
||||
"is assigned to device %s (usually due to missing support)\n",
|
||||
__func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_gdn));
|
||||
gdn_device_mismatch = true;
|
||||
break;
|
||||
const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FGDN_AR) + 1;
|
||||
bool gdn_device_mismatch = false;
|
||||
for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
|
||||
ggml_tensor * n = ggml_graph_node(gf, i);
|
||||
if (n->op != GGML_OP_GATED_DELTA_NET) {
|
||||
continue;
|
||||
}
|
||||
ggml_backend_dev_t device_gdn = ggml_backend_get_device(ggml_backend_sched_get_tensor_backend(sched.get(), n));
|
||||
|
||||
GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FGDN_AR "-", prefix_len) == 0);
|
||||
const int il = std::stoi(n->name + prefix_len);
|
||||
ggml_backend_dev_t device_kv = model.dev_layer(il);
|
||||
if (device_gdn != device_kv) {
|
||||
LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the fused Gated Delta Net tensor "
|
||||
"is assigned to device %s (usually due to missing support)\n",
|
||||
__func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_gdn));
|
||||
gdn_device_mismatch = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (gdn_device_mismatch) {
|
||||
cparams.fused_gdn_ar = false;
|
||||
LLAMA_LOG_WARN("%s: fused Gated Delta Net (autoregressive) not supported, set to disabled\n", __func__);
|
||||
} else {
|
||||
LLAMA_LOG_INFO("%s: fused Gated Delta Net (autoregressive) enabled\n", __func__);
|
||||
}
|
||||
}
|
||||
|
||||
if (gdn_device_mismatch) {
|
||||
cparams.fused_gdn_ar = false;
|
||||
LLAMA_LOG_WARN("%s: fused Gated Delta Net not supported, set to disabled\n", __func__);
|
||||
if (cparams.fused_gdn_ch) {
|
||||
// more than one token in the batch per sequence in order to take the chunked path
|
||||
auto * gf = graph_reserve(16*n_seqs, n_seqs, n_outputs, mctx.get(), true);
|
||||
if (!gf) {
|
||||
throw std::runtime_error("failed to reserve graph for fused Gated Delta Net check (chunked)");
|
||||
}
|
||||
|
||||
const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FGDN_CH) + 1;
|
||||
bool gdn_device_mismatch = false;
|
||||
for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
|
||||
ggml_tensor * n = ggml_graph_node(gf, i);
|
||||
if (n->op != GGML_OP_GATED_DELTA_NET) {
|
||||
continue;
|
||||
}
|
||||
ggml_backend_dev_t device_gdn = ggml_backend_get_device(ggml_backend_sched_get_tensor_backend(sched.get(), n));
|
||||
|
||||
GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FGDN_CH "-", prefix_len) == 0);
|
||||
const int il = std::stoi(n->name + prefix_len);
|
||||
ggml_backend_dev_t device_kv = model.dev_layer(il);
|
||||
if (device_gdn != device_kv) {
|
||||
LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the fused Gated Delta Net tensor "
|
||||
"is assigned to device %s (usually due to missing support)\n",
|
||||
__func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_gdn));
|
||||
gdn_device_mismatch = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (gdn_device_mismatch) {
|
||||
cparams.fused_gdn_ch = false;
|
||||
LLAMA_LOG_WARN("%s: fused Gated Delta Net (chunked) not supported, set to disabled\n", __func__);
|
||||
} else {
|
||||
LLAMA_LOG_INFO("%s: fused Gated Delta Net (chunked) enabled\n", __func__);
|
||||
}
|
||||
}
|
||||
|
||||
cparams.auto_fgdn = false;
|
||||
}
|
||||
|
||||
// reserve worst-case graph
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ struct llama_cparams {
|
|||
bool auto_fa;
|
||||
bool fused_gdn_ar; // use fused gated delta net (autoregressive)
|
||||
bool fused_gdn_ch; // use fused gated delta net (chunked)
|
||||
bool auto_fgdn;
|
||||
bool no_perf;
|
||||
bool warmup;
|
||||
bool op_offload;
|
||||
|
|
|
|||
|
|
@ -70,6 +70,6 @@ std::string llama_format_tensor_shape(const struct ggml_tensor * t);
|
|||
|
||||
std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i);
|
||||
|
||||
#define LLAMA_TENSOR_NAME_FATTN "__fattn__"
|
||||
#define LLAMA_TENSOR_NAME_FGDNAR "__fgdnar__"
|
||||
#define LLAMA_TENSOR_NAME_FGDNCH "__fgdnch__"
|
||||
#define LLAMA_TENSOR_NAME_FATTN "__fattn__"
|
||||
#define LLAMA_TENSOR_NAME_FGDN_AR "__fgdn_ar__"
|
||||
#define LLAMA_TENSOR_NAME_FGDN_CH "__fgdn_ch__"
|
||||
|
|
|
|||
|
|
@ -41,13 +41,6 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
|
|||
GGML_ASSERT(b->ne[0] == 1 && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs);
|
||||
GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs);
|
||||
|
||||
if (cparams.fused_gdn_ch) {
|
||||
//ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s);
|
||||
//cb(result, LLAMA_TENSOR_NAME_FGDNCH, il);
|
||||
|
||||
GGML_ABORT("not implemented yet");
|
||||
}
|
||||
|
||||
const float scale = 1.0f / sqrtf(S_k);
|
||||
|
||||
q = ggml_scale(ctx0, q, scale);
|
||||
|
|
@ -325,26 +318,6 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
|
|||
GGML_ASSERT(b->ne[0] == 1 && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs);
|
||||
GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs);
|
||||
|
||||
if (cparams.fused_gdn_ar) {
|
||||
ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s);
|
||||
cb(result, LLAMA_TENSOR_NAME_FGDNAR, il);
|
||||
|
||||
ggml_tensor * output = ggml_view_4d(ctx0, result,
|
||||
S_v, H_v, n_tokens, n_seqs,
|
||||
ggml_row_size(result->type, S_v),
|
||||
ggml_row_size(result->type, S_v * H_v),
|
||||
ggml_row_size(result->type, S_v * H_v * n_tokens), 0);
|
||||
|
||||
ggml_tensor * new_state = ggml_view_4d(ctx0, result,
|
||||
S_v, S_v, H_v, n_seqs,
|
||||
ggml_row_size(result->type, S_v),
|
||||
ggml_row_size(result->type, S_v * S_v),
|
||||
ggml_row_size(result->type, S_v * S_v * H_v),
|
||||
ggml_row_size(result->type, S_v * H_v * n_tokens * n_seqs));
|
||||
|
||||
return {output, new_state};
|
||||
}
|
||||
|
||||
const float scale = 1.0f / sqrtf(S_k);
|
||||
|
||||
q = ggml_scale(ctx0, q, scale);
|
||||
|
|
@ -401,3 +374,78 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
|
|||
|
||||
return {o, s};
|
||||
}
|
||||
|
||||
std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_net_fused(
|
||||
ggml_tensor * q,
|
||||
ggml_tensor * k,
|
||||
ggml_tensor * v,
|
||||
ggml_tensor * g,
|
||||
ggml_tensor * b,
|
||||
ggml_tensor * s,
|
||||
int il) {
|
||||
const int64_t S_k = q->ne[0];
|
||||
const int64_t H_k = q->ne[1];
|
||||
const int64_t n_tokens = q->ne[2];
|
||||
const int64_t n_seqs = q->ne[3];
|
||||
|
||||
const int64_t S_v = v->ne[0];
|
||||
const int64_t H_v = v->ne[1];
|
||||
|
||||
GGML_ASSERT(S_k == S_v);
|
||||
GGML_ASSERT(H_v % H_k == 0);
|
||||
|
||||
GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
|
||||
GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
|
||||
GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs);
|
||||
|
||||
GGML_ASSERT(g->ne[0] == 1 || g->ne[0] == S_v);
|
||||
GGML_ASSERT( g->ne[1] == H_v && g->ne[2] == n_tokens && g->ne[3] == n_seqs);
|
||||
GGML_ASSERT(b->ne[0] == 1 && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs);
|
||||
GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs);
|
||||
|
||||
ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s);
|
||||
if (n_tokens == 1) {
|
||||
cb(result, LLAMA_TENSOR_NAME_FGDN_AR, il);
|
||||
} else {
|
||||
cb(result, LLAMA_TENSOR_NAME_FGDN_CH, il);
|
||||
}
|
||||
|
||||
ggml_tensor * output = ggml_view_4d(ctx0, result,
|
||||
S_v, H_v, n_tokens, n_seqs,
|
||||
ggml_row_size(result->type, S_v),
|
||||
ggml_row_size(result->type, S_v * H_v),
|
||||
ggml_row_size(result->type, S_v * H_v * n_tokens), 0);
|
||||
|
||||
ggml_tensor * new_state = ggml_view_4d(ctx0, result,
|
||||
S_v, S_v, H_v, n_seqs,
|
||||
ggml_row_size(result->type, S_v),
|
||||
ggml_row_size(result->type, S_v * S_v),
|
||||
ggml_row_size(result->type, S_v * S_v * H_v),
|
||||
ggml_row_size(result->type, S_v * H_v * n_tokens * n_seqs));
|
||||
|
||||
return {output, new_state};
|
||||
}
|
||||
|
||||
std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_net(
|
||||
ggml_tensor * q,
|
||||
ggml_tensor * k,
|
||||
ggml_tensor * v,
|
||||
ggml_tensor * g,
|
||||
ggml_tensor * b,
|
||||
ggml_tensor * s,
|
||||
int il) {
|
||||
const int64_t n_seq_tokens = q->ne[2];
|
||||
|
||||
if (n_seq_tokens == 1) {
|
||||
if (cparams.fused_gdn_ar) {
|
||||
return build_delta_net_fused(q, k, v, g, b, s, il);
|
||||
}
|
||||
return build_delta_net_autoregressive(q, k, v, g, b, s, il);
|
||||
}
|
||||
|
||||
if (cparams.fused_gdn_ch) {
|
||||
return build_delta_net_fused(q, k, v, g, b, s, il);
|
||||
}
|
||||
|
||||
return build_delta_net_chunking(q, k, v, g, b, s, il);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -169,9 +169,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
|
|||
Kcur = ggml_l2_norm(ctx0, Kcur, eps_norm);
|
||||
|
||||
// Choose between build_delta_net_chunking and build_delta_net_recurrent based on n_tokens
|
||||
std::pair<ggml_tensor *, ggml_tensor *> attn_out = n_seq_tokens == 1 ?
|
||||
build_delta_net_autoregressive(Qcur, Kcur, Vcur, g1, beta, state, il) :
|
||||
build_delta_net_chunking(Qcur, Kcur, Vcur, g1, beta, state, il);
|
||||
auto attn_out = build_delta_net(Qcur, Kcur, Vcur, g1, beta, state, il);
|
||||
|
||||
ggml_tensor * output = ggml_cont(ctx0, attn_out.first);
|
||||
ggml_tensor * new_state = attn_out.second;
|
||||
|
|
|
|||
|
|
@ -44,6 +44,26 @@ struct llm_build_delta_net_base : public llm_graph_context {
|
|||
ggml_tensor * b,
|
||||
ggml_tensor * s,
|
||||
int il);
|
||||
|
||||
// use the ggml_gated_delta_net fused operator
|
||||
std::pair<ggml_tensor *, ggml_tensor *> build_delta_net_fused(
|
||||
ggml_tensor * q,
|
||||
ggml_tensor * k,
|
||||
ggml_tensor * v,
|
||||
ggml_tensor * g,
|
||||
ggml_tensor * b,
|
||||
ggml_tensor * s,
|
||||
int il);
|
||||
|
||||
// choose one of two implementations above based on the number of tokens
|
||||
std::pair<ggml_tensor *, ggml_tensor *> build_delta_net(
|
||||
ggml_tensor * q,
|
||||
ggml_tensor * k,
|
||||
ggml_tensor * v,
|
||||
ggml_tensor * g,
|
||||
ggml_tensor * b,
|
||||
ggml_tensor * s,
|
||||
int il);
|
||||
};
|
||||
|
||||
struct llm_build_rwkv6_base : public llm_graph_context {
|
||||
|
|
|
|||
|
|
@ -321,9 +321,9 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear(
|
|||
//v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
|
||||
|
||||
// if head keys and value keys are different, repeat to force tensors into matching shapes
|
||||
if (num_k_heads != num_v_heads) {
|
||||
// note: need explicit repeat only if we are not using the fused GDN
|
||||
if (num_k_heads != num_v_heads && (!cparams.fused_gdn_ar || !cparams.fused_gdn_ch)) {
|
||||
GGML_ASSERT(num_v_heads % num_k_heads == 0);
|
||||
// TODO: try to avoid these explicit repeats by utilizing op broadcast
|
||||
q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs);
|
||||
k_conv = ggml_repeat_4d(ctx0, k_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs);
|
||||
}
|
||||
|
|
@ -332,12 +332,8 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear(
|
|||
cb(k_conv, "k_conv_predelta", il);
|
||||
cb(v_conv, "v_conv_predelta", il);
|
||||
|
||||
std::pair<ggml_tensor *, ggml_tensor *> attn_out;
|
||||
if (n_seq_tokens == 1) {
|
||||
attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il);
|
||||
} else {
|
||||
attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, il);
|
||||
}
|
||||
auto attn_out = build_delta_net(q_conv, k_conv, v_conv, gate, beta, state, il);
|
||||
|
||||
ggml_tensor * output = attn_out.first;
|
||||
ggml_tensor * new_state = attn_out.second;
|
||||
cb(output, "attn_output", il);
|
||||
|
|
|
|||
|
|
@ -321,9 +321,9 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear(
|
|||
//v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
|
||||
|
||||
// if head keys and value keys are different, repeat to force tensors into matching shapes
|
||||
if (num_k_heads != num_v_heads) {
|
||||
// note: need explicit repeat only if we are not using the fused GDN
|
||||
if (num_k_heads != num_v_heads && (!cparams.fused_gdn_ar || !cparams.fused_gdn_ch)) {
|
||||
GGML_ASSERT(num_v_heads % num_k_heads == 0);
|
||||
// TODO: try to avoid these explicit repeats by utilizing op broadcast
|
||||
q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs);
|
||||
k_conv = ggml_repeat_4d(ctx0, k_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs);
|
||||
}
|
||||
|
|
@ -332,12 +332,8 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear(
|
|||
cb(k_conv, "k_conv_predelta", il);
|
||||
cb(v_conv, "v_conv_predelta", il);
|
||||
|
||||
std::pair<ggml_tensor *, ggml_tensor *> attn_out;
|
||||
if (n_seq_tokens == 1) {
|
||||
attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il);
|
||||
} else {
|
||||
attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, il);
|
||||
}
|
||||
auto attn_out = build_delta_net(q_conv, k_conv, v_conv, gate, beta, state, il);
|
||||
|
||||
ggml_tensor * output = attn_out.first;
|
||||
ggml_tensor * new_state = attn_out.second;
|
||||
cb(output, "attn_output", il);
|
||||
|
|
|
|||
|
|
@ -406,6 +406,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
|
|||
//v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
|
||||
|
||||
// if head keys and value keys are different, repeat to force tensors into matching shapes
|
||||
// TODO: avoid repeats for fused GDN, needs broadcast configuration for GDN op [TAG_GGML_GDN_BCAST]
|
||||
if (num_k_heads != num_v_heads) {
|
||||
GGML_ASSERT(num_v_heads % num_k_heads == 0);
|
||||
int64_t repeat_factor = num_v_heads / num_k_heads;
|
||||
|
|
@ -431,13 +432,8 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
|
|||
cb(k_conv, "k_conv_predelta", il);
|
||||
cb(v_conv, "v_conv_predelta", il);
|
||||
|
||||
// Choose between build_delta_net_chunking, build_delta_net_recurrent, and build_delta_net_autoregressive based on n_tokens
|
||||
std::pair<ggml_tensor *, ggml_tensor *> attn_out; // pair of (output, new_state)
|
||||
if (n_seq_tokens == 1) {
|
||||
attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il);
|
||||
} else {
|
||||
attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, il);
|
||||
}
|
||||
auto attn_out = build_delta_net(q_conv, k_conv, v_conv, gate, beta, state, il);
|
||||
|
||||
ggml_tensor * output = attn_out.first;
|
||||
ggml_tensor * new_state = attn_out.second;
|
||||
cb(output, "attn_output", il);
|
||||
|
|
|
|||
|
|
@ -8447,6 +8447,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
}
|
||||
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 1, 1));
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 16, 1, 1));
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 16, 1, 1, 1, true, true));
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 16, 1, 1, 1, false, true));
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 16, 64, 1, 2));
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 1));
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 2));
|
||||
|
|
@ -8456,10 +8459,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
// KDA (vector gate)
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 1, 1, 1, false, true));
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 1, 2, 1, false, true));
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 16, 1, 2, 1, false, true));
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 32, 4, 1, 1, false, true));
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 2, 1, false, true));
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 8, 32, 4, 2, 2, false, true));
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 2, 1, true, true));
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 16, 4, 2, 1, true, true));
|
||||
|
||||
#if 0
|
||||
// these tests are disabled to save execution time, sbut they can be handy for debugging
|
||||
|
|
|
|||
Loading…
Reference in New Issue