From 2c623bfaeaa504deecce41c7f44b58bfbf92bc22 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Fri, 6 Mar 2026 08:09:56 +0100 Subject: [PATCH] generic reductions --- .../ggml-vulkan/vulkan-shaders/common.slang | 101 ++++++++++++++++ .../vulkan-shaders/flash_attn.slang | 110 ++++-------------- 2 files changed, 121 insertions(+), 90 deletions(-) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/common.slang diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/common.slang b/ggml/src/ggml-vulkan/vulkan-shaders/common.slang new file mode 100644 index 0000000000..d083944b20 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/common.slang @@ -0,0 +1,101 @@ +module common; + +[require(spirv, subgroup_basic)] +public uint WaveGetWaveIndex() { + __target_switch + { + case spirv: + return spirv_asm { + OpCapability GroupNonUniform; + result:$$uint = OpLoad builtin(SubgroupId:uint); + }; + } +} + +public interface IReduceOp { + static T combine(T a, T b); +} + +public struct MaxOp : IReduceOp { + static T combine(T a, T b) { return max(a, b); } +} +public struct SumOp : IReduceOp { + static T combine(T a, T b) { return a + b; } +} + +public interface ISharedMemory { + static T get(uint idx); + static void set(uint idx, T value); +} + +public T reduce, ShMem: ISharedMemory>(T value, uint from, uint to) { + const uint subgroup_id = WaveGetWaveIndex(); + const uint lane_id = WaveGetLaneIndex(); + const uint from_id = lane_id % from; + const uint subgroup_size = WaveGetLaneCount(); + + // Reduce with subgroup ops first + [unroll] for (uint s = from; s < min(to, subgroup_size); s *= 2) { + value = Op::combine(value, WaveReadLaneAt(value, lane_id ^ s)); + } + + if (to > subgroup_size) { + // Reduce inside workgroup with shmem + GroupMemoryBarrierWithGroupSync(); + if (lane_id < from) { + ShMem.set(subgroup_id * from + from_id, value); + } + GroupMemoryBarrierWithGroupSync(); + value = ShMem.get(from_id); + [unroll] for (uint s = 1; s < to / subgroup_size; ++s) { + value = Op::combine(value, ShMem.get(s * from + from_id)); + } + } + + return value; +} + +public interface IReduceVecOp { + static vector combine(vector a, vector b); +} + +public struct MaxVecOp : IReduceVecOp { + static vector combine(vector a, vector b) { return max(a, b); } +} +public struct SumVecOp : IReduceOp> { + static vector combine(vector a, vector b) { return a + b; } +} + +public vector reduceVec>, ShMem: ISharedMemory>>(vector value, uint from, uint to, bool OLD_AMD_WINDOWS = false) { + const uint subgroup_id = WaveGetWaveIndex(); + const uint lane_id = WaveGetLaneIndex(); + const uint from_id = lane_id % from; + const uint subgroup_size = WaveGetLaneCount(); + + // Reduce with subgroup ops first + [unroll] for (uint s = from; s < min(to, subgroup_size); s *= 2) { + if (!OLD_AMD_WINDOWS) { + value = Op::combine(value, WaveReadLaneAt(value, lane_id ^ s)); + } else { + // Something about f16vec4 subgroupShuffleXor is broken on AMD Windows RDNA2 and below. + // Shuffle full vec4 as workaround. + // See https://github.com/ggml-org/llama.cpp/issues/19881#issuecomment-3958643697 + value = Op::combine(value, vector(WaveReadLaneAt(value, lane_id ^ s))); + } + } + + if (to > subgroup_size) { + // Reduce inside workgroup with shmem + GroupMemoryBarrierWithGroupSync(); + if (lane_id < from) { + ShMem.set(subgroup_id * from + from_id, value); + } + GroupMemoryBarrierWithGroupSync(); + value = ShMem.get(from_id); + [unroll] for (uint s = 1; s < to / subgroup_size; ++s) { + value = Op::combine(value, ShMem.get(s * from + from_id)); + } + } + + return value; +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.slang b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.slang index cb3590fdf0..30ab494bc0 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.slang +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.slang @@ -1,3 +1,4 @@ +import common; import types; import flash_attn_loader; @@ -96,7 +97,23 @@ static const uint num_subgroups = SubGroupSize == 0 ? 0 : WorkGroupSize / SubGro // If SubGroupSize is set to 0 then only use shmem reductions static const uint tmpsh_size = (SubGroupSize > 0) ? (row_split == 1 ? num_subgroups * D_split : num_subgroups) : WorkGroupSize; groupshared float tmpsh[tmpsh_size]; +struct ShMemFloat: ISharedMemory { + static float get(uint idx) { + return tmpsh[idx]; + } + static void set(uint idx, float value) { + tmpsh[idx] = value; + } +} groupshared vector tmpshv4[tmpsh_size]; +struct ShMemFloat4: ISharedMemory> { + static vector get(uint idx) { + return tmpshv4[idx]; + } + static void set(uint idx, vector value) { + tmpshv4[idx] = value; + } +} static const uint masksh_stride = Br + 1; groupshared FLOAT masksh[Bc * masksh_stride]; @@ -566,34 +583,7 @@ void main( float rowmaxf = Mf[r]; // Compute max across the row - if (SubGroupSize > 0) { - [unroll] for (uint s = D_split; s < SubGroupSize; s *= 2) { - rowmaxf = max(rowmaxf, WaveReadLaneAt(rowmaxf, subgroup_invocation_id ^ s)); - } - if (row_split == 1) { - // Reduce inside workgroup with shmem - GroupMemoryBarrierWithGroupSync(); - if (subgroup_invocation_id == d_tid) { - tmpsh[subgroup_id * D_split + d_tid] = rowmaxf; - } - GroupMemoryBarrierWithGroupSync(); - rowmaxf = tmpsh[d_tid]; - [unroll] for (uint s = 1; s < num_subgroups; ++s) { - rowmaxf = max(rowmaxf, tmpsh[s * D_split + d_tid]); - } - } - } else { - GroupMemoryBarrierWithGroupSync(); - tmpsh[tid] = rowmaxf; - GroupMemoryBarrierWithGroupSync(); - [unroll] for (int s = int(threads_per_rowgroup) / 2; s >= D_split; s >>= 1) { - if (rowgroup_tid < s) { - tmpsh[tid] = max(tmpsh[tid], tmpsh[tid ^ s]); - } - GroupMemoryBarrierWithGroupSync(); - } - rowmaxf = tmpsh[row_tid * threads_per_rowgroup + d_tid]; - } + rowmaxf = reduce, ShMemFloat>(rowmaxf, D_split, threads_per_rowgroup); float Moldf = Mf[r]; @@ -605,72 +595,12 @@ void main( Lf[r] = eMf*Lf[r]; // Compute sum across the row - if (SubGroupSize > 0) { - [unroll] for (uint s = D_split; s < SubGroupSize; s *= 2) { - Lf[r] += WaveReadLaneAt(Lf[r], subgroup_invocation_id ^ s); - } - if (row_split == 1) { - GroupMemoryBarrierWithGroupSync(); - if (subgroup_invocation_id == d_tid) { - tmpsh[subgroup_id * D_split + d_tid] = Lf[r]; - } - GroupMemoryBarrierWithGroupSync(); - Lf[r] = tmpsh[d_tid]; - [unroll] for (uint s = 1; s < num_subgroups; ++s) { - Lf[r] += tmpsh[s * D_split + d_tid]; - } - } - } else { - GroupMemoryBarrierWithGroupSync(); - tmpsh[tid] = Lf[r]; - GroupMemoryBarrierWithGroupSync(); - [unroll] for (int s = int(threads_per_rowgroup) / 2; s >= D_split; s >>= 1) { - if (rowgroup_tid < s) { - tmpsh[tid] = tmpsh[tid] + tmpsh[tid ^ s]; - } - GroupMemoryBarrierWithGroupSync(); - } - Lf[r] = tmpsh[row_tid * threads_per_rowgroup + d_tid]; - } + Lf[r] = reduce, ShMemFloat>(Lf[r], D_split, threads_per_rowgroup); [unroll] for (uint d = 0; d < HSV_per_thread / 4; ++d) { Of[r][d] = FLOAT(eMf) * Of[r][d]; - if (SubGroupSize > 0) { - [unroll] for (uint s = D_split; s < SubGroupSize; s *= 2) { - if (!OLD_AMD_WINDOWS) { - Of[r][d] += WaveReadLaneAt(Of[r][d], subgroup_invocation_id ^ s); - } else { - // Something about f16vector subgroupShuffleXor is broken on AMD Windows RDNA2 and below. - // Shuffle full vector as workaround. - // See https://github.com/ggml-org/llama.cpp/issues/19881#issuecomment-3958643697 - Of[r][d] += vector(WaveReadLaneAt(vector(Of[r][d]), subgroup_invocation_id ^ s)); - } - } - if (row_split == 1) { - GroupMemoryBarrierWithGroupSync(); - if (subgroup_invocation_id == d_tid) { - tmpshv4[subgroup_id * D_split + d_tid] = Of[r][d]; - } - GroupMemoryBarrierWithGroupSync(); - Of[r][d] = tmpshv4[d_tid]; - [unroll] for (uint s = 1; s < num_subgroups; ++s) { - Of[r][d] += tmpshv4[s * D_split + d_tid]; - } - } - } else { - GroupMemoryBarrierWithGroupSync(); - tmpshv4[tid] = Of[r][d]; - GroupMemoryBarrierWithGroupSync(); - [unroll] for (int s = int(threads_per_rowgroup) / 2; s >= D_split; s >>= 1) { - if (rowgroup_tid < s) { - Of[r][d] += tmpshv4[tid ^ s]; - tmpshv4[tid] = Of[r][d]; - } - GroupMemoryBarrierWithGroupSync(); - } - Of[r][d] = tmpshv4[row_tid * threads_per_rowgroup + d_tid]; - } + reduceVec, ShMemFloat4>(Of[r][d], D_split, threads_per_rowgroup, OLD_AMD_WINDOWS); } }