From deee23863bb75a67471c62b605248d3ff725c9f0 Mon Sep 17 00:00:00 2001 From: ProgenyAlpha Date: Thu, 12 Mar 2026 06:32:04 -0400 Subject: [PATCH] vulkan: add GATED_DELTA_NET op support (#20334) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * vulkan: add GATED_DELTA_NET op support Implements the fused gated delta net recurrence as a Vulkan compute shader with full support for scalar gate, KDA vector gate, GQA broadcast, multi-token sequences, and permuted (non-contiguous) q/k inputs. Specialization constants select head size (32/64/128) and KDA mode at pipeline creation time. Passes all 13 test-backend-ops cases on AMD Radeon 890M (RADV GFX1150). Co-Authored-By: Claude Opus 4.6 * vulkan: optimize GATED_DELTA_NET shader (Phase 1) - vec4 dot products on all inner loops (dp4 hardware intrinsic) - Cache exp(g) in shared memory for KDA path, eliminating ~32K redundant global reads and ~16K redundant exp() calls per token - vec4 fused decay + rank-1 update (3 vec4 ops vs 12 scalar ops) - Add perf benchmark cases for GATED_DELTA_NET to test-backend-ops KDA TG: +5.4% throughput. Non-KDA: no regressions. 13/13 test-backend-ops passing on AMD Radeon 890M (RADV GFX1150). Co-Authored-By: Claude Opus 4.6 * vulkan: address review feedback for GATED_DELTA_NET Pipeline array refactor [3][2], A_TYPE/D_TYPE/FLOAT_TYPE shader macros, scale in push constants, supports_op fix, dispatch restructuring. Co-Authored-By: Claude Opus 4.6 * vulkan: use FLOAT_TYPE for buffer/shared declarations, align formatting Co-Authored-By: Claude Opus 4.6 * vulkan: add explicit FLOAT_TYPE casts for buffer loads Wrap data_q, data_k, and data_g buffer reads with FLOAT_TYPE() casts to ensure correct behavior across all Vulkan configurations. Co-Authored-By: Claude Opus 4.6 * vulkan: fix Q/K broadcast for interleaved head layout Adapt to the interleaved broadcast convention from #20340: head_id / rq1 → head_id % neq1 Co-Authored-By: Claude Opus 4.6 --------- Co-authored-by: Progeny Alpha Co-authored-by: Claude Opus 4.6 --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 119 ++++++++++++++++ .../vulkan-shaders/gated_delta_net.comp | 128 ++++++++++++++++++ .../vulkan-shaders/vulkan-shaders-gen.cpp | 2 + tests/test-backend-ops.cpp | 17 +++ 4 files changed, 266 insertions(+) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 2a2f7f4f11..3c81805b84 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -825,6 +825,8 @@ struct vk_device_struct { vk_pipeline pipeline_pool2d_f32; vk_pipeline pipeline_rwkv_wkv6_f32; vk_pipeline pipeline_rwkv_wkv7_f32; + // [size_idx][kda] where size_idx: 0=d32, 1=d64, 2=d128 + vk_pipeline pipeline_gated_delta_net[3][2]; vk_pipeline pipeline_ssm_scan_f32_d128; vk_pipeline pipeline_ssm_scan_f32_d256; vk_pipeline pipeline_ssm_conv_f32; @@ -1454,6 +1456,18 @@ struct vk_op_rwkv_wkv7_push_constants { uint32_t C; uint32_t H; }; +struct vk_op_gated_delta_net_push_constants { + uint32_t H; + uint32_t n_tokens; + uint32_t n_seqs; + uint32_t s_off; + uint32_t sq1, sq2, sq3; + uint32_t sv1, sv2, sv3; + uint32_t sb1, sb2, sb3; + uint32_t neq1, rq3; + float scale; +}; + struct vk_op_ssm_scan_push_constants { uint32_t nb02, nb03, nb12, nb13; uint32_t nb21, nb22, nb31; @@ -4568,6 +4582,23 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); + { + const uint32_t gdn_sizes[] = {32, 64, 128}; + const char * gdn_names[][2] = { + {"gated_delta_net_f32_d32", "gated_delta_net_f32_d32_kda"}, + {"gated_delta_net_f32_d64", "gated_delta_net_f32_d64_kda"}, + {"gated_delta_net_f32_d128", "gated_delta_net_f32_d128_kda"}, + }; + for (uint32_t si = 0; si < 3; si++) { + for (uint32_t kda = 0; kda < 2; kda++) { + ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net[si][kda], + gdn_names[si][kda], gated_delta_net_f32_len, gated_delta_net_f32_data, + "main", 7, sizeof(vk_op_gated_delta_net_push_constants), + {1, 1, 1}, {gdn_sizes[si], kda}, 1); + } + } + } + if (device->subgroup_arithmetic && device->subgroup_require_full_support) { ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size}, 1, true, true); ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size}, 1, true, true); @@ -9498,6 +9529,20 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_rwkv_wkv7_f32; } return nullptr; + case GGML_OP_GATED_DELTA_NET: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + const uint32_t S_v = dst->src[2]->ne[0]; + const uint32_t kda = (dst->src[3]->ne[0] == (int64_t)S_v) ? 1 : 0; + uint32_t si; + switch (S_v) { + case 32: si = 0; break; + case 64: si = 1; break; + case 128: si = 2; break; + default: return nullptr; + } + return ctx->device->pipeline_gated_delta_net[si][kda]; + } + return nullptr; case GGML_OP_SSM_SCAN: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { const uint32_t d_state = src0->ne[0]; @@ -10328,6 +10373,59 @@ static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx, ); } +static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) { + const ggml_tensor * src_q = dst->src[0]; + const ggml_tensor * src_v = dst->src[2]; + const ggml_tensor * src_beta = dst->src[4]; + + GGML_ASSERT(dst->buffer != nullptr); + + const uint32_t S_v = (uint32_t)src_v->ne[0]; + const uint32_t H = (uint32_t)src_v->ne[1]; + const uint32_t n_tokens = (uint32_t)src_v->ne[2]; + const uint32_t n_seqs = (uint32_t)src_v->ne[3]; + + const uint32_t s_off = S_v * H * n_tokens * n_seqs; + + vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op); + GGML_ASSERT(pipeline != nullptr); + + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + + vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst); + vk_subbuffer src_buf[6] = {}; + for (int i = 0; i < 6; i++) { + src_buf[i] = ggml_vk_tensor_subbuffer(ctx, dst->src[i]); + } + + const uint32_t sq1 = (uint32_t)(src_q->nb[1] / sizeof(float)); + const uint32_t sq2 = (uint32_t)(src_q->nb[2] / sizeof(float)); + const uint32_t sq3 = (uint32_t)(src_q->nb[3] / sizeof(float)); + const uint32_t sv1 = (uint32_t)(src_v->nb[1] / sizeof(float)); + const uint32_t sv2 = (uint32_t)(src_v->nb[2] / sizeof(float)); + const uint32_t sv3 = (uint32_t)(src_v->nb[3] / sizeof(float)); + const uint32_t sb1 = (uint32_t)(src_beta->nb[1] / sizeof(float)); + const uint32_t sb2 = (uint32_t)(src_beta->nb[2] / sizeof(float)); + const uint32_t sb3 = (uint32_t)(src_beta->nb[3] / sizeof(float)); + + const uint32_t neq1 = (uint32_t)src_q->ne[1]; + const uint32_t rq3 = (uint32_t)(src_v->ne[3] / src_q->ne[3]); + + const float scale = 1.0f / sqrtf((float)S_v); + const vk_op_gated_delta_net_push_constants pc = { + H, n_tokens, n_seqs, s_off, + sq1, sq2, sq3, + sv1, sv2, sv3, + sb1, sb2, sb3, + neq1, rq3, + scale + }; + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, + {src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], dst_buf}, + pc, { H, n_seqs, 1u }); +} + static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; @@ -13044,6 +13142,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr break; + case GGML_OP_GATED_DELTA_NET: + ggml_vk_gated_delta_net(ctx, compute_ctx, node); + + break; + case GGML_OP_SSM_SCAN: ggml_vk_ssm_scan(ctx, compute_ctx, node); @@ -15455,6 +15558,19 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV7: return true; // all inputs are contiguous, see ggml.c + case GGML_OP_GATED_DELTA_NET: + { + const uint32_t S_v = op->src[2]->ne[0]; + if (S_v != 32 && S_v != 64 && S_v != 128) { + return false; + } + for (int i = 0; i < 6; i++) { + if (op->src[i] == nullptr || op->src[i]->type != GGML_TYPE_F32) { + return false; + } + } + return op->type == GGML_TYPE_F32; + } case GGML_OP_SSM_SCAN: { for (int i = 0; i < 6; i++) { @@ -16332,6 +16448,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * } else if (tensor->op == GGML_OP_RWKV_WKV7) { tensor_clone = ggml_rwkv_wkv7(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], src_clone[4], src_clone[5], src_clone[6]); + } else if (tensor->op == GGML_OP_GATED_DELTA_NET) { + tensor_clone = ggml_gated_delta_net(ggml_ctx, src_clone[0], src_clone[1], + src_clone[2], src_clone[3], src_clone[4], src_clone[5]); } else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) { src_clone[0]->flags = tensor->src[0]->flags; tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1], diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp new file mode 100644 index 0000000000..1fdf889e82 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp @@ -0,0 +1,128 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : require + +layout(constant_id = 0) const uint S_V = 128; +layout(constant_id = 1) const uint KDA = 0; + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout(push_constant) uniform Parameters { + uint H; + uint n_tokens; + uint n_seqs; + uint s_off; + uint sq1, sq2, sq3; + uint sv1, sv2, sv3; + uint sb1, sb2, sb3; + uint neq1, rq3; + float scale; +}; + +layout(binding = 0) readonly buffer QBuf { FLOAT_TYPE data_q[]; }; +layout(binding = 1) readonly buffer KBuf { FLOAT_TYPE data_k[]; }; +layout(binding = 2) readonly buffer VBuf { FLOAT_TYPE data_v[]; }; +layout(binding = 3) readonly buffer GBuf { FLOAT_TYPE data_g[]; }; +layout(binding = 4) readonly buffer BetaBuf { FLOAT_TYPE data_beta[]; }; +layout(binding = 5) readonly buffer StateBuf { FLOAT_TYPE data_state[]; }; +layout(binding = 6) buffer DstBuf { FLOAT_TYPE data_dst[]; }; + +shared FLOAT_TYPE s_k[S_V]; +shared FLOAT_TYPE s_q[S_V]; +shared FLOAT_TYPE s_g[S_V]; // KDA only: cached exp(g[i]) + +void main() { + const uint head_id = gl_WorkGroupID.x; + const uint seq_id = gl_WorkGroupID.y; + const uint col = gl_LocalInvocationID.x; + + const uint iq1 = head_id % neq1; + const uint iq3 = seq_id / rq3; + + const uint state_size = S_V * S_V; + const uint state_base = (seq_id * H + head_id) * state_size; + + FLOAT_TYPE state[S_V]; + [[unroll]] for (uint i = 0; i < S_V; i++) { + state[i] = FLOAT_TYPE(data_state[state_base + i * S_V + col]); + } + + uint attn_off = (seq_id * n_tokens * H + head_id) * S_V; + + for (uint t = 0; t < n_tokens; t++) { + const uint q_off = iq3 * sq3 + t * sq2 + iq1 * sq1; + const uint k_off = q_off; + const uint v_off = seq_id * sv3 + t * sv2 + head_id * sv1; + + s_q[col] = FLOAT_TYPE(data_q[q_off + col]); + s_k[col] = FLOAT_TYPE(data_k[k_off + col]); + + const uint gb_off = seq_id * sb3 + t * sb2 + head_id * sb1; + + if (KDA != 0) { + const uint g_base = gb_off * S_V; + s_g[col] = exp(FLOAT_TYPE(data_g[g_base + col])); + } + + barrier(); + + const FLOAT_TYPE v_val = FLOAT_TYPE(data_v[v_off + col]); + const FLOAT_TYPE beta_val = FLOAT_TYPE(data_beta[gb_off]); + + if (KDA == 0) { + const FLOAT_TYPE g_val = exp(FLOAT_TYPE(data_g[gb_off])); + + FLOAT_TYPE kv_col = 0.0; + [[unroll]] for (uint i = 0; i < S_V; i += 4) { + kv_col += dot( + vec4(state[i], state[i+1], state[i+2], state[i+3]), + vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]) + ); + } + + FLOAT_TYPE delta_col = (v_val - g_val * kv_col) * beta_val; + + FLOAT_TYPE attn_col = 0.0; + [[unroll]] for (uint i = 0; i < S_V; i += 4) { + vec4 sv = vec4(state[i], state[i+1], state[i+2], state[i+3]); + vec4 kv = vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]); + sv = g_val * sv + kv * delta_col; + state[i] = sv.x; state[i+1] = sv.y; state[i+2] = sv.z; state[i+3] = sv.w; + + attn_col += dot(sv, vec4(s_q[i], s_q[i+1], s_q[i+2], s_q[i+3])); + } + + data_dst[attn_off + col] = attn_col * scale; + } else { + FLOAT_TYPE kv_col = 0.0; + [[unroll]] for (uint i = 0; i < S_V; i += 4) { + vec4 gv = vec4(s_g[i], s_g[i+1], s_g[i+2], s_g[i+3]); + vec4 sv = vec4(state[i], state[i+1], state[i+2], state[i+3]); + vec4 kv = vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]); + kv_col += dot(gv * sv, kv); + } + + FLOAT_TYPE delta_col = (v_val - kv_col) * beta_val; + + FLOAT_TYPE attn_col = 0.0; + [[unroll]] for (uint i = 0; i < S_V; i += 4) { + vec4 gv = vec4(s_g[i], s_g[i+1], s_g[i+2], s_g[i+3]); + vec4 sv = vec4(state[i], state[i+1], state[i+2], state[i+3]); + vec4 kv = vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]); + sv = gv * sv + kv * delta_col; + state[i] = sv.x; state[i+1] = sv.y; state[i+2] = sv.z; state[i+3] = sv.w; + + attn_col += dot(sv, vec4(s_q[i], s_q[i+1], s_q[i+2], s_q[i+3])); + } + + data_dst[attn_off + col] = attn_col * scale; + } + + attn_off += S_V * H; + barrier(); + } + + [[unroll]] for (uint i = 0; i < S_V; i++) { + data_dst[s_off + state_base + i * S_V + col] = state[i]; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index fb8941232b..4b00ba3deb 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -987,6 +987,8 @@ void process_shaders() { string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); + string_to_spv("gated_delta_net_f32", "gated_delta_net.comp", merge_maps(base_dict, {{"FLOAT_TYPE", "float"}})); + string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index a821655d10..e9f2e8ace4 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -8731,6 +8731,23 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {128, 16, 2, 3}, 2)); test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {64, 16, 2, 3}, 3)); + // GATED_DELTA_NET: realistic model configurations + // TG: n_seq_tokens=1 (autoregressive) + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 1, 1)); // Qwen3.5-like: 32 heads, d=128 + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 16, 64, 1, 1)); // smaller model + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 1, 1, 1, false, true)); // KDA + // PP: n_seq_tokens=64,256 (prompt processing) + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 64, 1)); // PP-64 + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 256, 1)); // PP-256 + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 512, 1)); // PP-512 + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 1024, 1)); // PP-1024 + // Small model configs (fewer heads = less GPU occupancy for autoregressive) + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 128, 64, 1)); // 4h PP-64 + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 128, 256, 1)); // 4h PP-256 + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 128, 512, 1)); // 4h PP-512 + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 128, 1024, 1)); // 4h PP-1024 + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 64, 1, 1, false, true)); // KDA PP-64 + return test_cases; }