vulkan: add GATED_DELTA_NET op support (#20334)
* vulkan: add GATED_DELTA_NET op support Implements the fused gated delta net recurrence as a Vulkan compute shader with full support for scalar gate, KDA vector gate, GQA broadcast, multi-token sequences, and permuted (non-contiguous) q/k inputs. Specialization constants select head size (32/64/128) and KDA mode at pipeline creation time. Passes all 13 test-backend-ops cases on AMD Radeon 890M (RADV GFX1150). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * vulkan: optimize GATED_DELTA_NET shader (Phase 1) - vec4 dot products on all inner loops (dp4 hardware intrinsic) - Cache exp(g) in shared memory for KDA path, eliminating ~32K redundant global reads and ~16K redundant exp() calls per token - vec4 fused decay + rank-1 update (3 vec4 ops vs 12 scalar ops) - Add perf benchmark cases for GATED_DELTA_NET to test-backend-ops KDA TG: +5.4% throughput. Non-KDA: no regressions. 13/13 test-backend-ops passing on AMD Radeon 890M (RADV GFX1150). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * vulkan: address review feedback for GATED_DELTA_NET Pipeline array refactor [3][2], A_TYPE/D_TYPE/FLOAT_TYPE shader macros, scale in push constants, supports_op fix, dispatch restructuring. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * vulkan: use FLOAT_TYPE for buffer/shared declarations, align formatting Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * vulkan: add explicit FLOAT_TYPE casts for buffer loads Wrap data_q, data_k, and data_g buffer reads with FLOAT_TYPE() casts to ensure correct behavior across all Vulkan configurations. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * vulkan: fix Q/K broadcast for interleaved head layout Adapt to the interleaved broadcast convention from #20340: head_id / rq1 → head_id % neq1 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> --------- Co-authored-by: Progeny Alpha <ProgenyAlpha@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
c3e3f9e533
commit
deee23863b
|
|
@ -825,6 +825,8 @@ struct vk_device_struct {
|
|||
vk_pipeline pipeline_pool2d_f32;
|
||||
vk_pipeline pipeline_rwkv_wkv6_f32;
|
||||
vk_pipeline pipeline_rwkv_wkv7_f32;
|
||||
// [size_idx][kda] where size_idx: 0=d32, 1=d64, 2=d128
|
||||
vk_pipeline pipeline_gated_delta_net[3][2];
|
||||
vk_pipeline pipeline_ssm_scan_f32_d128;
|
||||
vk_pipeline pipeline_ssm_scan_f32_d256;
|
||||
vk_pipeline pipeline_ssm_conv_f32;
|
||||
|
|
@ -1454,6 +1456,18 @@ struct vk_op_rwkv_wkv7_push_constants {
|
|||
uint32_t C;
|
||||
uint32_t H;
|
||||
};
|
||||
struct vk_op_gated_delta_net_push_constants {
|
||||
uint32_t H;
|
||||
uint32_t n_tokens;
|
||||
uint32_t n_seqs;
|
||||
uint32_t s_off;
|
||||
uint32_t sq1, sq2, sq3;
|
||||
uint32_t sv1, sv2, sv3;
|
||||
uint32_t sb1, sb2, sb3;
|
||||
uint32_t neq1, rq3;
|
||||
float scale;
|
||||
};
|
||||
|
||||
struct vk_op_ssm_scan_push_constants {
|
||||
uint32_t nb02, nb03, nb12, nb13;
|
||||
uint32_t nb21, nb22, nb31;
|
||||
|
|
@ -4568,6 +4582,23 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
|
||||
|
||||
{
|
||||
const uint32_t gdn_sizes[] = {32, 64, 128};
|
||||
const char * gdn_names[][2] = {
|
||||
{"gated_delta_net_f32_d32", "gated_delta_net_f32_d32_kda"},
|
||||
{"gated_delta_net_f32_d64", "gated_delta_net_f32_d64_kda"},
|
||||
{"gated_delta_net_f32_d128", "gated_delta_net_f32_d128_kda"},
|
||||
};
|
||||
for (uint32_t si = 0; si < 3; si++) {
|
||||
for (uint32_t kda = 0; kda < 2; kda++) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net[si][kda],
|
||||
gdn_names[si][kda], gated_delta_net_f32_len, gated_delta_net_f32_data,
|
||||
"main", 7, sizeof(vk_op_gated_delta_net_push_constants),
|
||||
{1, 1, 1}, {gdn_sizes[si], kda}, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size}, 1, true, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size}, 1, true, true);
|
||||
|
|
@ -9498,6 +9529,20 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|||
return ctx->device->pipeline_rwkv_wkv7_f32;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_GATED_DELTA_NET:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
const uint32_t S_v = dst->src[2]->ne[0];
|
||||
const uint32_t kda = (dst->src[3]->ne[0] == (int64_t)S_v) ? 1 : 0;
|
||||
uint32_t si;
|
||||
switch (S_v) {
|
||||
case 32: si = 0; break;
|
||||
case 64: si = 1; break;
|
||||
case 128: si = 2; break;
|
||||
default: return nullptr;
|
||||
}
|
||||
return ctx->device->pipeline_gated_delta_net[si][kda];
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_SSM_SCAN:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
const uint32_t d_state = src0->ne[0];
|
||||
|
|
@ -10328,6 +10373,59 @@ static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|||
);
|
||||
}
|
||||
|
||||
static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src_q = dst->src[0];
|
||||
const ggml_tensor * src_v = dst->src[2];
|
||||
const ggml_tensor * src_beta = dst->src[4];
|
||||
|
||||
GGML_ASSERT(dst->buffer != nullptr);
|
||||
|
||||
const uint32_t S_v = (uint32_t)src_v->ne[0];
|
||||
const uint32_t H = (uint32_t)src_v->ne[1];
|
||||
const uint32_t n_tokens = (uint32_t)src_v->ne[2];
|
||||
const uint32_t n_seqs = (uint32_t)src_v->ne[3];
|
||||
|
||||
const uint32_t s_off = S_v * H * n_tokens * n_seqs;
|
||||
|
||||
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op);
|
||||
GGML_ASSERT(pipeline != nullptr);
|
||||
|
||||
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
|
||||
|
||||
vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
|
||||
vk_subbuffer src_buf[6] = {};
|
||||
for (int i = 0; i < 6; i++) {
|
||||
src_buf[i] = ggml_vk_tensor_subbuffer(ctx, dst->src[i]);
|
||||
}
|
||||
|
||||
const uint32_t sq1 = (uint32_t)(src_q->nb[1] / sizeof(float));
|
||||
const uint32_t sq2 = (uint32_t)(src_q->nb[2] / sizeof(float));
|
||||
const uint32_t sq3 = (uint32_t)(src_q->nb[3] / sizeof(float));
|
||||
const uint32_t sv1 = (uint32_t)(src_v->nb[1] / sizeof(float));
|
||||
const uint32_t sv2 = (uint32_t)(src_v->nb[2] / sizeof(float));
|
||||
const uint32_t sv3 = (uint32_t)(src_v->nb[3] / sizeof(float));
|
||||
const uint32_t sb1 = (uint32_t)(src_beta->nb[1] / sizeof(float));
|
||||
const uint32_t sb2 = (uint32_t)(src_beta->nb[2] / sizeof(float));
|
||||
const uint32_t sb3 = (uint32_t)(src_beta->nb[3] / sizeof(float));
|
||||
|
||||
const uint32_t neq1 = (uint32_t)src_q->ne[1];
|
||||
const uint32_t rq3 = (uint32_t)(src_v->ne[3] / src_q->ne[3]);
|
||||
|
||||
const float scale = 1.0f / sqrtf((float)S_v);
|
||||
const vk_op_gated_delta_net_push_constants pc = {
|
||||
H, n_tokens, n_seqs, s_off,
|
||||
sq1, sq2, sq3,
|
||||
sv1, sv2, sv3,
|
||||
sb1, sb2, sb3,
|
||||
neq1, rq3,
|
||||
scale
|
||||
};
|
||||
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
||||
{src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], dst_buf},
|
||||
pc, { H, n_seqs, 1u });
|
||||
}
|
||||
|
||||
static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
|
@ -13044,6 +13142,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|||
|
||||
break;
|
||||
|
||||
case GGML_OP_GATED_DELTA_NET:
|
||||
ggml_vk_gated_delta_net(ctx, compute_ctx, node);
|
||||
|
||||
break;
|
||||
|
||||
case GGML_OP_SSM_SCAN:
|
||||
ggml_vk_ssm_scan(ctx, compute_ctx, node);
|
||||
|
||||
|
|
@ -15455,6 +15558,19 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
return true; // all inputs are contiguous, see ggml.c
|
||||
case GGML_OP_GATED_DELTA_NET:
|
||||
{
|
||||
const uint32_t S_v = op->src[2]->ne[0];
|
||||
if (S_v != 32 && S_v != 64 && S_v != 128) {
|
||||
return false;
|
||||
}
|
||||
for (int i = 0; i < 6; i++) {
|
||||
if (op->src[i] == nullptr || op->src[i]->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return op->type == GGML_TYPE_F32;
|
||||
}
|
||||
case GGML_OP_SSM_SCAN:
|
||||
{
|
||||
for (int i = 0; i < 6; i++) {
|
||||
|
|
@ -16332,6 +16448,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
|||
} else if (tensor->op == GGML_OP_RWKV_WKV7) {
|
||||
tensor_clone = ggml_rwkv_wkv7(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3],
|
||||
src_clone[4], src_clone[5], src_clone[6]);
|
||||
} else if (tensor->op == GGML_OP_GATED_DELTA_NET) {
|
||||
tensor_clone = ggml_gated_delta_net(ggml_ctx, src_clone[0], src_clone[1],
|
||||
src_clone[2], src_clone[3], src_clone[4], src_clone[5]);
|
||||
} else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) {
|
||||
src_clone[0]->flags = tensor->src[0]->flags;
|
||||
tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1],
|
||||
|
|
|
|||
|
|
@ -0,0 +1,128 @@
|
|||
#version 450
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : require
|
||||
|
||||
layout(constant_id = 0) const uint S_V = 128;
|
||||
layout(constant_id = 1) const uint KDA = 0;
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout(push_constant) uniform Parameters {
|
||||
uint H;
|
||||
uint n_tokens;
|
||||
uint n_seqs;
|
||||
uint s_off;
|
||||
uint sq1, sq2, sq3;
|
||||
uint sv1, sv2, sv3;
|
||||
uint sb1, sb2, sb3;
|
||||
uint neq1, rq3;
|
||||
float scale;
|
||||
};
|
||||
|
||||
layout(binding = 0) readonly buffer QBuf { FLOAT_TYPE data_q[]; };
|
||||
layout(binding = 1) readonly buffer KBuf { FLOAT_TYPE data_k[]; };
|
||||
layout(binding = 2) readonly buffer VBuf { FLOAT_TYPE data_v[]; };
|
||||
layout(binding = 3) readonly buffer GBuf { FLOAT_TYPE data_g[]; };
|
||||
layout(binding = 4) readonly buffer BetaBuf { FLOAT_TYPE data_beta[]; };
|
||||
layout(binding = 5) readonly buffer StateBuf { FLOAT_TYPE data_state[]; };
|
||||
layout(binding = 6) buffer DstBuf { FLOAT_TYPE data_dst[]; };
|
||||
|
||||
shared FLOAT_TYPE s_k[S_V];
|
||||
shared FLOAT_TYPE s_q[S_V];
|
||||
shared FLOAT_TYPE s_g[S_V]; // KDA only: cached exp(g[i])
|
||||
|
||||
void main() {
|
||||
const uint head_id = gl_WorkGroupID.x;
|
||||
const uint seq_id = gl_WorkGroupID.y;
|
||||
const uint col = gl_LocalInvocationID.x;
|
||||
|
||||
const uint iq1 = head_id % neq1;
|
||||
const uint iq3 = seq_id / rq3;
|
||||
|
||||
const uint state_size = S_V * S_V;
|
||||
const uint state_base = (seq_id * H + head_id) * state_size;
|
||||
|
||||
FLOAT_TYPE state[S_V];
|
||||
[[unroll]] for (uint i = 0; i < S_V; i++) {
|
||||
state[i] = FLOAT_TYPE(data_state[state_base + i * S_V + col]);
|
||||
}
|
||||
|
||||
uint attn_off = (seq_id * n_tokens * H + head_id) * S_V;
|
||||
|
||||
for (uint t = 0; t < n_tokens; t++) {
|
||||
const uint q_off = iq3 * sq3 + t * sq2 + iq1 * sq1;
|
||||
const uint k_off = q_off;
|
||||
const uint v_off = seq_id * sv3 + t * sv2 + head_id * sv1;
|
||||
|
||||
s_q[col] = FLOAT_TYPE(data_q[q_off + col]);
|
||||
s_k[col] = FLOAT_TYPE(data_k[k_off + col]);
|
||||
|
||||
const uint gb_off = seq_id * sb3 + t * sb2 + head_id * sb1;
|
||||
|
||||
if (KDA != 0) {
|
||||
const uint g_base = gb_off * S_V;
|
||||
s_g[col] = exp(FLOAT_TYPE(data_g[g_base + col]));
|
||||
}
|
||||
|
||||
barrier();
|
||||
|
||||
const FLOAT_TYPE v_val = FLOAT_TYPE(data_v[v_off + col]);
|
||||
const FLOAT_TYPE beta_val = FLOAT_TYPE(data_beta[gb_off]);
|
||||
|
||||
if (KDA == 0) {
|
||||
const FLOAT_TYPE g_val = exp(FLOAT_TYPE(data_g[gb_off]));
|
||||
|
||||
FLOAT_TYPE kv_col = 0.0;
|
||||
[[unroll]] for (uint i = 0; i < S_V; i += 4) {
|
||||
kv_col += dot(
|
||||
vec4(state[i], state[i+1], state[i+2], state[i+3]),
|
||||
vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3])
|
||||
);
|
||||
}
|
||||
|
||||
FLOAT_TYPE delta_col = (v_val - g_val * kv_col) * beta_val;
|
||||
|
||||
FLOAT_TYPE attn_col = 0.0;
|
||||
[[unroll]] for (uint i = 0; i < S_V; i += 4) {
|
||||
vec4 sv = vec4(state[i], state[i+1], state[i+2], state[i+3]);
|
||||
vec4 kv = vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]);
|
||||
sv = g_val * sv + kv * delta_col;
|
||||
state[i] = sv.x; state[i+1] = sv.y; state[i+2] = sv.z; state[i+3] = sv.w;
|
||||
|
||||
attn_col += dot(sv, vec4(s_q[i], s_q[i+1], s_q[i+2], s_q[i+3]));
|
||||
}
|
||||
|
||||
data_dst[attn_off + col] = attn_col * scale;
|
||||
} else {
|
||||
FLOAT_TYPE kv_col = 0.0;
|
||||
[[unroll]] for (uint i = 0; i < S_V; i += 4) {
|
||||
vec4 gv = vec4(s_g[i], s_g[i+1], s_g[i+2], s_g[i+3]);
|
||||
vec4 sv = vec4(state[i], state[i+1], state[i+2], state[i+3]);
|
||||
vec4 kv = vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]);
|
||||
kv_col += dot(gv * sv, kv);
|
||||
}
|
||||
|
||||
FLOAT_TYPE delta_col = (v_val - kv_col) * beta_val;
|
||||
|
||||
FLOAT_TYPE attn_col = 0.0;
|
||||
[[unroll]] for (uint i = 0; i < S_V; i += 4) {
|
||||
vec4 gv = vec4(s_g[i], s_g[i+1], s_g[i+2], s_g[i+3]);
|
||||
vec4 sv = vec4(state[i], state[i+1], state[i+2], state[i+3]);
|
||||
vec4 kv = vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]);
|
||||
sv = gv * sv + kv * delta_col;
|
||||
state[i] = sv.x; state[i+1] = sv.y; state[i+2] = sv.z; state[i+3] = sv.w;
|
||||
|
||||
attn_col += dot(sv, vec4(s_q[i], s_q[i+1], s_q[i+2], s_q[i+3]));
|
||||
}
|
||||
|
||||
data_dst[attn_off + col] = attn_col * scale;
|
||||
}
|
||||
|
||||
attn_off += S_V * H;
|
||||
barrier();
|
||||
}
|
||||
|
||||
[[unroll]] for (uint i = 0; i < S_V; i++) {
|
||||
data_dst[s_off + state_base + i * S_V + col] = state[i];
|
||||
}
|
||||
}
|
||||
|
|
@ -987,6 +987,8 @@ void process_shaders() {
|
|||
|
||||
string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||
|
||||
string_to_spv("gated_delta_net_f32", "gated_delta_net.comp", merge_maps(base_dict, {{"FLOAT_TYPE", "float"}}));
|
||||
|
||||
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||
string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||
|
||||
|
|
|
|||
|
|
@ -8731,6 +8731,23 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
|||
test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {128, 16, 2, 3}, 2));
|
||||
test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {64, 16, 2, 3}, 3));
|
||||
|
||||
// GATED_DELTA_NET: realistic model configurations
|
||||
// TG: n_seq_tokens=1 (autoregressive)
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 1, 1)); // Qwen3.5-like: 32 heads, d=128
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 16, 64, 1, 1)); // smaller model
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 1, 1, 1, false, true)); // KDA
|
||||
// PP: n_seq_tokens=64,256 (prompt processing)
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 64, 1)); // PP-64
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 256, 1)); // PP-256
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 512, 1)); // PP-512
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 1024, 1)); // PP-1024
|
||||
// Small model configs (fewer heads = less GPU occupancy for autoregressive)
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 128, 64, 1)); // 4h PP-64
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 128, 256, 1)); // 4h PP-256
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 128, 512, 1)); // 4h PP-512
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 128, 1024, 1)); // 4h PP-1024
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 64, 1, 1, false, true)); // KDA PP-64
|
||||
|
||||
return test_cases;
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue