vulkan: change gated_delta_net to shard a column across a subgroup (#20662)

* vulkan: change gated_delta_net to shard a column across a subgroup

This is based on https://github.com/ggml-org/llama.cpp/pull/20391, I used an
LLM to port the CUDA code to Vulkan, and guided to it to make various fixes to
work with Vulkan (e.g. handling different subgroup sizes, unknown mapping of
subgroup to invocation id, using subgroupAdd optionally, etc.).

This fixes a perf regression from the transposing of the values in memory
(!20443).

* vulkan: Spread columns across fewer lanes to reduce the number of workgroups
This commit is contained in:
Jeff Bolz 2026-03-20 06:17:15 -05:00 committed by GitHub
parent dc6592431b
commit e06c3ab2bc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 140 additions and 67 deletions

View File

@ -4604,12 +4604,42 @@ static void ggml_vk_load_shaders(vk_device& device) {
{"gated_delta_net_f32_d64", "gated_delta_net_f32_d64_kda"},
{"gated_delta_net_f32_d128", "gated_delta_net_f32_d128_kda"},
};
const bool use_subgroup_reduce = device->subgroup_arithmetic;
for (uint32_t si = 0; si < 3; si++) {
const uint32_t S_V = gdn_sizes[si];
GGML_ASSERT(is_pow2(S_V));
uint32_t lanes_per_column;
if (S_V >= 128u && device->subgroup_clustered) {
lanes_per_column = 8u;
} else {
// Use largest power-of-two that divides both S_V and subgroup_size so that
// (1) S_V % lanes_per_column == 0 and (2) S_V % (subgroup_size / lanes_per_column) == 0.
// This means we don't need extra bounds checking logic in the shader.
lanes_per_column = std::min(S_V, device->subgroup_size);
}
const bool need_clustered_shader = lanes_per_column != 1 && (lanes_per_column < device->subgroup_size);
size_t gdn_len;
const void * gdn_data;
if (use_subgroup_reduce && need_clustered_shader) {
gdn_len = gated_delta_net_f32_len;
gdn_data = (const void *)gated_delta_net_f32_data;
} else if (use_subgroup_reduce) {
gdn_len = gated_delta_net_f32_nocluster_len;
gdn_data = (const void *)gated_delta_net_f32_nocluster_data;
} else {
gdn_len = gated_delta_net_f32_shmem_len;
gdn_data = (const void *)gated_delta_net_f32_shmem_data;
}
const uint32_t cols_per_wg = device->subgroup_size / lanes_per_column;
const std::array<uint32_t, 3> wg_denoms = {1u, 1u, cols_per_wg};
for (uint32_t kda = 0; kda < 2; kda++) {
ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net[si][kda],
gdn_names[si][kda], gated_delta_net_f32_len, gated_delta_net_f32_data,
"main", 7, sizeof(vk_op_gated_delta_net_push_constants),
{1, 1, 1}, {gdn_sizes[si], kda}, 1);
gdn_names[si][kda], gdn_len, gdn_data, "main", 7, sizeof(vk_op_gated_delta_net_push_constants),
wg_denoms, {S_V, kda, device->subgroup_size, lanes_per_column}, 1, true, use_subgroup_reduce, device->subgroup_size);
}
}
}
@ -10438,7 +10468,7 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], dst_buf},
pc, { H, n_seqs, 1u });
pc, { H, n_seqs, S_v });
}
static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {

View File

@ -1,11 +1,25 @@
#version 450
#extension GL_EXT_control_flow_attributes : require
#extension GL_KHR_shader_subgroup_basic : enable
#if USE_SUBGROUP_CLUSTERED
#extension GL_KHR_shader_subgroup_clustered : enable
#endif
#if USE_SUBGROUP_ADD
#extension GL_KHR_shader_subgroup_arithmetic : enable
#endif
// Caller guarantees valid spec constants: S_V % COLS_PER_WG == 0 and S_V % LANES_PER_COLUMN == 0,
// so no bounds checking is needed.
layout(constant_id = 0) const uint S_V = 128;
layout(constant_id = 1) const uint KDA = 0;
layout(constant_id = 2) const uint SUBGROUP_SIZE = 32;
layout(constant_id = 3) const uint LANES_PER_COLUMN = 32;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
const uint COLS_PER_WG = SUBGROUP_SIZE / LANES_PER_COLUMN;
const uint ROWS_PER_LANE = S_V / LANES_PER_COLUMN;
layout(local_size_x_id = 2, local_size_y = 1, local_size_z = 1) in;
layout(push_constant) uniform Parameters {
uint H;
@ -27,14 +41,61 @@ layout(binding = 4) readonly buffer BetaBuf { FLOAT_TYPE data_beta[]; };
layout(binding = 5) readonly buffer StateBuf { FLOAT_TYPE data_state[]; };
layout(binding = 6) buffer DstBuf { FLOAT_TYPE data_dst[]; };
shared FLOAT_TYPE s_k[S_V];
shared FLOAT_TYPE s_q[S_V];
shared FLOAT_TYPE s_g[S_V]; // KDA only: cached exp(g[i])
#if !USE_SUBGROUP_ADD && !USE_SUBGROUP_CLUSTERED
shared FLOAT_TYPE temp[SUBGROUP_SIZE];
// This does a reduction across groups of LANES_PER_COLUMN
FLOAT_TYPE reduce_add_shmem(FLOAT_TYPE partial) {
const uint lane = gl_SubgroupInvocationID;
temp[lane] = partial;
barrier();
[[unroll]] for (uint s = LANES_PER_COLUMN / 2u; s > 0; s >>= 1u) {
FLOAT_TYPE other = temp[lane ^ s];
barrier();
temp[lane] += other;
barrier();
}
const FLOAT_TYPE result = temp[lane];
barrier();
return result;
}
#endif
// clusterSize for subgroupClusteredAdd must be a compile-time constant; branch on spec constant
FLOAT_TYPE reduce_partial(FLOAT_TYPE partial) {
switch (LANES_PER_COLUMN) {
case 1u:
return partial;
#if USE_SUBGROUP_CLUSTERED
// Workaround for GLSL requiring a literal constant for the cluster size.
// The branches should all fold away.
case 2u:
return subgroupClusteredAdd(partial, 2u);
case 4u:
return subgroupClusteredAdd(partial, 4u);
case 8u:
return subgroupClusteredAdd(partial, 8u);
case 16u:
return subgroupClusteredAdd(partial, 16u);
case 32u:
return subgroupClusteredAdd(partial, 32u);
case 64u:
return subgroupClusteredAdd(partial, 64u);
#endif
default:
#if USE_SUBGROUP_ADD
return subgroupAdd(partial);
#else
return reduce_add_shmem(partial);
#endif
}
}
void main() {
const uint head_id = gl_WorkGroupID.x;
const uint seq_id = gl_WorkGroupID.y;
const uint col = gl_LocalInvocationID.x;
const uint seq_id = gl_WorkGroupID.y;
const uint lane = gl_SubgroupInvocationID % LANES_PER_COLUMN;
const uint col = gl_WorkGroupID.z * COLS_PER_WG + (gl_SubgroupInvocationID / LANES_PER_COLUMN);
const uint iq1 = head_id % neq1;
const uint iq3 = seq_id / rq3;
@ -42,9 +103,9 @@ void main() {
const uint state_size = S_V * S_V;
const uint state_base = (seq_id * H + head_id) * state_size;
FLOAT_TYPE state[S_V];
[[unroll]] for (uint i = 0; i < S_V; i++) {
state[i] = FLOAT_TYPE(data_state[state_base + col * S_V + i]);
FLOAT_TYPE s_shard[ROWS_PER_LANE];
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
s_shard[r] = FLOAT_TYPE(data_state[state_base + col * S_V + r * LANES_PER_COLUMN + lane]);
}
uint attn_off = (seq_id * n_tokens * H + head_id) * S_V;
@ -53,76 +114,56 @@ void main() {
const uint q_off = iq3 * sq3 + t * sq2 + iq1 * sq1;
const uint k_off = q_off;
const uint v_off = seq_id * sv3 + t * sv2 + head_id * sv1;
s_q[col] = FLOAT_TYPE(data_q[q_off + col]);
s_k[col] = FLOAT_TYPE(data_k[k_off + col]);
const uint gb_off = seq_id * sb3 + t * sb2 + head_id * sb1;
if (KDA != 0) {
const uint g_base = gb_off * S_V;
s_g[col] = exp(FLOAT_TYPE(data_g[g_base + col]));
}
barrier();
const FLOAT_TYPE v_val = FLOAT_TYPE(data_v[v_off + col]);
const FLOAT_TYPE beta_val = FLOAT_TYPE(data_beta[gb_off]);
FLOAT_TYPE k_reg[ROWS_PER_LANE];
FLOAT_TYPE q_reg[ROWS_PER_LANE];
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
const uint i = r * LANES_PER_COLUMN + lane;
k_reg[r] = FLOAT_TYPE(data_k[k_off + i]);
q_reg[r] = FLOAT_TYPE(data_q[q_off + i]);
}
FLOAT_TYPE g_exp[ROWS_PER_LANE];
if (KDA == 0) {
const FLOAT_TYPE g_val = exp(FLOAT_TYPE(data_g[gb_off]));
FLOAT_TYPE kv_col = 0.0;
[[unroll]] for (uint i = 0; i < S_V; i += 4) {
kv_col += dot(
vec4(state[i], state[i+1], state[i+2], state[i+3]),
vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3])
);
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
g_exp[r] = g_val;
}
FLOAT_TYPE delta_col = (v_val - g_val * kv_col) * beta_val;
FLOAT_TYPE attn_col = 0.0;
[[unroll]] for (uint i = 0; i < S_V; i += 4) {
vec4 sv = vec4(state[i], state[i+1], state[i+2], state[i+3]);
vec4 kv = vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]);
sv = g_val * sv + kv * delta_col;
state[i] = sv.x; state[i+1] = sv.y; state[i+2] = sv.z; state[i+3] = sv.w;
attn_col += dot(sv, vec4(s_q[i], s_q[i+1], s_q[i+2], s_q[i+3]));
}
data_dst[attn_off + col] = attn_col * scale;
} else {
FLOAT_TYPE kv_col = 0.0;
[[unroll]] for (uint i = 0; i < S_V; i += 4) {
vec4 gv = vec4(s_g[i], s_g[i+1], s_g[i+2], s_g[i+3]);
vec4 sv = vec4(state[i], state[i+1], state[i+2], state[i+3]);
vec4 kv = vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]);
kv_col += dot(gv * sv, kv);
const uint g_base = gb_off * S_V;
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
const uint i = r * LANES_PER_COLUMN + lane;
g_exp[r] = exp(FLOAT_TYPE(data_g[g_base + i]));
}
}
FLOAT_TYPE delta_col = (v_val - kv_col) * beta_val;
const FLOAT_TYPE v_val = FLOAT_TYPE(data_v[v_off + col]);
FLOAT_TYPE attn_col = 0.0;
[[unroll]] for (uint i = 0; i < S_V; i += 4) {
vec4 gv = vec4(s_g[i], s_g[i+1], s_g[i+2], s_g[i+3]);
vec4 sv = vec4(state[i], state[i+1], state[i+2], state[i+3]);
vec4 kv = vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]);
sv = gv * sv + kv * delta_col;
state[i] = sv.x; state[i+1] = sv.y; state[i+2] = sv.z; state[i+3] = sv.w;
FLOAT_TYPE kv_shard = 0.0;
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
kv_shard += g_exp[r] * s_shard[r] * k_reg[r];
}
FLOAT_TYPE kv_col = reduce_partial(kv_shard);
attn_col += dot(sv, vec4(s_q[i], s_q[i+1], s_q[i+2], s_q[i+3]));
}
FLOAT_TYPE delta_col = (v_val - kv_col) * beta_val;
FLOAT_TYPE attn_partial = 0.0;
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
s_shard[r] = g_exp[r] * s_shard[r] + k_reg[r] * delta_col;
attn_partial += s_shard[r] * q_reg[r];
}
FLOAT_TYPE attn_col = reduce_partial(attn_partial);
if (lane == 0) {
data_dst[attn_off + col] = attn_col * scale;
}
attn_off += S_V * H;
barrier();
}
[[unroll]] for (uint i = 0; i < S_V; i++) {
data_dst[s_off + state_base + col * S_V + i] = state[i];
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
data_dst[s_off + state_base + col * S_V + r * LANES_PER_COLUMN + lane] = s_shard[r];
}
}

View File

@ -987,7 +987,9 @@ void process_shaders() {
string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
string_to_spv("gated_delta_net_f32", "gated_delta_net.comp", merge_maps(base_dict, {{"FLOAT_TYPE", "float"}}));
string_to_spv("gated_delta_net_f32", "gated_delta_net.comp", merge_maps(base_dict, {{"FLOAT_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}, {"USE_SUBGROUP_CLUSTERED", "1"}}));
string_to_spv("gated_delta_net_f32_nocluster", "gated_delta_net.comp", merge_maps(base_dict, {{"FLOAT_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}, {"USE_SUBGROUP_CLUSTERED", "0"}}));
string_to_spv("gated_delta_net_f32_shmem", "gated_delta_net.comp", merge_maps(base_dict, {{"FLOAT_TYPE", "float"}, {"USE_SUBGROUP_ADD", "0"}, {"USE_SUBGROUP_CLUSTERED", "0"}}));
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));