diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 3e36435d16..566958b3a9 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -4604,12 +4604,42 @@ static void ggml_vk_load_shaders(vk_device& device) { {"gated_delta_net_f32_d64", "gated_delta_net_f32_d64_kda"}, {"gated_delta_net_f32_d128", "gated_delta_net_f32_d128_kda"}, }; + const bool use_subgroup_reduce = device->subgroup_arithmetic; for (uint32_t si = 0; si < 3; si++) { + const uint32_t S_V = gdn_sizes[si]; + GGML_ASSERT(is_pow2(S_V)); + + uint32_t lanes_per_column; + if (S_V >= 128u && device->subgroup_clustered) { + lanes_per_column = 8u; + } else { + // Use largest power-of-two that divides both S_V and subgroup_size so that + // (1) S_V % lanes_per_column == 0 and (2) S_V % (subgroup_size / lanes_per_column) == 0. + // This means we don't need extra bounds checking logic in the shader. + lanes_per_column = std::min(S_V, device->subgroup_size); + } + + const bool need_clustered_shader = lanes_per_column != 1 && (lanes_per_column < device->subgroup_size); + size_t gdn_len; + const void * gdn_data; + if (use_subgroup_reduce && need_clustered_shader) { + gdn_len = gated_delta_net_f32_len; + gdn_data = (const void *)gated_delta_net_f32_data; + } else if (use_subgroup_reduce) { + gdn_len = gated_delta_net_f32_nocluster_len; + gdn_data = (const void *)gated_delta_net_f32_nocluster_data; + } else { + gdn_len = gated_delta_net_f32_shmem_len; + gdn_data = (const void *)gated_delta_net_f32_shmem_data; + } + + const uint32_t cols_per_wg = device->subgroup_size / lanes_per_column; + const std::array wg_denoms = {1u, 1u, cols_per_wg}; + for (uint32_t kda = 0; kda < 2; kda++) { ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net[si][kda], - gdn_names[si][kda], gated_delta_net_f32_len, gated_delta_net_f32_data, - "main", 7, sizeof(vk_op_gated_delta_net_push_constants), - {1, 1, 1}, {gdn_sizes[si], kda}, 1); + gdn_names[si][kda], gdn_len, gdn_data, "main", 7, sizeof(vk_op_gated_delta_net_push_constants), + wg_denoms, {S_V, kda, device->subgroup_size, lanes_per_column}, 1, true, use_subgroup_reduce, device->subgroup_size); } } } @@ -10438,7 +10468,7 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], dst_buf}, - pc, { H, n_seqs, 1u }); + pc, { H, n_seqs, S_v }); } static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp index f008859b99..5e9f8308c1 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp @@ -1,11 +1,25 @@ #version 450 #extension GL_EXT_control_flow_attributes : require +#extension GL_KHR_shader_subgroup_basic : enable +#if USE_SUBGROUP_CLUSTERED +#extension GL_KHR_shader_subgroup_clustered : enable +#endif +#if USE_SUBGROUP_ADD +#extension GL_KHR_shader_subgroup_arithmetic : enable +#endif +// Caller guarantees valid spec constants: S_V % COLS_PER_WG == 0 and S_V % LANES_PER_COLUMN == 0, +// so no bounds checking is needed. layout(constant_id = 0) const uint S_V = 128; layout(constant_id = 1) const uint KDA = 0; +layout(constant_id = 2) const uint SUBGROUP_SIZE = 32; +layout(constant_id = 3) const uint LANES_PER_COLUMN = 32; -layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; +const uint COLS_PER_WG = SUBGROUP_SIZE / LANES_PER_COLUMN; +const uint ROWS_PER_LANE = S_V / LANES_PER_COLUMN; + +layout(local_size_x_id = 2, local_size_y = 1, local_size_z = 1) in; layout(push_constant) uniform Parameters { uint H; @@ -27,14 +41,61 @@ layout(binding = 4) readonly buffer BetaBuf { FLOAT_TYPE data_beta[]; }; layout(binding = 5) readonly buffer StateBuf { FLOAT_TYPE data_state[]; }; layout(binding = 6) buffer DstBuf { FLOAT_TYPE data_dst[]; }; -shared FLOAT_TYPE s_k[S_V]; -shared FLOAT_TYPE s_q[S_V]; -shared FLOAT_TYPE s_g[S_V]; // KDA only: cached exp(g[i]) +#if !USE_SUBGROUP_ADD && !USE_SUBGROUP_CLUSTERED +shared FLOAT_TYPE temp[SUBGROUP_SIZE]; + +// This does a reduction across groups of LANES_PER_COLUMN +FLOAT_TYPE reduce_add_shmem(FLOAT_TYPE partial) { + const uint lane = gl_SubgroupInvocationID; + temp[lane] = partial; + barrier(); + [[unroll]] for (uint s = LANES_PER_COLUMN / 2u; s > 0; s >>= 1u) { + FLOAT_TYPE other = temp[lane ^ s]; + barrier(); + temp[lane] += other; + barrier(); + } + const FLOAT_TYPE result = temp[lane]; + barrier(); + return result; +} +#endif + +// clusterSize for subgroupClusteredAdd must be a compile-time constant; branch on spec constant +FLOAT_TYPE reduce_partial(FLOAT_TYPE partial) { + switch (LANES_PER_COLUMN) { + case 1u: + return partial; +#if USE_SUBGROUP_CLUSTERED + // Workaround for GLSL requiring a literal constant for the cluster size. + // The branches should all fold away. + case 2u: + return subgroupClusteredAdd(partial, 2u); + case 4u: + return subgroupClusteredAdd(partial, 4u); + case 8u: + return subgroupClusteredAdd(partial, 8u); + case 16u: + return subgroupClusteredAdd(partial, 16u); + case 32u: + return subgroupClusteredAdd(partial, 32u); + case 64u: + return subgroupClusteredAdd(partial, 64u); +#endif + default: +#if USE_SUBGROUP_ADD + return subgroupAdd(partial); +#else + return reduce_add_shmem(partial); +#endif + } +} void main() { const uint head_id = gl_WorkGroupID.x; - const uint seq_id = gl_WorkGroupID.y; - const uint col = gl_LocalInvocationID.x; + const uint seq_id = gl_WorkGroupID.y; + const uint lane = gl_SubgroupInvocationID % LANES_PER_COLUMN; + const uint col = gl_WorkGroupID.z * COLS_PER_WG + (gl_SubgroupInvocationID / LANES_PER_COLUMN); const uint iq1 = head_id % neq1; const uint iq3 = seq_id / rq3; @@ -42,9 +103,9 @@ void main() { const uint state_size = S_V * S_V; const uint state_base = (seq_id * H + head_id) * state_size; - FLOAT_TYPE state[S_V]; - [[unroll]] for (uint i = 0; i < S_V; i++) { - state[i] = FLOAT_TYPE(data_state[state_base + col * S_V + i]); + FLOAT_TYPE s_shard[ROWS_PER_LANE]; + [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { + s_shard[r] = FLOAT_TYPE(data_state[state_base + col * S_V + r * LANES_PER_COLUMN + lane]); } uint attn_off = (seq_id * n_tokens * H + head_id) * S_V; @@ -53,76 +114,56 @@ void main() { const uint q_off = iq3 * sq3 + t * sq2 + iq1 * sq1; const uint k_off = q_off; const uint v_off = seq_id * sv3 + t * sv2 + head_id * sv1; - - s_q[col] = FLOAT_TYPE(data_q[q_off + col]); - s_k[col] = FLOAT_TYPE(data_k[k_off + col]); - const uint gb_off = seq_id * sb3 + t * sb2 + head_id * sb1; - - if (KDA != 0) { - const uint g_base = gb_off * S_V; - s_g[col] = exp(FLOAT_TYPE(data_g[g_base + col])); - } - - barrier(); - - const FLOAT_TYPE v_val = FLOAT_TYPE(data_v[v_off + col]); const FLOAT_TYPE beta_val = FLOAT_TYPE(data_beta[gb_off]); + FLOAT_TYPE k_reg[ROWS_PER_LANE]; + FLOAT_TYPE q_reg[ROWS_PER_LANE]; + [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { + const uint i = r * LANES_PER_COLUMN + lane; + k_reg[r] = FLOAT_TYPE(data_k[k_off + i]); + q_reg[r] = FLOAT_TYPE(data_q[q_off + i]); + } + + FLOAT_TYPE g_exp[ROWS_PER_LANE]; if (KDA == 0) { const FLOAT_TYPE g_val = exp(FLOAT_TYPE(data_g[gb_off])); - - FLOAT_TYPE kv_col = 0.0; - [[unroll]] for (uint i = 0; i < S_V; i += 4) { - kv_col += dot( - vec4(state[i], state[i+1], state[i+2], state[i+3]), - vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]) - ); + [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { + g_exp[r] = g_val; } - - FLOAT_TYPE delta_col = (v_val - g_val * kv_col) * beta_val; - - FLOAT_TYPE attn_col = 0.0; - [[unroll]] for (uint i = 0; i < S_V; i += 4) { - vec4 sv = vec4(state[i], state[i+1], state[i+2], state[i+3]); - vec4 kv = vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]); - sv = g_val * sv + kv * delta_col; - state[i] = sv.x; state[i+1] = sv.y; state[i+2] = sv.z; state[i+3] = sv.w; - - attn_col += dot(sv, vec4(s_q[i], s_q[i+1], s_q[i+2], s_q[i+3])); - } - - data_dst[attn_off + col] = attn_col * scale; } else { - FLOAT_TYPE kv_col = 0.0; - [[unroll]] for (uint i = 0; i < S_V; i += 4) { - vec4 gv = vec4(s_g[i], s_g[i+1], s_g[i+2], s_g[i+3]); - vec4 sv = vec4(state[i], state[i+1], state[i+2], state[i+3]); - vec4 kv = vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]); - kv_col += dot(gv * sv, kv); + const uint g_base = gb_off * S_V; + [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { + const uint i = r * LANES_PER_COLUMN + lane; + g_exp[r] = exp(FLOAT_TYPE(data_g[g_base + i])); } + } - FLOAT_TYPE delta_col = (v_val - kv_col) * beta_val; + const FLOAT_TYPE v_val = FLOAT_TYPE(data_v[v_off + col]); - FLOAT_TYPE attn_col = 0.0; - [[unroll]] for (uint i = 0; i < S_V; i += 4) { - vec4 gv = vec4(s_g[i], s_g[i+1], s_g[i+2], s_g[i+3]); - vec4 sv = vec4(state[i], state[i+1], state[i+2], state[i+3]); - vec4 kv = vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]); - sv = gv * sv + kv * delta_col; - state[i] = sv.x; state[i+1] = sv.y; state[i+2] = sv.z; state[i+3] = sv.w; + FLOAT_TYPE kv_shard = 0.0; + [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { + kv_shard += g_exp[r] * s_shard[r] * k_reg[r]; + } + FLOAT_TYPE kv_col = reduce_partial(kv_shard); - attn_col += dot(sv, vec4(s_q[i], s_q[i+1], s_q[i+2], s_q[i+3])); - } + FLOAT_TYPE delta_col = (v_val - kv_col) * beta_val; + FLOAT_TYPE attn_partial = 0.0; + [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { + s_shard[r] = g_exp[r] * s_shard[r] + k_reg[r] * delta_col; + attn_partial += s_shard[r] * q_reg[r]; + } + FLOAT_TYPE attn_col = reduce_partial(attn_partial); + + if (lane == 0) { data_dst[attn_off + col] = attn_col * scale; } attn_off += S_V * H; - barrier(); } - [[unroll]] for (uint i = 0; i < S_V; i++) { - data_dst[s_off + state_base + col * S_V + i] = state[i]; + [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { + data_dst[s_off + state_base + col * S_V + r * LANES_PER_COLUMN + lane] = s_shard[r]; } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index abd2a9c36f..8186dba36f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -987,7 +987,9 @@ void process_shaders() { string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); - string_to_spv("gated_delta_net_f32", "gated_delta_net.comp", merge_maps(base_dict, {{"FLOAT_TYPE", "float"}})); + string_to_spv("gated_delta_net_f32", "gated_delta_net.comp", merge_maps(base_dict, {{"FLOAT_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}, {"USE_SUBGROUP_CLUSTERED", "1"}})); + string_to_spv("gated_delta_net_f32_nocluster", "gated_delta_net.comp", merge_maps(base_dict, {{"FLOAT_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}, {"USE_SUBGROUP_CLUSTERED", "0"}})); + string_to_spv("gated_delta_net_f32_shmem", "gated_delta_net.comp", merge_maps(base_dict, {{"FLOAT_TYPE", "float"}, {"USE_SUBGROUP_ADD", "0"}, {"USE_SUBGROUP_CLUSTERED", "0"}})); string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));