unify scalar+vector and fix reduce function

This commit is contained in:
Ruben Ortlam 2026-03-13 09:23:03 +01:00
parent e880cb2e0d
commit 5ec6569eb5
2 changed files with 51 additions and 72 deletions

View File

@ -12,89 +12,68 @@ public uint WaveGetWaveIndex() {
}
}
public interface IReduceOp<T> {
static T combine(T a, T b);
public interface ISharedMemory<T, uint N> {
static vector<T, N> get(uint idx);
static void set(uint idx, vector<T, N> value);
}
public struct MaxOp<T: IArithmetic> : IReduceOp<T> {
static T combine(T a, T b) { return max(a, b); }
}
public struct SumOp<T: IArithmetic> : IReduceOp<T> {
static T combine(T a, T b) { return a + b; }
}
public interface ISharedMemory<T> {
static T get(uint idx);
static void set(uint idx, T value);
}
public T reduce<T: __BuiltinType, Op: IReduceOp<T>, ShMem: ISharedMemory<T>>(T value, uint from, uint to) {
const uint subgroup_id = WaveGetWaveIndex();
const uint lane_id = WaveGetLaneIndex();
const uint from_id = lane_id % from;
const uint subgroup_size = WaveGetLaneCount();
// Reduce with subgroup ops first
[unroll] for (uint s = from; s < min(to, subgroup_size); s *= 2) {
value = Op::combine(value, WaveReadLaneAt(value, lane_id ^ s));
}
if (to > subgroup_size) {
// Reduce inside workgroup with shmem
GroupMemoryBarrierWithGroupSync();
if (lane_id < from) {
ShMem.set(subgroup_id * from + from_id, value);
}
GroupMemoryBarrierWithGroupSync();
value = ShMem.get(from_id);
[unroll] for (uint s = 1; s < to / subgroup_size; ++s) {
value = Op::combine(value, ShMem.get(s * from + from_id));
}
}
return value;
}
public interface IReduceVecOp<T, uint N> {
public interface IReduceOp<T, uint N> {
static vector<T, N> combine(vector<T, N> a, vector<T, N> b);
}
public struct MaxVecOp<T: __BuiltinFloatingPointType, uint N> : IReduceVecOp<T, N> {
public struct MaxOp<T: __BuiltinFloatingPointType, uint N> : IReduceOp<T, N> {
static vector<T, N> combine(vector<T, N> a, vector<T, N> b) { return max(a, b); }
}
public struct SumVecOp<T: __BuiltinArithmeticType, uint N> : IReduceOp<vector<T, N>> {
public struct SumOp<T: __BuiltinArithmeticType, uint N> : IReduceOp<T, N> {
static vector<T, N> combine(vector<T, N> a, vector<T, N> b) { return a + b; }
}
public vector<T, N> reduceVec<T: __BuiltinType, uint N, Op: IReduceOp<vector<T, N>>, ShMem: ISharedMemory<vector<T, N>>>(vector<T, N> value, uint from, uint to, bool OLD_AMD_WINDOWS = false) {
const uint subgroup_id = WaveGetWaveIndex();
const uint lane_id = WaveGetLaneIndex();
const uint from_id = lane_id % from;
const uint subgroup_size = WaveGetLaneCount();
public vector<T, N> reduce<T: __BuiltinType, uint N, Op: IReduceOp<T, N>, ShMem: ISharedMemory<T, N>>(vector<T, N> value, uint from, uint to, uint tid, uint subgroup_size, bool OLD_AMD_WINDOWS = false) {
if (subgroup_size > 0) {
const uint subgroup_id = WaveGetWaveIndex();
const uint lane_id = WaveGetLaneIndex();
const uint from_id = lane_id % from;
const uint subgroup_size = WaveGetLaneCount();
// Reduce with subgroup ops first
[unroll] for (uint s = from; s < min(to, subgroup_size); s *= 2) {
if (!OLD_AMD_WINDOWS) {
value = Op::combine(value, WaveReadLaneAt(value, lane_id ^ s));
} else {
// Something about f16vec4 subgroupShuffleXor is broken on AMD Windows RDNA2 and below.
// Shuffle full vec4 as workaround.
// See https://github.com/ggml-org/llama.cpp/issues/19881#issuecomment-3958643697
value = Op::combine(value, vector<T, N>(WaveReadLaneAt(value, lane_id ^ s)));
// Reduce with subgroup ops first
[unroll] for (uint s = from; s < min(to, subgroup_size); s *= 2) {
if (!OLD_AMD_WINDOWS) {
value = Op::combine(value, WaveReadLaneAt(value, lane_id ^ s));
} else if (T is half) {
// Something about f16vec4 subgroupShuffleXor is broken on AMD Windows RDNA2 and below.
// Shuffle full vec4 as workaround.
// See https://github.com/ggml-org/llama.cpp/issues/19881#issuecomment-3958643697
value = Op::combine(value, (WaveReadLaneAt(vector<float, N>((value as vector<half, N>).value), lane_id ^ s) as vector<T, N>).value);
}
}
}
if (to > subgroup_size) {
// Reduce inside workgroup with shmem
GroupMemoryBarrierWithGroupSync();
if (lane_id < from) {
ShMem.set(subgroup_id * from + from_id, value);
if (to > subgroup_size) {
// Reduce inside workgroup with shmem
GroupMemoryBarrierWithGroupSync();
if (lane_id < from) {
ShMem.set(subgroup_id * from + from_id, value);
}
GroupMemoryBarrierWithGroupSync();
value = ShMem.get(from_id);
[unroll] for (uint s = 1; s < to / subgroup_size; ++s) {
value = Op::combine(value, ShMem.get(s * from + from_id));
}
}
} else {
const uint group_id = tid / to;
const uint group_tid = tid % to;
const uint from_id = tid % from;
GroupMemoryBarrierWithGroupSync();
value = ShMem.get(from_id);
[unroll] for (uint s = 1; s < to / subgroup_size; ++s) {
value = Op::combine(value, ShMem.get(s * from + from_id));
ShMem.set(tid, value);
GroupMemoryBarrierWithGroupSync();
[unroll] for (int s = int(to) / 2; s >= from; s >>= 1) {
if (group_tid < s) {
ShMem.set(tid, Op::combine(ShMem.get(tid), ShMem.get(tid ^ s)));
}
GroupMemoryBarrierWithGroupSync();
}
value = ShMem.get(group_id * to + from_id);
}
return value;

View File

@ -97,7 +97,7 @@ static const uint num_subgroups = SubGroupSize == 0 ? 0 : WorkGroupSize / SubGro
// If SubGroupSize is set to 0 then only use shmem reductions
static const uint tmpsh_size = (SubGroupSize > 0) ? (row_split == 1 ? num_subgroups * D_split : num_subgroups) : WorkGroupSize;
groupshared float tmpsh[tmpsh_size];
struct ShMemFloat: ISharedMemory<float> {
struct ShMemFloat: ISharedMemory<float, 1> {
static float get(uint idx) {
return tmpsh[idx];
}
@ -106,7 +106,7 @@ struct ShMemFloat: ISharedMemory<float> {
}
}
groupshared vector<FLOAT, 4> tmpshv4[tmpsh_size];
struct ShMemFloat4: ISharedMemory<vector<FLOAT, 4>> {
struct ShMemFloat4: ISharedMemory<FLOAT, 4> {
static vector<FLOAT, 4> get(uint idx) {
return tmpshv4[idx];
}
@ -583,7 +583,7 @@ void main(
float rowmaxf = Mf[r];
// Compute max across the row
rowmaxf = reduce<float, MaxOp<float>, ShMemFloat>(rowmaxf, D_split, threads_per_rowgroup);
rowmaxf = reduce<float, 1, MaxOp<float, 1>, ShMemFloat>(rowmaxf, D_split, threads_per_rowgroup, tid, SubGroupSize);
float Moldf = Mf[r];
@ -595,12 +595,12 @@ void main(
Lf[r] = eMf*Lf[r];
// Compute sum across the row
Lf[r] = reduce<float, SumOp<float>, ShMemFloat>(Lf[r], D_split, threads_per_rowgroup);
Lf[r] = reduce<float, 1, SumOp<float, 1>, ShMemFloat>(Lf[r], D_split, threads_per_rowgroup, tid, SubGroupSize);
[unroll] for (uint d = 0; d < HSV_per_thread / 4; ++d) {
Of[r][d] = FLOAT(eMf) * Of[r][d];
reduceVec<FLOAT, 4, SumVecOp<FLOAT, 4>, ShMemFloat4>(Of[r][d], D_split, threads_per_rowgroup, OLD_AMD_WINDOWS);
Of[r][d] = reduce<FLOAT, 4, SumOp<FLOAT, 4>, ShMemFloat4>(Of[r][d], D_split, threads_per_rowgroup, tid, SubGroupSize, OLD_AMD_WINDOWS);
}
}