140 lines
4.1 KiB
Plaintext
140 lines
4.1 KiB
Plaintext
#version 450
|
|
|
|
#extension GL_EXT_control_flow_attributes : require
|
|
#extension GL_KHR_shader_subgroup_basic : enable
|
|
#extension GL_KHR_shader_subgroup_arithmetic : enable
|
|
#extension GL_KHR_shader_subgroup_shuffle : enable
|
|
|
|
#include "types.glsl"
|
|
|
|
layout (push_constant) uniform parameter
|
|
{
|
|
uint n_rows;
|
|
uint n_expert_used;
|
|
};
|
|
|
|
layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
|
|
|
|
layout(constant_id = 0) const uint WARP_SIZE = 32;
|
|
layout(constant_id = 1) const uint n_experts = 512;
|
|
layout(constant_id = 2) const bool with_norm = true;
|
|
|
|
const uint experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1;
|
|
|
|
layout (binding = 0, std430) readonly buffer Logits {float logits[];};
|
|
layout (binding = 1, std430) writeonly buffer Weights {float weights[];};
|
|
layout (binding = 2, std430) writeonly buffer Ids {uint ids[];};
|
|
|
|
void main() {
|
|
const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_LocalInvocationID.y;
|
|
if (row >= n_rows) {
|
|
return;
|
|
}
|
|
|
|
const uint logits_offset = n_experts * row;
|
|
const uint weights_offset = n_expert_used * row;
|
|
const uint ids_offset = n_experts * row;
|
|
|
|
float logits_r[experts_per_thread];
|
|
|
|
const float INFINITY = 1.0 / 0.0;
|
|
|
|
[[unroll]]
|
|
for (uint i = 0; i < n_experts; i += WARP_SIZE) {
|
|
const uint expert = i + gl_LocalInvocationID.x;
|
|
logits_r[i / WARP_SIZE] = n_experts % WARP_SIZE == 0 || expert < n_experts ? logits[logits_offset + expert] : -INFINITY;
|
|
}
|
|
|
|
float max_val = logits_r[0];
|
|
|
|
[[unroll]]
|
|
for (int i = 1; i < experts_per_thread; i++) {
|
|
const float val = logits_r[i];
|
|
max_val = max(val, max_val);
|
|
}
|
|
|
|
max_val = subgroupMax(max_val);
|
|
|
|
float wt[experts_per_thread];
|
|
float tmp = 0.f;
|
|
|
|
[[unroll]]
|
|
for (int i = 0; i < experts_per_thread; i++) {
|
|
const float val = logits_r[i];
|
|
wt[i] = exp(val - max_val);
|
|
tmp += wt[i];
|
|
}
|
|
|
|
tmp = subgroupAdd(tmp);
|
|
|
|
const float inv_sum = 1.0f / tmp;
|
|
|
|
[[unroll]]
|
|
for (int i = 0; i < experts_per_thread; i++) {
|
|
wt[i] = wt[i] * inv_sum;
|
|
}
|
|
|
|
// at this point, each thread holds a portion of softmax,
|
|
// we do the argmax reduce over n_expert_used, each time marking
|
|
// the expert weight as -inf to exclude from the next iteration
|
|
|
|
float wt_sum = 0.f;
|
|
|
|
float output_weights[experts_per_thread];
|
|
|
|
for (int k = 0; k < n_expert_used; k++) {
|
|
float max_val = wt[0];
|
|
uint max_expert = gl_LocalInvocationID.x;
|
|
|
|
[[unroll]]
|
|
for (int i = 1; i < experts_per_thread; i++) {
|
|
const uint expert = gl_LocalInvocationID.x + i * WARP_SIZE;
|
|
if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) {
|
|
max_val = wt[i];
|
|
max_expert = expert;
|
|
}
|
|
}
|
|
|
|
[[unroll]]
|
|
for (uint mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
|
|
const float val = subgroupShuffleXor(max_val, mask);
|
|
const uint expert = subgroupShuffleXor(max_expert, mask);
|
|
if (val > max_val || (val == max_val && expert < max_expert)) {
|
|
max_val = val;
|
|
max_expert = expert;
|
|
}
|
|
}
|
|
|
|
if ((k & (WARP_SIZE - 1)) == gl_LocalInvocationID.x) {
|
|
output_weights[k / WARP_SIZE] = max_val;
|
|
}
|
|
|
|
if ((max_expert & (WARP_SIZE - 1)) == gl_LocalInvocationID.x) {
|
|
wt[max_expert / WARP_SIZE] = -INFINITY;
|
|
|
|
ids[ids_offset + k] = max_expert;
|
|
if (with_norm) {
|
|
wt_sum += max_val;
|
|
}
|
|
}
|
|
}
|
|
|
|
if (with_norm) {
|
|
wt_sum = subgroupAdd(wt_sum);
|
|
const float inv_sum = 1.0f / wt_sum;
|
|
|
|
[[unroll]]
|
|
for (uint i = 0; i < experts_per_thread; ++i) {
|
|
output_weights[i] *= inv_sum;
|
|
}
|
|
}
|
|
|
|
[[unroll]]
|
|
for (uint i = 0; i < experts_per_thread; ++i) {
|
|
uint idx = i * WARP_SIZE + gl_LocalInvocationID.x;
|
|
if (idx < n_expert_used) {
|
|
weights[weights_offset + idx] = output_weights[i];
|
|
}
|
|
}
|
|
}
|