diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/common.slang b/ggml/src/ggml-vulkan/vulkan-shaders/common.slang index d083944b20..11469857ff 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/common.slang +++ b/ggml/src/ggml-vulkan/vulkan-shaders/common.slang @@ -12,89 +12,68 @@ public uint WaveGetWaveIndex() { } } -public interface IReduceOp { - static T combine(T a, T b); +public interface ISharedMemory { + static vector get(uint idx); + static void set(uint idx, vector value); } -public struct MaxOp : IReduceOp { - static T combine(T a, T b) { return max(a, b); } -} -public struct SumOp : IReduceOp { - static T combine(T a, T b) { return a + b; } -} - -public interface ISharedMemory { - static T get(uint idx); - static void set(uint idx, T value); -} - -public T reduce, ShMem: ISharedMemory>(T value, uint from, uint to) { - const uint subgroup_id = WaveGetWaveIndex(); - const uint lane_id = WaveGetLaneIndex(); - const uint from_id = lane_id % from; - const uint subgroup_size = WaveGetLaneCount(); - - // Reduce with subgroup ops first - [unroll] for (uint s = from; s < min(to, subgroup_size); s *= 2) { - value = Op::combine(value, WaveReadLaneAt(value, lane_id ^ s)); - } - - if (to > subgroup_size) { - // Reduce inside workgroup with shmem - GroupMemoryBarrierWithGroupSync(); - if (lane_id < from) { - ShMem.set(subgroup_id * from + from_id, value); - } - GroupMemoryBarrierWithGroupSync(); - value = ShMem.get(from_id); - [unroll] for (uint s = 1; s < to / subgroup_size; ++s) { - value = Op::combine(value, ShMem.get(s * from + from_id)); - } - } - - return value; -} - -public interface IReduceVecOp { +public interface IReduceOp { static vector combine(vector a, vector b); } -public struct MaxVecOp : IReduceVecOp { +public struct MaxOp : IReduceOp { static vector combine(vector a, vector b) { return max(a, b); } } -public struct SumVecOp : IReduceOp> { +public struct SumOp : IReduceOp { static vector combine(vector a, vector b) { return a + b; } } -public vector reduceVec>, ShMem: ISharedMemory>>(vector value, uint from, uint to, bool OLD_AMD_WINDOWS = false) { - const uint subgroup_id = WaveGetWaveIndex(); - const uint lane_id = WaveGetLaneIndex(); - const uint from_id = lane_id % from; - const uint subgroup_size = WaveGetLaneCount(); +public vector reduce, ShMem: ISharedMemory>(vector value, uint from, uint to, uint tid, uint subgroup_size, bool OLD_AMD_WINDOWS = false) { + if (subgroup_size > 0) { + const uint subgroup_id = WaveGetWaveIndex(); + const uint lane_id = WaveGetLaneIndex(); + const uint from_id = lane_id % from; + const uint subgroup_size = WaveGetLaneCount(); - // Reduce with subgroup ops first - [unroll] for (uint s = from; s < min(to, subgroup_size); s *= 2) { - if (!OLD_AMD_WINDOWS) { - value = Op::combine(value, WaveReadLaneAt(value, lane_id ^ s)); - } else { - // Something about f16vec4 subgroupShuffleXor is broken on AMD Windows RDNA2 and below. - // Shuffle full vec4 as workaround. - // See https://github.com/ggml-org/llama.cpp/issues/19881#issuecomment-3958643697 - value = Op::combine(value, vector(WaveReadLaneAt(value, lane_id ^ s))); + // Reduce with subgroup ops first + [unroll] for (uint s = from; s < min(to, subgroup_size); s *= 2) { + if (!OLD_AMD_WINDOWS) { + value = Op::combine(value, WaveReadLaneAt(value, lane_id ^ s)); + } else if (T is half) { + // Something about f16vec4 subgroupShuffleXor is broken on AMD Windows RDNA2 and below. + // Shuffle full vec4 as workaround. + // See https://github.com/ggml-org/llama.cpp/issues/19881#issuecomment-3958643697 + value = Op::combine(value, (WaveReadLaneAt(vector((value as vector).value), lane_id ^ s) as vector).value); + } } - } - if (to > subgroup_size) { - // Reduce inside workgroup with shmem - GroupMemoryBarrierWithGroupSync(); - if (lane_id < from) { - ShMem.set(subgroup_id * from + from_id, value); + if (to > subgroup_size) { + // Reduce inside workgroup with shmem + GroupMemoryBarrierWithGroupSync(); + if (lane_id < from) { + ShMem.set(subgroup_id * from + from_id, value); + } + GroupMemoryBarrierWithGroupSync(); + value = ShMem.get(from_id); + [unroll] for (uint s = 1; s < to / subgroup_size; ++s) { + value = Op::combine(value, ShMem.get(s * from + from_id)); + } } + } else { + const uint group_id = tid / to; + const uint group_tid = tid % to; + const uint from_id = tid % from; + GroupMemoryBarrierWithGroupSync(); - value = ShMem.get(from_id); - [unroll] for (uint s = 1; s < to / subgroup_size; ++s) { - value = Op::combine(value, ShMem.get(s * from + from_id)); + ShMem.set(tid, value); + GroupMemoryBarrierWithGroupSync(); + [unroll] for (int s = int(to) / 2; s >= from; s >>= 1) { + if (group_tid < s) { + ShMem.set(tid, Op::combine(ShMem.get(tid), ShMem.get(tid ^ s))); + } + GroupMemoryBarrierWithGroupSync(); } + value = ShMem.get(group_id * to + from_id); } return value; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.slang b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.slang index 30ab494bc0..205f6d5161 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.slang +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.slang @@ -97,7 +97,7 @@ static const uint num_subgroups = SubGroupSize == 0 ? 0 : WorkGroupSize / SubGro // If SubGroupSize is set to 0 then only use shmem reductions static const uint tmpsh_size = (SubGroupSize > 0) ? (row_split == 1 ? num_subgroups * D_split : num_subgroups) : WorkGroupSize; groupshared float tmpsh[tmpsh_size]; -struct ShMemFloat: ISharedMemory { +struct ShMemFloat: ISharedMemory { static float get(uint idx) { return tmpsh[idx]; } @@ -106,7 +106,7 @@ struct ShMemFloat: ISharedMemory { } } groupshared vector tmpshv4[tmpsh_size]; -struct ShMemFloat4: ISharedMemory> { +struct ShMemFloat4: ISharedMemory { static vector get(uint idx) { return tmpshv4[idx]; } @@ -583,7 +583,7 @@ void main( float rowmaxf = Mf[r]; // Compute max across the row - rowmaxf = reduce, ShMemFloat>(rowmaxf, D_split, threads_per_rowgroup); + rowmaxf = reduce, ShMemFloat>(rowmaxf, D_split, threads_per_rowgroup, tid, SubGroupSize); float Moldf = Mf[r]; @@ -595,12 +595,12 @@ void main( Lf[r] = eMf*Lf[r]; // Compute sum across the row - Lf[r] = reduce, ShMemFloat>(Lf[r], D_split, threads_per_rowgroup); + Lf[r] = reduce, ShMemFloat>(Lf[r], D_split, threads_per_rowgroup, tid, SubGroupSize); [unroll] for (uint d = 0; d < HSV_per_thread / 4; ++d) { Of[r][d] = FLOAT(eMf) * Of[r][d]; - reduceVec, ShMemFloat4>(Of[r][d], D_split, threads_per_rowgroup, OLD_AMD_WINDOWS); + Of[r][d] = reduce, ShMemFloat4>(Of[r][d], D_split, threads_per_rowgroup, tid, SubGroupSize, OLD_AMD_WINDOWS); } }