From c5a778891ba0ddbd4cbb507c823f970595b1adc2 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Sat, 7 Mar 2026 15:41:10 +0800 Subject: [PATCH] ggml: add GATED_DELTA_NET op (#19504) * ggml: add GATED_DELTA_NET op * remove the transpose * add KDA * add qwen35 dense * llama : check for fused gated delta net backend support --------- Co-authored-by: Georgi Gerganov --- ggml/include/ggml.h | 10 ++ ggml/src/ggml-cpu/ggml-cpu.c | 10 ++ ggml/src/ggml-cpu/ops.cpp | 184 ++++++++++++++++++++ ggml/src/ggml-cpu/ops.h | 1 + ggml/src/ggml-cuda/gated_delta_net.cu | 223 +++++++++++++++++++++++++ ggml/src/ggml-cuda/gated_delta_net.cuh | 4 + ggml/src/ggml-cuda/ggml-cuda.cu | 5 + ggml/src/ggml.c | 57 ++++++- src/llama-context.cpp | 42 ++++- src/llama-cparams.h | 2 + src/llama-impl.h | 4 +- src/models/delta-net-base.cpp | 29 ++++ src/models/qwen35.cpp | 3 +- src/models/qwen35moe.cpp | 3 +- tests/test-backend-ops.cpp | 60 +++++++ 15 files changed, 627 insertions(+), 10 deletions(-) create mode 100644 ggml/src/ggml-cuda/gated_delta_net.cu create mode 100644 ggml/src/ggml-cuda/gated_delta_net.cuh diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 784d69206b..566e271479 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -556,6 +556,7 @@ extern "C" { GGML_OP_GATED_LINEAR_ATTN, GGML_OP_RWKV_WKV7, GGML_OP_SOLVE_TRI, + GGML_OP_GATED_DELTA_NET, GGML_OP_UNARY, @@ -2463,6 +2464,15 @@ extern "C" { bool lower, bool uni); + GGML_API struct ggml_tensor * ggml_gated_delta_net( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * g, + struct ggml_tensor * beta, + struct ggml_tensor * state); + // custom operators typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata); diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 7c4026fac4..dc2b5ffaa7 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2021,6 +2021,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_solve_tri(params, tensor); } break; + case GGML_OP_GATED_DELTA_NET: + { + ggml_compute_forward_gated_delta_net(params, tensor); + } break; case GGML_OP_MAP_CUSTOM1: { ggml_compute_forward_map_custom1(params, tensor); @@ -2200,6 +2204,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { } break; case GGML_OP_COUNT_EQUAL: case GGML_OP_SOLVE_TRI: + case GGML_OP_GATED_DELTA_NET: { n_tasks = n_threads; } break; @@ -2905,6 +2910,11 @@ struct ggml_cplan ggml_graph_plan( { cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks); } break; + case GGML_OP_GATED_DELTA_NET: + { + const int64_t S_v = node->src[2]->ne[0]; + cur = S_v * sizeof(float) * n_tasks; + } break; case GGML_OP_COUNT: { GGML_ABORT("fatal error"); diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 2c372f9635..331e071a26 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -10380,6 +10380,190 @@ void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, s } } +// ggml_compute_forward_gated_delta_net +static void ggml_compute_forward_gated_delta_net_one_chunk( + const ggml_compute_params * params, + ggml_tensor * dst, + int64_t ir0, + int64_t ir1) { + + ggml_tensor * src_q = dst->src[0]; + ggml_tensor * src_k = dst->src[1]; + ggml_tensor * src_v = dst->src[2]; + ggml_tensor * src_g = dst->src[3]; + ggml_tensor * src_beta = dst->src[4]; + ggml_tensor * src_state = dst->src[5]; + + const int64_t S_v = src_v->ne[0]; + const int64_t H = src_v->ne[1]; + const int64_t n_tokens = src_v->ne[2]; + const int64_t n_seqs = src_v->ne[3]; + + GGML_ASSERT(ggml_is_contiguous_rows(src_q)); + GGML_ASSERT(ggml_is_contiguous_rows(src_k)); + GGML_ASSERT(ggml_is_contiguous_rows(src_v)); + GGML_ASSERT(ggml_is_contiguous(src_g)); + GGML_ASSERT(ggml_is_contiguous(src_beta)); + GGML_ASSERT(ggml_is_contiguous(src_state)); + + GGML_ASSERT(src_g->ne[0] == 1 || src_g->ne[0] == S_v); + GGML_ASSERT(src_beta->ne[0] == 1); + + GGML_TENSOR_LOCALS(int64_t, neq, src_q, ne); + GGML_TENSOR_LOCALS(size_t, nbq, src_q, nb); + GGML_TENSOR_LOCALS(int64_t, nek, src_k, ne); + GGML_TENSOR_LOCALS(size_t, nbk, src_k, nb); + GGML_TENSOR_LOCALS(int64_t, nev, src_v, ne); + GGML_TENSOR_LOCALS(size_t, nbv, src_v, nb); + GGML_TENSOR_LOCALS(int64_t, neg, src_g, ne); + GGML_TENSOR_LOCALS(size_t, nbg, src_g, nb); + GGML_TENSOR_LOCALS(size_t, nbb, src_beta, nb); + + const bool kda = (neg0 == S_v); + + // scratch layout per thread: [delta(S_v)] + const int64_t scratch_per_thread = S_v; + const int ith = params->ith; + + float * delta = (float *)params->wdata + ith * scratch_per_thread + CACHE_LINE_SIZE_F32; + + // output layout: [attn_scores | new_states] + // attn_scores: S_v * H * n_tokens * n_seqs floats + // new_states: S_v * S_v * H * n_seqs floats + const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs; + float * attn_out_base = (float *)dst->data; + float * state_out_base = (float *)dst->data + attn_score_elems; + + const float * state_in_base = (const float *)src_state->data; + + const int64_t rq1 = nev1 / neq1; + const int64_t rk1 = nev1 / nek1; + const int64_t rq3 = nev3 / neq3; + const int64_t rk3 = nev3 / nek3; + + const float scale = 1.0f / sqrtf((float) S_v); + + for (int64_t ir = ir0; ir < ir1; ++ir) { + const int64_t iv1 = ir % H; // head_index + const int64_t iv3 = ir / H; // sequence + + const int64_t iq1 = iv1 / rq1; + const int64_t ik1 = iv1 / rk1; + + const int64_t iq3 = iv3 / rq3; + const int64_t ik3 = iv3 / rk3; + + float * s_out = state_out_base + (iv3 * H + iv1) * S_v * S_v; + + // copy input state into output buffer and operate in-place + const float * s_in = state_in_base + (iv3 * H + iv1) * S_v * S_v; + memcpy(s_out, s_in, S_v * S_v * sizeof(float)); + + // attn output pointer for first token of this (head, seq) + float * attn_data = attn_out_base + (iv3 * n_tokens * H + iv1) * S_v; + + for (int64_t t = 0; t < n_tokens; t++) { + const float * q_d = (const float *)((const char *)src_q->data + iq3 * nbq3 + t * nbq2 + iq1 * nbq1); + const float * k_d = (const float *)((const char *)src_k->data + ik3 * nbk3 + t * nbk2 + ik1 * nbk1); + const float * v_d = (const float *)((const char *)src_v->data + iv3 * nbv3 + t * nbv2 + iv1 * nbv1); + + const float beta_val = *(const float *)((const char *)src_beta->data + iv3 * nbb3 + t * nbb2 + iv1 * nbb1); + const float * g_d = (const float *)((const char *)src_g->data + iv3 * nbg3 + t * nbg2 + iv1 * nbg1); + + if (kda) { + for (int64_t i = 0; i < S_v; ++i) { + ggml_vec_scale_f32(S_v, &s_out[i * S_v], expf(g_d[i])); + } + } else { + ggml_vec_scale_f32(S_v * S_v, s_out, expf(g_d[0])); + } + + // delta[j] = sum_i S[j][i] * k[i] + memset(delta, 0, S_v * sizeof(float)); + for (int64_t i = 0; i < S_v; ++i) { + ggml_vec_mad_f32(S_v, delta, &s_out[i * S_v], k_d[i]); + } + for (int64_t j = 0; j < S_v; ++j) { + delta[j] = (v_d[j] - delta[j]) * beta_val; + } + + // outer product: S[j][i] += k[i] * delta[j] + for (int64_t i = 0; i < S_v; ++i) { + ggml_vec_mad_f32(S_v, &s_out[i * S_v], delta, k_d[i]); + } + + // attn_out[j] = sum_i S[j][i] * q[i] + memset(attn_data, 0, S_v * sizeof(float)); + for (int64_t i = 0; i < S_v; ++i) { + ggml_vec_mad_f32(S_v, attn_data, &s_out[i * S_v], q_d[i]); + } + ggml_vec_scale_f32(S_v, attn_data, scale); + + attn_data += S_v * H; // advance to next token + } + + } +} + + +static void ggml_compute_forward_gated_delta_net_f32( + const ggml_compute_params * params, + ggml_tensor * dst) { + + ggml_tensor * V = dst->src[2]; + int64_t nr = V->ne[1] * V->ne[3]; + + // disable for NUMA + const bool disable_chunking = ggml_is_numa(); + + int nth = params->nth; + int ith = params->ith; + + // 4x chunks per thread + int nth_scaled = nth * 4; + int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled; + int64_t nchunk = (nr + chunk_size - 1) / chunk_size; + + if (nth == 1 || nchunk < nth || disable_chunking) { + nchunk = nth; + } + + if (ith == 0) { + ggml_threadpool_chunk_set(params->threadpool, nth); + } + + ggml_barrier(params->threadpool); + + const int64_t dr = (nr + nchunk - 1) / nchunk; + + int current_chunk = ith; + + while (current_chunk < nchunk) { + const int64_t ir0 = dr * current_chunk; + const int64_t ir1 = MIN(ir0 + dr, nr); + + ggml_compute_forward_gated_delta_net_one_chunk(params, dst, ir0, ir1); + current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1); + } +} + +void ggml_compute_forward_gated_delta_net( + const ggml_compute_params * params, + ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_gated_delta_net_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_rwkv_wkv7 static void ggml_compute_forward_rwkv_wkv7_f32( diff --git a/ggml/src/ggml-cpu/ops.h b/ggml/src/ggml-cpu/ops.h index 0fdfee7976..3fa1443abc 100644 --- a/ggml/src/ggml-cpu/ops.h +++ b/ggml/src/ggml-cpu/ops.h @@ -102,6 +102,7 @@ void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, s void ggml_compute_forward_rwkv_wkv7(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_gla(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_gated_delta_net(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_map_custom1(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_map_custom2(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_map_custom3(const struct ggml_compute_params * params, struct ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu new file mode 100644 index 0000000000..d8e8111455 --- /dev/null +++ b/ggml/src/ggml-cuda/gated_delta_net.cu @@ -0,0 +1,223 @@ +#include "gated_delta_net.cuh" +#include "ggml-cuda/common.cuh" + +template +__global__ void gated_delta_net_cuda(const float * q, + const float * k, + const float * v, + const float * g, + const float * beta, + const float * curr_state, + float * dst, + int64_t H, + int64_t n_tokens, + int64_t n_seqs, + int64_t sq1, + int64_t sq2, + int64_t sq3, + int64_t sv1, + int64_t sv2, + int64_t sv3, + int64_t sb1, + int64_t sb2, + int64_t sb3, + int64_t rq1, + int64_t rq3, + float scale) { + const int64_t h_idx = blockIdx.x; + const int64_t sequence = blockIdx.y; + const int col = threadIdx.x; // each thread owns one column + + const int64_t iq1 = h_idx / rq1; + const int64_t iq3 = sequence / rq3; + + const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs; + float * attn_data = dst; + float * state = dst + attn_score_elems; + + const int64_t state_offset = (sequence * H + h_idx) * S_v * S_v; + state += state_offset; + curr_state += state_offset; + attn_data += (sequence * n_tokens * H + h_idx) * S_v; + + // Load state column into registers + float s[S_v]; +#pragma unroll + for (int i = 0; i < S_v; i++) { + s[i] = curr_state[i * S_v + col]; + } + + for (int t = 0; t < n_tokens; t++) { + const float * q_t = q + iq3 * sq3 + t * sq2 + iq1 * sq1; + const float * k_t = k + iq3 * sq3 + t * sq2 + iq1 * sq1; + const float * v_t = v + sequence * sv3 + t * sv2 + h_idx * sv1; + + const int64_t gb_offset = sequence * sb3 + t * sb2 + h_idx * sb1; + const float * beta_t = beta + gb_offset; + const float * g_t = g + gb_offset * (KDA ? S_v : 1); + + const float beta_val = *beta_t; + + if constexpr (!KDA) { + const float g_val = expf(*g_t); + + // kv[col] = (S^T @ k)[col] = sum_i S[i][col] * k[i] + float kv_col = 0.0f; +#pragma unroll + for (int i = 0; i < S_v; i++) { + kv_col += s[i] * k_t[i]; + } + + // delta[col] = (v[col] - g * kv[col]) * beta + float delta_col = (v_t[col] - g_val * kv_col) * beta_val; + + // fused: S[i][col] = g * S[i][col] + k[i] * delta[col] + // attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i] + float attn_col = 0.0f; +#pragma unroll + for (int i = 0; i < S_v; i++) { + s[i] = g_val * s[i] + k_t[i] * delta_col; + attn_col += s[i] * q_t[i]; + } + + attn_data[col] = attn_col * scale; + } else { + // kv[col] = sum_i g[i] * S[i][col] * k[i] + float kv_col = 0.0f; +#pragma unroll + for (int i = 0; i < S_v; i++) { + kv_col += expf(g_t[i]) * s[i] * k_t[i]; + } + + // delta[col] = (v[col] - kv[col]) * beta + float delta_col = (v_t[col] - kv_col) * beta_val; + + // fused: S[i][col] = g[i] * S[i][col] + k[i] * delta[col] + // attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i] + float attn_col = 0.0f; +#pragma unroll + for (int i = 0; i < S_v; i++) { + s[i] = expf(g_t[i]) * s[i] + k_t[i] * delta_col; + attn_col += s[i] * q_t[i]; + } + + attn_data[col] = attn_col * scale; + } + + attn_data += S_v * H; + } + + // Write state back to global memory +#pragma unroll + for (int i = 0; i < S_v; i++) { + state[i * S_v + col] = s[i]; + } +} + +template +static void launch_gated_delta_net( + const float * q_d, const float * k_d, const float * v_d, + const float * g_d, const float * b_d, const float * s_d, + float * dst_d, + int64_t S_v, int64_t H, int64_t n_tokens, int64_t n_seqs, + int64_t sq1, int64_t sq2, int64_t sq3, + int64_t sv1, int64_t sv2, int64_t sv3, + int64_t sb1, int64_t sb2, int64_t sb3, + int64_t rq1, int64_t rq3, + float scale, cudaStream_t stream) { + + dim3 grid_dims(H, n_seqs, 1); + dim3 block_dims(S_v, 1, 1); + + switch (S_v) { + case 32: + gated_delta_net_cuda<32, KDA><<>>( + q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, + n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, rq1, rq3, scale); + break; + case 64: + gated_delta_net_cuda<64, KDA><<>>( + q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, + n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, rq1, rq3, scale); + break; + case 128: + gated_delta_net_cuda<128, KDA><<>>( + q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, + n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, rq1, rq3, scale); + break; + default: + GGML_ABORT("fatal error"); + break; + } +} + +void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_tensor * src_q = dst->src[0]; + ggml_tensor * src_k = dst->src[1]; + ggml_tensor * src_v = dst->src[2]; + ggml_tensor * src_g = dst->src[3]; + ggml_tensor * src_beta = dst->src[4]; + ggml_tensor * src_state = dst->src[5]; + + GGML_TENSOR_LOCALS(int64_t, neq, src_q, ne); + GGML_TENSOR_LOCALS(size_t, nbq, src_q, nb); + GGML_TENSOR_LOCALS(int64_t, nev, src_v, ne); + GGML_TENSOR_LOCALS(size_t, nbv, src_v, nb); + GGML_TENSOR_LOCALS(size_t, nbb, src_beta, nb); + + const int64_t S_v = nev0; + const int64_t H = nev1; + const int64_t n_tokens = nev2; + const int64_t n_seqs = nev3; + + const bool kda = (src_g->ne[0] == S_v); + + const int64_t rq1 = nev1 / neq1; + const int64_t rq3 = nev3 / neq3; + + const float * q_d = (const float *) src_q->data; + const float * k_d = (const float *) src_k->data; + const float * v_d = (const float *) src_v->data; + const float * g_d = (const float *) src_g->data; + const float * b_d = (const float *) src_beta->data; + + const float * s_d = (const float *) src_state->data; + float * dst_d = (float *) dst->data; + + GGML_ASSERT(ggml_is_contiguous_rows(src_q)); + GGML_ASSERT(ggml_is_contiguous_rows(src_k)); + GGML_ASSERT(ggml_is_contiguous_rows(src_v)); + GGML_ASSERT(ggml_are_same_stride(src_q, src_k)); + GGML_ASSERT(src_g->ne[0] == 1 || kda); + GGML_ASSERT(ggml_is_contiguous(src_g)); + GGML_ASSERT(ggml_is_contiguous(src_beta)); + GGML_ASSERT(ggml_is_contiguous(src_state)); + + // strides in floats (beta strides used for both g and beta offset computation) + const int64_t sq1 = nbq1 / sizeof(float); + const int64_t sq2 = nbq2 / sizeof(float); + const int64_t sq3 = nbq3 / sizeof(float); + const int64_t sv1 = nbv1 / sizeof(float); + const int64_t sv2 = nbv2 / sizeof(float); + const int64_t sv3 = nbv3 / sizeof(float); + const int64_t sb1 = nbb1 / sizeof(float); + const int64_t sb2 = nbb2 / sizeof(float); + const int64_t sb3 = nbb3 / sizeof(float); + + const float scale = 1.0f / sqrtf((float) S_v); + + cudaStream_t stream = ctx.stream(); + + if (kda) { + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, rq1, rq3, scale, stream); + } else { + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, rq1, rq3, scale, stream); + } +} diff --git a/ggml/src/ggml-cuda/gated_delta_net.cuh b/ggml/src/ggml-cuda/gated_delta_net.cuh new file mode 100644 index 0000000000..7375e81c0c --- /dev/null +++ b/ggml/src/ggml-cuda/gated_delta_net.cuh @@ -0,0 +1,4 @@ +#include "common.cuh" +#include "ggml.h" + +void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 54dc43bc08..a8007a0636 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -53,6 +53,7 @@ #include "ggml-cuda/upscale.cuh" #include "ggml-cuda/wkv.cuh" #include "ggml-cuda/gla.cuh" +#include "ggml-cuda/gated_delta_net.cuh" #include "ggml-cuda/set.cuh" #include "ggml-cuda/set-rows.cuh" #include "ggml-cuda/pad_reflect_1d.cuh" @@ -2733,6 +2734,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_GATED_LINEAR_ATTN: ggml_cuda_op_gated_linear_attn(ctx, dst); break; + case GGML_OP_GATED_DELTA_NET: + ggml_cuda_op_gated_delta_net(ctx, dst); + break; case GGML_OP_RWKV_WKV7: ggml_cuda_op_rwkv_wkv7(ctx, dst); break; @@ -4972,6 +4976,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_LEAKY_RELU: case GGML_OP_RWKV_WKV6: case GGML_OP_GATED_LINEAR_ATTN: + case GGML_OP_GATED_DELTA_NET: case GGML_OP_RWKV_WKV7: return true; case GGML_OP_FLASH_ATTN_EXT: diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index d644cca8a6..aeafc395d7 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1031,6 +1031,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "GATED_LINEAR_ATTN", "RWKV_WKV7", "SOLVE_TRI", + "GATED_DELTA_NET", "UNARY", @@ -1048,7 +1049,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "GLU", }; -static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95"); +static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1140,6 +1141,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "gated_linear_attn(k, v, q, gate, s)", "rwkv_wkv7(r, w, k, v, a, b, s)", "A X = B, A triangular, solve X", + "gated_delta_net(q, k, v, g, beta, s)", "unary(x)", @@ -1157,7 +1159,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "glu(x)", }; -static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95"); +static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -6124,6 +6126,57 @@ struct ggml_tensor * ggml_solve_tri( return result; } +// ggml_gated_delta_net + +struct ggml_tensor * ggml_gated_delta_net( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * g, + struct ggml_tensor * beta, + struct ggml_tensor * state) { + GGML_ASSERT(ggml_is_contiguous_rows(q)); + GGML_ASSERT(ggml_is_contiguous_rows(k)); + GGML_ASSERT(ggml_is_contiguous_rows(v)); + GGML_ASSERT(ggml_is_contiguous(g)); + GGML_ASSERT(ggml_is_contiguous(beta)); + GGML_ASSERT(ggml_is_contiguous(state)); + + GGML_ASSERT(q->type == GGML_TYPE_F32); + GGML_ASSERT(k->type == GGML_TYPE_F32); + GGML_ASSERT(v->type == GGML_TYPE_F32); + GGML_ASSERT(g->type == GGML_TYPE_F32); + GGML_ASSERT(beta->type == GGML_TYPE_F32); + GGML_ASSERT(state->type == GGML_TYPE_F32); + + const int64_t S_v = v->ne[0]; + const int64_t H = v->ne[1]; + const int64_t n_tokens = v->ne[2]; + const int64_t n_seqs = v->ne[3]; + + // gate: scalar [1, H, T, B] or vector [S_v, H, T, B] (KDA) + GGML_ASSERT(g->ne[0] == 1 || g->ne[0] == S_v); + GGML_ASSERT(beta->ne[0] == 1); + + GGML_ASSERT(ggml_nelements(state) == S_v * S_v * H * n_seqs); + + // concat output and new_state into a single tensor + // output: S_v * H * n_tokens * n_seqs, state: S_v * S_v * H * n_seqs + const int64_t ne[4] = { S_v * H, n_tokens * n_seqs + S_v * n_seqs, 1, 1 }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + result->op = GGML_OP_GATED_DELTA_NET; + result->src[0] = q; + result->src[1] = k; + result->src[2] = v; + result->src[3] = g; + result->src[4] = beta; + result->src[5] = state; + + return result; +} + //////////////////////////////////////////////////////////////////////////////// struct ggml_hash_set ggml_hash_set_new(size_t size) { diff --git a/src/llama-context.cpp b/src/llama-context.cpp index d050604b2a..abaa5c0f8d 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -150,6 +150,9 @@ llama_context::llama_context( cparams.flash_attn = params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED; cparams.auto_fa = params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO; + cparams.fused_gdn_ar = true; + cparams.fused_gdn_ch = false; // TODO: implement + // with causal attention, the batch size is limited by the context size cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch; @@ -422,7 +425,7 @@ void llama_context::sched_reserve() { if (cparams.auto_fa) { auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true); if (!gf) { - throw std::runtime_error("failed to split graph for Flash Attention check"); + throw std::runtime_error("failed to reserve graph for Flash Attention check"); } const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1; @@ -432,8 +435,7 @@ void llama_context::sched_reserve() { if (n->op != GGML_OP_FLASH_ATTN_EXT) { continue; } - ggml_backend_dev_t device_fa = ggml_backend_get_device( - ggml_backend_sched_get_tensor_backend(sched.get(), n)); + ggml_backend_dev_t device_fa = ggml_backend_get_device(ggml_backend_sched_get_tensor_backend(sched.get(), n)); // TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0); @@ -448,6 +450,7 @@ void llama_context::sched_reserve() { break; } } + if (fa_device_mismatch) { cparams.flash_attn = false; LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__); @@ -459,6 +462,39 @@ void llama_context::sched_reserve() { cparams.auto_fa = false; } + if (cparams.fused_gdn_ar) { + auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true); + if (!gf) { + throw std::runtime_error("failed to reserve graph for fused Gated Delta Net check"); + } + + const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FGDNAR) + 1; + bool gdn_device_mismatch = false; + for (int i = 0; i < ggml_graph_n_nodes(gf); i++) { + ggml_tensor * n = ggml_graph_node(gf, i); + if (n->op != GGML_OP_GATED_DELTA_NET) { + continue; + } + ggml_backend_dev_t device_gdn = ggml_backend_get_device(ggml_backend_sched_get_tensor_backend(sched.get(), n)); + + GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FGDNAR "-", prefix_len) == 0); + const int il = std::stoi(n->name + prefix_len); + ggml_backend_dev_t device_kv = model.dev_layer(il); + if (device_gdn != device_kv) { + LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the fused Gated Delta Net tensor " + "is assigned to device %s (usually due to missing support)\n", + __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_gdn)); + gdn_device_mismatch = true; + break; + } + } + + if (gdn_device_mismatch) { + cparams.fused_gdn_ar = false; + LLAMA_LOG_WARN("%s: fused Gated Delta Net not supported, set to disabled\n", __func__); + } + } + // reserve worst-case graph int n_splits_pp = -1; int n_nodes_pp = -1; diff --git a/src/llama-cparams.h b/src/llama-cparams.h index 2da3bbd6f9..333922468c 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -31,6 +31,8 @@ struct llama_cparams { bool offload_kqv; bool flash_attn; bool auto_fa; + bool fused_gdn_ar; // use fused gated delta net (autoregressive) + bool fused_gdn_ch; // use fused gated delta net (chunked) bool no_perf; bool warmup; bool op_offload; diff --git a/src/llama-impl.h b/src/llama-impl.h index dfd9fee9f4..ee27ac1bea 100644 --- a/src/llama-impl.h +++ b/src/llama-impl.h @@ -70,4 +70,6 @@ std::string llama_format_tensor_shape(const struct ggml_tensor * t); std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i); -#define LLAMA_TENSOR_NAME_FATTN "__fattn__" +#define LLAMA_TENSOR_NAME_FATTN "__fattn__" +#define LLAMA_TENSOR_NAME_FGDNAR "__fgdnar__" +#define LLAMA_TENSOR_NAME_FGDNCH "__fgdnch__" diff --git a/src/models/delta-net-base.cpp b/src/models/delta-net-base.cpp index c57abbb5b7..b0be62fc68 100644 --- a/src/models/delta-net-base.cpp +++ b/src/models/delta-net-base.cpp @@ -1,5 +1,7 @@ #include "models.h" +#include "llama-impl.h" + // utility to get one slice from the third dimension // input dim: [x, y, c, b] // output dim: [x, y, 1, b] @@ -39,6 +41,13 @@ std::pair llm_build_delta_net_base::build_delta_ne GGML_ASSERT(b->ne[0] == 1 && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs); GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs); + if (cparams.fused_gdn_ch) { + //ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s); + //cb(result, LLAMA_TENSOR_NAME_FGDNCH, il); + + GGML_ABORT("not implemented yet"); + } + const float scale = 1.0f / sqrtf(S_k); q = ggml_scale(ctx0, q, scale); @@ -316,6 +325,26 @@ std::pair llm_build_delta_net_base::build_delta_ne GGML_ASSERT(b->ne[0] == 1 && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs); GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs); + if (cparams.fused_gdn_ar) { + ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s); + cb(result, LLAMA_TENSOR_NAME_FGDNAR, il); + + ggml_tensor * output = ggml_view_4d(ctx0, result, + S_v, H_v, n_tokens, n_seqs, + ggml_row_size(result->type, S_v), + ggml_row_size(result->type, S_v * H_v), + ggml_row_size(result->type, S_v * H_v * n_tokens), 0); + + ggml_tensor * new_state = ggml_view_4d(ctx0, result, + S_v, S_v, H_v, n_seqs, + ggml_row_size(result->type, S_v), + ggml_row_size(result->type, S_v * S_v), + ggml_row_size(result->type, S_v * S_v * H_v), + ggml_row_size(result->type, S_v * H_v * n_tokens * n_seqs)); + + return {output, new_state}; + } + const float scale = 1.0f / sqrtf(S_k); q = ggml_scale(ctx0, q, scale); diff --git a/src/models/qwen35.cpp b/src/models/qwen35.cpp index bacf7a4c2e..afc5a1aad7 100644 --- a/src/models/qwen35.cpp +++ b/src/models/qwen35.cpp @@ -332,8 +332,7 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear( cb(k_conv, "k_conv_predelta", il); cb(v_conv, "v_conv_predelta", il); - // Choose between build_delta_net_chunking, build_delta_net_recurrent, and build_delta_net_autoregressive based on n_tokens - std::pair attn_out; // pair of (output, new_state) + std::pair attn_out; if (n_seq_tokens == 1) { attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il); } else { diff --git a/src/models/qwen35moe.cpp b/src/models/qwen35moe.cpp index 22d708f206..17291ec230 100644 --- a/src/models/qwen35moe.cpp +++ b/src/models/qwen35moe.cpp @@ -332,8 +332,7 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear( cb(k_conv, "k_conv_predelta", il); cb(v_conv, "v_conv_predelta", il); - // Choose between build_delta_net_chunking, build_delta_net_recurrent, and build_delta_net_autoregressive based on n_tokens - std::pair attn_out; // pair of (output, new_state) + std::pair attn_out; if (n_seq_tokens == 1) { attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il); } else { diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index faa771e086..32a83b001d 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -3665,6 +3665,51 @@ struct test_rwkv_wkv6 : public test_case { } }; +// GGML_OP_GATED_DELTA_NET +struct test_gated_delta_net : public test_case { + const ggml_type type; + + const int64_t head_count; + const int64_t head_size; + const int64_t n_seq_tokens; + const int64_t n_seqs; + const int v_repeat; + const bool permuted; + const bool kda; + + std::string vars() override { + return VARS_TO_STR8(type, head_count, head_size, n_seq_tokens, n_seqs, v_repeat, permuted, kda); + } + + test_gated_delta_net(ggml_type type = GGML_TYPE_F32, + int64_t head_count = 4, int64_t head_size = 16, int64_t n_seq_tokens = 1, int64_t n_seqs = 1, + int v_repeat = 1, bool permuted = false, bool kda = false) + : type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs), + v_repeat(v_repeat), permuted(permuted), kda(kda) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * q; + ggml_tensor * k; + ggml_tensor * v; + if (permuted) { + // create with dims 1 and 2 swapped, then permute back to get non-contiguous layout + q = ggml_permute(ctx, ggml_new_tensor_4d(ctx, type, head_size, n_seq_tokens, head_count, n_seqs), 0, 2, 1, 3); + k = ggml_permute(ctx, ggml_new_tensor_4d(ctx, type, head_size, n_seq_tokens, head_count, n_seqs), 0, 2, 1, 3); + v = ggml_permute(ctx, ggml_new_tensor_4d(ctx, type, head_size, n_seq_tokens, head_count * v_repeat, n_seqs), 0, 2, 1, 3); + } else { + q = ggml_new_tensor_4d(ctx, type, head_size, head_count, n_seq_tokens, n_seqs); + k = ggml_new_tensor_4d(ctx, type, head_size, head_count, n_seq_tokens, n_seqs); + v = ggml_new_tensor_4d(ctx, type, head_size, head_count * v_repeat, n_seq_tokens, n_seqs); + } + const int64_t g_ne0 = kda ? head_size : 1; + ggml_tensor * g = ggml_new_tensor_4d(ctx, type, g_ne0, head_count * v_repeat, n_seq_tokens, n_seqs); + ggml_tensor * beta = ggml_new_tensor_4d(ctx, type, 1, head_count * v_repeat, n_seq_tokens, n_seqs); + ggml_tensor * state = ggml_new_tensor_2d(ctx, type, head_size * v_repeat * head_size * head_count, n_seqs); + ggml_tensor * out = ggml_gated_delta_net(ctx, q, k, v, g, beta, state); + return out; + } +}; + // GGML_OP_GATED_LINEAR_ATTN struct test_gla : public test_case { const ggml_type type; @@ -8405,6 +8450,21 @@ static std::vector> make_test_cases_eval() { } } + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 1, 1)); + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 16, 64, 1, 2)); + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 1)); + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 2)); + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 8, 32, 4, 2, 2)); + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 2, 1, true)); + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 1, 1, true)); + // KDA (vector gate) + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 1, 1, 1, false, true)); + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 1, 2, 1, false, true)); + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 32, 4, 1, 1, false, true)); + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 2, 1, false, true)); + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 8, 32, 4, 2, 2, false, true)); + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 2, 1, true, true)); + #if 0 // these tests are disabled to save execution time, sbut they can be handy for debugging test_cases.emplace_back(new test_llama(2, true));