52 lines
1.2 KiB
Plaintext
52 lines
1.2 KiB
Plaintext
#version 450
|
|
|
|
#extension GL_EXT_control_flow_attributes : enable
|
|
|
|
#include "types.glsl"
|
|
|
|
layout (push_constant) uniform parameter
|
|
{
|
|
uint32_t ne00;
|
|
uint32_t ne01;
|
|
uint32_t nb00;
|
|
uint32_t nb01;
|
|
uint32_t a_offset;
|
|
} p;
|
|
|
|
#define BLOCK_SIZE 256
|
|
|
|
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
|
|
|
layout (binding = 0) readonly buffer A {uint data_a[];};
|
|
layout (binding = 1) writeonly buffer D {uint data_d[];};
|
|
|
|
shared uint vals[BLOCK_SIZE];
|
|
|
|
void main() {
|
|
const uint expert_id = gl_WorkGroupID.x;
|
|
const uint num_elements = p.ne00 * p.ne01;
|
|
const uint tid = gl_LocalInvocationID.x;
|
|
|
|
uint count = 0;
|
|
for (uint idx = tid; idx < num_elements; idx += BLOCK_SIZE) {
|
|
const uint i01 = idx / p.ne00;
|
|
const uint i00 = idx % p.ne00;
|
|
const uint a = data_a[p.a_offset + i01 * p.nb01 + i00 * p.nb00];
|
|
|
|
count += uint(a == expert_id);
|
|
}
|
|
|
|
vals[tid] = count;
|
|
barrier();
|
|
[[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
|
|
if (tid < s) {
|
|
vals[tid] += vals[tid + s];
|
|
}
|
|
barrier();
|
|
}
|
|
|
|
if (tid == 0) {
|
|
data_d[expert_id] = vals[0];
|
|
}
|
|
}
|