generic reductions
This commit is contained in:
parent
e1b40fa53a
commit
2c623bfaea
|
|
@ -0,0 +1,101 @@
|
||||||
|
module common;
|
||||||
|
|
||||||
|
[require(spirv, subgroup_basic)]
|
||||||
|
public uint WaveGetWaveIndex() {
|
||||||
|
__target_switch
|
||||||
|
{
|
||||||
|
case spirv:
|
||||||
|
return spirv_asm {
|
||||||
|
OpCapability GroupNonUniform;
|
||||||
|
result:$$uint = OpLoad builtin(SubgroupId:uint);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public interface IReduceOp<T> {
|
||||||
|
static T combine(T a, T b);
|
||||||
|
}
|
||||||
|
|
||||||
|
public struct MaxOp<T: IArithmetic> : IReduceOp<T> {
|
||||||
|
static T combine(T a, T b) { return max(a, b); }
|
||||||
|
}
|
||||||
|
public struct SumOp<T: IArithmetic> : IReduceOp<T> {
|
||||||
|
static T combine(T a, T b) { return a + b; }
|
||||||
|
}
|
||||||
|
|
||||||
|
public interface ISharedMemory<T> {
|
||||||
|
static T get(uint idx);
|
||||||
|
static void set(uint idx, T value);
|
||||||
|
}
|
||||||
|
|
||||||
|
public T reduce<T: __BuiltinType, Op: IReduceOp<T>, ShMem: ISharedMemory<T>>(T value, uint from, uint to) {
|
||||||
|
const uint subgroup_id = WaveGetWaveIndex();
|
||||||
|
const uint lane_id = WaveGetLaneIndex();
|
||||||
|
const uint from_id = lane_id % from;
|
||||||
|
const uint subgroup_size = WaveGetLaneCount();
|
||||||
|
|
||||||
|
// Reduce with subgroup ops first
|
||||||
|
[unroll] for (uint s = from; s < min(to, subgroup_size); s *= 2) {
|
||||||
|
value = Op::combine(value, WaveReadLaneAt(value, lane_id ^ s));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (to > subgroup_size) {
|
||||||
|
// Reduce inside workgroup with shmem
|
||||||
|
GroupMemoryBarrierWithGroupSync();
|
||||||
|
if (lane_id < from) {
|
||||||
|
ShMem.set(subgroup_id * from + from_id, value);
|
||||||
|
}
|
||||||
|
GroupMemoryBarrierWithGroupSync();
|
||||||
|
value = ShMem.get(from_id);
|
||||||
|
[unroll] for (uint s = 1; s < to / subgroup_size; ++s) {
|
||||||
|
value = Op::combine(value, ShMem.get(s * from + from_id));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
public interface IReduceVecOp<T, uint N> {
|
||||||
|
static vector<T, N> combine(vector<T, N> a, vector<T, N> b);
|
||||||
|
}
|
||||||
|
|
||||||
|
public struct MaxVecOp<T: __BuiltinFloatingPointType, uint N> : IReduceVecOp<T, N> {
|
||||||
|
static vector<T, N> combine(vector<T, N> a, vector<T, N> b) { return max(a, b); }
|
||||||
|
}
|
||||||
|
public struct SumVecOp<T: __BuiltinArithmeticType, uint N> : IReduceOp<vector<T, N>> {
|
||||||
|
static vector<T, N> combine(vector<T, N> a, vector<T, N> b) { return a + b; }
|
||||||
|
}
|
||||||
|
|
||||||
|
public vector<T, N> reduceVec<T: __BuiltinType, uint N, Op: IReduceOp<vector<T, N>>, ShMem: ISharedMemory<vector<T, N>>>(vector<T, N> value, uint from, uint to, bool OLD_AMD_WINDOWS = false) {
|
||||||
|
const uint subgroup_id = WaveGetWaveIndex();
|
||||||
|
const uint lane_id = WaveGetLaneIndex();
|
||||||
|
const uint from_id = lane_id % from;
|
||||||
|
const uint subgroup_size = WaveGetLaneCount();
|
||||||
|
|
||||||
|
// Reduce with subgroup ops first
|
||||||
|
[unroll] for (uint s = from; s < min(to, subgroup_size); s *= 2) {
|
||||||
|
if (!OLD_AMD_WINDOWS) {
|
||||||
|
value = Op::combine(value, WaveReadLaneAt(value, lane_id ^ s));
|
||||||
|
} else {
|
||||||
|
// Something about f16vec4 subgroupShuffleXor is broken on AMD Windows RDNA2 and below.
|
||||||
|
// Shuffle full vec4 as workaround.
|
||||||
|
// See https://github.com/ggml-org/llama.cpp/issues/19881#issuecomment-3958643697
|
||||||
|
value = Op::combine(value, vector<T, N>(WaveReadLaneAt(value, lane_id ^ s)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (to > subgroup_size) {
|
||||||
|
// Reduce inside workgroup with shmem
|
||||||
|
GroupMemoryBarrierWithGroupSync();
|
||||||
|
if (lane_id < from) {
|
||||||
|
ShMem.set(subgroup_id * from + from_id, value);
|
||||||
|
}
|
||||||
|
GroupMemoryBarrierWithGroupSync();
|
||||||
|
value = ShMem.get(from_id);
|
||||||
|
[unroll] for (uint s = 1; s < to / subgroup_size; ++s) {
|
||||||
|
value = Op::combine(value, ShMem.get(s * from + from_id));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
import common;
|
||||||
import types;
|
import types;
|
||||||
import flash_attn_loader;
|
import flash_attn_loader;
|
||||||
|
|
||||||
|
|
@ -96,7 +97,23 @@ static const uint num_subgroups = SubGroupSize == 0 ? 0 : WorkGroupSize / SubGro
|
||||||
// If SubGroupSize is set to 0 then only use shmem reductions
|
// If SubGroupSize is set to 0 then only use shmem reductions
|
||||||
static const uint tmpsh_size = (SubGroupSize > 0) ? (row_split == 1 ? num_subgroups * D_split : num_subgroups) : WorkGroupSize;
|
static const uint tmpsh_size = (SubGroupSize > 0) ? (row_split == 1 ? num_subgroups * D_split : num_subgroups) : WorkGroupSize;
|
||||||
groupshared float tmpsh[tmpsh_size];
|
groupshared float tmpsh[tmpsh_size];
|
||||||
|
struct ShMemFloat: ISharedMemory<float> {
|
||||||
|
static float get(uint idx) {
|
||||||
|
return tmpsh[idx];
|
||||||
|
}
|
||||||
|
static void set(uint idx, float value) {
|
||||||
|
tmpsh[idx] = value;
|
||||||
|
}
|
||||||
|
}
|
||||||
groupshared vector<FLOAT, 4> tmpshv4[tmpsh_size];
|
groupshared vector<FLOAT, 4> tmpshv4[tmpsh_size];
|
||||||
|
struct ShMemFloat4: ISharedMemory<vector<FLOAT, 4>> {
|
||||||
|
static vector<FLOAT, 4> get(uint idx) {
|
||||||
|
return tmpshv4[idx];
|
||||||
|
}
|
||||||
|
static void set(uint idx, vector<FLOAT, 4> value) {
|
||||||
|
tmpshv4[idx] = value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static const uint masksh_stride = Br + 1;
|
static const uint masksh_stride = Br + 1;
|
||||||
groupshared FLOAT masksh[Bc * masksh_stride];
|
groupshared FLOAT masksh[Bc * masksh_stride];
|
||||||
|
|
@ -566,34 +583,7 @@ void main(
|
||||||
float rowmaxf = Mf[r];
|
float rowmaxf = Mf[r];
|
||||||
|
|
||||||
// Compute max across the row
|
// Compute max across the row
|
||||||
if (SubGroupSize > 0) {
|
rowmaxf = reduce<float, MaxOp<float>, ShMemFloat>(rowmaxf, D_split, threads_per_rowgroup);
|
||||||
[unroll] for (uint s = D_split; s < SubGroupSize; s *= 2) {
|
|
||||||
rowmaxf = max(rowmaxf, WaveReadLaneAt(rowmaxf, subgroup_invocation_id ^ s));
|
|
||||||
}
|
|
||||||
if (row_split == 1) {
|
|
||||||
// Reduce inside workgroup with shmem
|
|
||||||
GroupMemoryBarrierWithGroupSync();
|
|
||||||
if (subgroup_invocation_id == d_tid) {
|
|
||||||
tmpsh[subgroup_id * D_split + d_tid] = rowmaxf;
|
|
||||||
}
|
|
||||||
GroupMemoryBarrierWithGroupSync();
|
|
||||||
rowmaxf = tmpsh[d_tid];
|
|
||||||
[unroll] for (uint s = 1; s < num_subgroups; ++s) {
|
|
||||||
rowmaxf = max(rowmaxf, tmpsh[s * D_split + d_tid]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
GroupMemoryBarrierWithGroupSync();
|
|
||||||
tmpsh[tid] = rowmaxf;
|
|
||||||
GroupMemoryBarrierWithGroupSync();
|
|
||||||
[unroll] for (int s = int(threads_per_rowgroup) / 2; s >= D_split; s >>= 1) {
|
|
||||||
if (rowgroup_tid < s) {
|
|
||||||
tmpsh[tid] = max(tmpsh[tid], tmpsh[tid ^ s]);
|
|
||||||
}
|
|
||||||
GroupMemoryBarrierWithGroupSync();
|
|
||||||
}
|
|
||||||
rowmaxf = tmpsh[row_tid * threads_per_rowgroup + d_tid];
|
|
||||||
}
|
|
||||||
|
|
||||||
float Moldf = Mf[r];
|
float Moldf = Mf[r];
|
||||||
|
|
||||||
|
|
@ -605,72 +595,12 @@ void main(
|
||||||
Lf[r] = eMf*Lf[r];
|
Lf[r] = eMf*Lf[r];
|
||||||
|
|
||||||
// Compute sum across the row
|
// Compute sum across the row
|
||||||
if (SubGroupSize > 0) {
|
Lf[r] = reduce<float, SumOp<float>, ShMemFloat>(Lf[r], D_split, threads_per_rowgroup);
|
||||||
[unroll] for (uint s = D_split; s < SubGroupSize; s *= 2) {
|
|
||||||
Lf[r] += WaveReadLaneAt(Lf[r], subgroup_invocation_id ^ s);
|
|
||||||
}
|
|
||||||
if (row_split == 1) {
|
|
||||||
GroupMemoryBarrierWithGroupSync();
|
|
||||||
if (subgroup_invocation_id == d_tid) {
|
|
||||||
tmpsh[subgroup_id * D_split + d_tid] = Lf[r];
|
|
||||||
}
|
|
||||||
GroupMemoryBarrierWithGroupSync();
|
|
||||||
Lf[r] = tmpsh[d_tid];
|
|
||||||
[unroll] for (uint s = 1; s < num_subgroups; ++s) {
|
|
||||||
Lf[r] += tmpsh[s * D_split + d_tid];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
GroupMemoryBarrierWithGroupSync();
|
|
||||||
tmpsh[tid] = Lf[r];
|
|
||||||
GroupMemoryBarrierWithGroupSync();
|
|
||||||
[unroll] for (int s = int(threads_per_rowgroup) / 2; s >= D_split; s >>= 1) {
|
|
||||||
if (rowgroup_tid < s) {
|
|
||||||
tmpsh[tid] = tmpsh[tid] + tmpsh[tid ^ s];
|
|
||||||
}
|
|
||||||
GroupMemoryBarrierWithGroupSync();
|
|
||||||
}
|
|
||||||
Lf[r] = tmpsh[row_tid * threads_per_rowgroup + d_tid];
|
|
||||||
}
|
|
||||||
|
|
||||||
[unroll] for (uint d = 0; d < HSV_per_thread / 4; ++d) {
|
[unroll] for (uint d = 0; d < HSV_per_thread / 4; ++d) {
|
||||||
Of[r][d] = FLOAT(eMf) * Of[r][d];
|
Of[r][d] = FLOAT(eMf) * Of[r][d];
|
||||||
|
|
||||||
if (SubGroupSize > 0) {
|
reduceVec<FLOAT, 4, SumVecOp<FLOAT, 4>, ShMemFloat4>(Of[r][d], D_split, threads_per_rowgroup, OLD_AMD_WINDOWS);
|
||||||
[unroll] for (uint s = D_split; s < SubGroupSize; s *= 2) {
|
|
||||||
if (!OLD_AMD_WINDOWS) {
|
|
||||||
Of[r][d] += WaveReadLaneAt(Of[r][d], subgroup_invocation_id ^ s);
|
|
||||||
} else {
|
|
||||||
// Something about f16vector<float, 4> subgroupShuffleXor is broken on AMD Windows RDNA2 and below.
|
|
||||||
// Shuffle full vector<float, 4> as workaround.
|
|
||||||
// See https://github.com/ggml-org/llama.cpp/issues/19881#issuecomment-3958643697
|
|
||||||
Of[r][d] += vector<FLOAT, 4>(WaveReadLaneAt(vector<float, 4>(Of[r][d]), subgroup_invocation_id ^ s));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (row_split == 1) {
|
|
||||||
GroupMemoryBarrierWithGroupSync();
|
|
||||||
if (subgroup_invocation_id == d_tid) {
|
|
||||||
tmpshv4[subgroup_id * D_split + d_tid] = Of[r][d];
|
|
||||||
}
|
|
||||||
GroupMemoryBarrierWithGroupSync();
|
|
||||||
Of[r][d] = tmpshv4[d_tid];
|
|
||||||
[unroll] for (uint s = 1; s < num_subgroups; ++s) {
|
|
||||||
Of[r][d] += tmpshv4[s * D_split + d_tid];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
GroupMemoryBarrierWithGroupSync();
|
|
||||||
tmpshv4[tid] = Of[r][d];
|
|
||||||
GroupMemoryBarrierWithGroupSync();
|
|
||||||
[unroll] for (int s = int(threads_per_rowgroup) / 2; s >= D_split; s >>= 1) {
|
|
||||||
if (rowgroup_tid < s) {
|
|
||||||
Of[r][d] += tmpshv4[tid ^ s];
|
|
||||||
tmpshv4[tid] = Of[r][d];
|
|
||||||
}
|
|
||||||
GroupMemoryBarrierWithGroupSync();
|
|
||||||
}
|
|
||||||
Of[r][d] = tmpshv4[row_tid * threads_per_rowgroup + d_tid];
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue