196 lines
7.2 KiB
Plaintext
196 lines
7.2 KiB
Plaintext
#version 450
|
|
|
|
#extension GL_EXT_shader_16bit_storage : require
|
|
#extension GL_EXT_nonuniform_qualifier : enable
|
|
#extension GL_EXT_control_flow_attributes : require
|
|
#if ADD_RMS
|
|
#extension GL_KHR_shader_subgroup_arithmetic : enable
|
|
#extension GL_KHR_shader_subgroup_basic : enable
|
|
#endif
|
|
|
|
#include "rte.glsl"
|
|
#include "types.glsl"
|
|
#include "utils.glsl"
|
|
|
|
layout (push_constant) uniform parameter2
|
|
{
|
|
// shape for dst
|
|
uint ne20; uint ne21; uint ne22; uint ne23;
|
|
|
|
// strides for srcs+dst
|
|
uint nb[12][4];
|
|
|
|
uint rms_partials;
|
|
} p;
|
|
|
|
// No readonly/writeonly decorations. Workaround for MoltenVK Bug, see https://github.com/ggml-org/llama.cpp/issues/15498
|
|
layout (binding = 0) buffer A0 {A_TYPE data_a[];} a0;
|
|
layout (binding = 1) buffer A1 {A_TYPE data_a[];} a1;
|
|
layout (binding = 2) buffer A2 {A_TYPE data_a[];} a2;
|
|
layout (binding = 3) buffer A3 {A_TYPE data_a[];} a3;
|
|
layout (binding = 4) buffer A4 {A_TYPE data_a[];} a4;
|
|
layout (binding = 5) buffer A5 {A_TYPE data_a[];} a5;
|
|
layout (binding = 6) buffer A6 {A_TYPE data_a[];} a6;
|
|
layout (binding = 7) buffer A7 {A_TYPE data_a[];} a7;
|
|
layout (binding = 8) buffer A8 {A_TYPE data_a[];} a8;
|
|
layout (binding = 9) buffer A9 {A_TYPE data_a[];} a9;
|
|
layout (binding = 10) buffer A10 {A_TYPE data_a[];} a10;
|
|
layout (binding = 11) buffer A11 {A_TYPE data_a[];} a11;
|
|
layout (binding = 0) buffer D0 {D_TYPE data_d[];} d0;
|
|
layout (binding = 1) buffer D1 {D_TYPE data_d[];} d1;
|
|
layout (binding = 2) buffer D2 {D_TYPE data_d[];} d2;
|
|
layout (binding = 3) buffer D3 {D_TYPE data_d[];} d3;
|
|
layout (binding = 4) buffer D4 {D_TYPE data_d[];} d4;
|
|
layout (binding = 5) buffer D5 {D_TYPE data_d[];} d5;
|
|
layout (binding = 6) buffer D6 {D_TYPE data_d[];} d6;
|
|
layout (binding = 7) buffer D7 {D_TYPE data_d[];} d7;
|
|
layout (binding = 8) buffer D8 {D_TYPE data_d[];} d8;
|
|
layout (binding = 9) buffer D9 {D_TYPE data_d[];} d9;
|
|
layout (binding = 10) buffer D10 {D_TYPE data_d[];} d10;
|
|
layout (binding = 11) buffer D11 {D_TYPE data_d[];} d11;
|
|
layout (binding = 0, std430) buffer PartialBuf0 {float partial_sums[];} partials0;
|
|
layout (binding = 1, std430) buffer PartialBuf1 {float partial_sums[];} partials1;
|
|
layout (binding = 2, std430) buffer PartialBuf2 {float partial_sums[];} partials2;
|
|
layout (binding = 3, std430) buffer PartialBuf3 {float partial_sums[];} partials3;
|
|
layout (binding = 4, std430) buffer PartialBuf4 {float partial_sums[];} partials4;
|
|
layout (binding = 5, std430) buffer PartialBuf5 {float partial_sums[];} partials5;
|
|
layout (binding = 6, std430) buffer PartialBuf6 {float partial_sums[];} partials6;
|
|
layout (binding = 7, std430) buffer PartialBuf7 {float partial_sums[];} partials7;
|
|
layout (binding = 8, std430) buffer PartialBuf8 {float partial_sums[];} partials8;
|
|
layout (binding = 9, std430) buffer PartialBuf9 {float partial_sums[];} partials9;
|
|
layout (binding = 10, std430) buffer PartialBuf10 {float partial_sums[];} partials10;
|
|
layout (binding = 11, std430) buffer PartialBuf11 {float partial_sums[];} partials11;
|
|
|
|
layout(constant_id = 0) const uint num_srcs = 2;
|
|
|
|
FLOAT_TYPE load_a(uint b, uint i) {
|
|
switch (b) {
|
|
case 0: return FLOAT_TYPE(a0.data_a[i]);
|
|
case 1: return FLOAT_TYPE(a1.data_a[i]);
|
|
case 2: return FLOAT_TYPE(a2.data_a[i]);
|
|
case 3: return FLOAT_TYPE(a3.data_a[i]);
|
|
case 4: return FLOAT_TYPE(a4.data_a[i]);
|
|
case 5: return FLOAT_TYPE(a5.data_a[i]);
|
|
case 6: return FLOAT_TYPE(a6.data_a[i]);
|
|
case 7: return FLOAT_TYPE(a7.data_a[i]);
|
|
case 8: return FLOAT_TYPE(a8.data_a[i]);
|
|
case 9: return FLOAT_TYPE(a9.data_a[i]);
|
|
case 10: return FLOAT_TYPE(a10.data_a[i]);
|
|
case 11: return FLOAT_TYPE(a11.data_a[i]);
|
|
default: return FLOAT_TYPE(0);
|
|
}
|
|
}
|
|
|
|
void store_d(uint b, uint i, FLOAT_TYPE v) {
|
|
switch (b) {
|
|
case 0: d0.data_d[i] = D_TYPE(v); break;
|
|
case 1: d1.data_d[i] = D_TYPE(v); break;
|
|
case 2: d2.data_d[i] = D_TYPE(v); break;
|
|
case 3: d3.data_d[i] = D_TYPE(v); break;
|
|
case 4: d4.data_d[i] = D_TYPE(v); break;
|
|
case 5: d5.data_d[i] = D_TYPE(v); break;
|
|
case 6: d6.data_d[i] = D_TYPE(v); break;
|
|
case 7: d7.data_d[i] = D_TYPE(v); break;
|
|
case 8: d8.data_d[i] = D_TYPE(v); break;
|
|
case 9: d9.data_d[i] = D_TYPE(v); break;
|
|
case 10: d10.data_d[i] = D_TYPE(v); break;
|
|
case 11: d11.data_d[i] = D_TYPE(v); break;
|
|
default: break;
|
|
}
|
|
}
|
|
|
|
void store_partial(uint b, uint i, float v) {
|
|
switch (b) {
|
|
case 0: partials0.partial_sums[i] = v; break;
|
|
case 1: partials1.partial_sums[i] = v; break;
|
|
case 2: partials2.partial_sums[i] = v; break;
|
|
case 3: partials3.partial_sums[i] = v; break;
|
|
case 4: partials4.partial_sums[i] = v; break;
|
|
case 5: partials5.partial_sums[i] = v; break;
|
|
case 6: partials6.partial_sums[i] = v; break;
|
|
case 7: partials7.partial_sums[i] = v; break;
|
|
case 8: partials8.partial_sums[i] = v; break;
|
|
case 9: partials9.partial_sums[i] = v; break;
|
|
case 10: partials10.partial_sums[i] = v; break;
|
|
case 11: partials11.partial_sums[i] = v; break;
|
|
default: break;
|
|
}
|
|
}
|
|
|
|
uint src_idx(uint s, uint i00, uint i01, uint i02, uint i03) {
|
|
return i03*p.nb[s][3] + i02*p.nb[s][2] + i01*p.nb[s][1] + i00*p.nb[s][0];
|
|
}
|
|
|
|
uint dst_idx(uint i00, uint i01, uint i02, uint i03) {
|
|
uint nb20 = p.nb[num_srcs][0];
|
|
uint nb21 = p.nb[num_srcs][1];
|
|
uint nb22 = p.nb[num_srcs][2];
|
|
uint nb23 = p.nb[num_srcs][3];
|
|
return i03*nb23 + i02*nb22 + i01*nb21 + i00*nb20;
|
|
}
|
|
|
|
uint get_idx() {
|
|
return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
|
}
|
|
|
|
const uint num_threads = 256;
|
|
|
|
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
|
|
|
|
#if ADD_RMS
|
|
// XXX TODO this could be sized based on number of subgroups, but that't not considered a constant
|
|
shared FLOAT_TYPE sumsh[num_threads];
|
|
#endif
|
|
|
|
void main() {
|
|
uint idx = get_idx();
|
|
uint orig_idx = idx;
|
|
|
|
uint ne = p.ne20 * p.ne21 * p.ne22 * p.ne23;
|
|
|
|
// num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
|
|
const uint num_iter = 2;
|
|
|
|
FLOAT_TYPE sum_sq = 0;
|
|
|
|
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
|
|
if (idx >= ne) {
|
|
continue;
|
|
}
|
|
uint i00, i01, i02, i03;
|
|
get_indices(idx, i00, i01, i02, i03, p.ne20, p.ne21, p.ne22, p.ne23);
|
|
|
|
FLOAT_TYPE sum = FLOAT_TYPE(0);
|
|
[[unroll]] for (uint s = 0; s < num_srcs; ++s) {
|
|
sum += load_a(s, src_idx(s, i00, i01, i02, i03));
|
|
}
|
|
sum_sq += sum*sum;
|
|
store_d(num_srcs, dst_idx(i00, i01, i02, i03), sum);
|
|
|
|
idx += num_threads;
|
|
}
|
|
|
|
#if ADD_RMS
|
|
if (p.rms_partials != 0) {
|
|
// reduce the sum within each subgroup, then across subgroups
|
|
const uint NumSubgroups = num_threads / gl_SubgroupSize;
|
|
sum_sq = subgroupAdd(sum_sq);
|
|
if (gl_SubgroupInvocationID == 0) {
|
|
sumsh[gl_SubgroupID] = sum_sq;
|
|
}
|
|
barrier();
|
|
[[unroll]] for (uint s = NumSubgroups / 2; s > 0; s >>= 1) {
|
|
if (gl_SubgroupID < s && gl_SubgroupInvocationID == 0) {
|
|
sum_sq += sumsh[gl_SubgroupID + s];
|
|
sumsh[gl_SubgroupID] = sum_sq;
|
|
}
|
|
barrier();
|
|
}
|
|
|
|
if (gl_SubgroupID == 0 && gl_SubgroupInvocationID == 0) {
|
|
store_partial(num_srcs + 1, orig_idx / (num_iter * num_threads), sum_sq);
|
|
}
|
|
}
|
|
#endif
|
|
}
|