179 lines
5.2 KiB
Plaintext
179 lines
5.2 KiB
Plaintext
#version 450
|
|
|
|
#extension GL_EXT_control_flow_attributes : require
|
|
#extension GL_KHR_shader_subgroup_basic : enable
|
|
#extension GL_KHR_shader_subgroup_arithmetic : enable
|
|
#extension GL_KHR_shader_subgroup_shuffle : enable
|
|
|
|
#include "types.glsl"
|
|
|
|
layout (push_constant) uniform parameter
|
|
{
|
|
uint n_rows;
|
|
uint n_experts_push;
|
|
uint n_expert_used;
|
|
float clamp_min;
|
|
float clamp_max;
|
|
};
|
|
|
|
layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
|
|
|
|
layout(constant_id = 0) const uint WARP_SIZE = 32;
|
|
layout(constant_id = 1) const uint n_experts_spec = 512;
|
|
layout(constant_id = 2) const bool with_norm = true;
|
|
layout(constant_id = 3) const bool late_softmax = false;
|
|
layout(constant_id = 4) const bool nexperts_use_push = false;
|
|
|
|
uint n_experts = nexperts_use_push ? n_experts_push : n_experts_spec;
|
|
|
|
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
|
|
|
|
const uint experts_per_thread = CEIL_DIV(n_experts_spec, WARP_SIZE);
|
|
|
|
layout (binding = 0, std430) readonly buffer Logits {float logits[];};
|
|
layout (binding = 1, std430) writeonly buffer Weights {float weights[];};
|
|
layout (binding = 2, std430) writeonly buffer Ids {uint ids[];};
|
|
|
|
const float INFINITY = 1.0 / 0.0;
|
|
|
|
// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
|
|
void softmax_warp_inplace(inout float vals[experts_per_thread], const uint limit, const uint lane, const bool use_limit) {
|
|
float max_val = -INFINITY;
|
|
|
|
[[unroll]]
|
|
for (int i = 0; i < experts_per_thread; i++) {
|
|
const uint idx = lane + i * WARP_SIZE;
|
|
const bool is_active = !use_limit || (idx < limit);
|
|
if (is_active) {
|
|
max_val = max(max_val, vals[i]);
|
|
}
|
|
}
|
|
|
|
max_val = subgroupMax(max_val);
|
|
|
|
float sum = 0.f;
|
|
|
|
[[unroll]]
|
|
for (int i = 0; i < experts_per_thread; i++) {
|
|
const uint idx = lane + i * WARP_SIZE;
|
|
const bool is_active = !use_limit || (idx < limit);
|
|
if (is_active) {
|
|
const float val = exp(vals[i] - max_val);
|
|
vals[i] = val;
|
|
sum += val;
|
|
} else {
|
|
vals[i] = 0.f;
|
|
}
|
|
}
|
|
|
|
sum = subgroupAdd(sum);
|
|
|
|
const float inv_sum = 1.0f / sum;
|
|
|
|
[[unroll]]
|
|
for (int i = 0; i < experts_per_thread; i++) {
|
|
const uint idx = lane + i * WARP_SIZE;
|
|
const bool is_active = !use_limit || (idx < limit);
|
|
if (is_active) {
|
|
vals[i] *= inv_sum;
|
|
}
|
|
}
|
|
}
|
|
|
|
void main() {
|
|
const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_SubgroupID;
|
|
if (row >= n_rows) {
|
|
return;
|
|
}
|
|
|
|
const uint logits_offset = n_experts * row;
|
|
const uint weights_offset = n_expert_used * row;
|
|
const uint ids_offset = n_experts * row;
|
|
const uint lane = gl_SubgroupInvocationID;
|
|
|
|
float wt[experts_per_thread];
|
|
|
|
[[unroll]]
|
|
for (uint i = 0; i < n_experts; i += WARP_SIZE) {
|
|
const uint expert = i + lane;
|
|
wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[logits_offset + expert] : -INFINITY;
|
|
}
|
|
|
|
if (!late_softmax) {
|
|
softmax_warp_inplace(wt, n_experts, lane, nexperts_use_push);
|
|
}
|
|
|
|
// at this point, each thread holds a portion of softmax,
|
|
// we do the argmax reduce over n_expert_used, each time marking
|
|
// the expert weight as -inf to exclude from the next iteration
|
|
|
|
float wt_sum = 0.f;
|
|
|
|
float output_weights[experts_per_thread];
|
|
|
|
[[unroll]]
|
|
for (int i = 0; i < experts_per_thread; i++) {
|
|
output_weights[i] = 0.f;
|
|
}
|
|
|
|
for (int k = 0; k < n_expert_used; k++) {
|
|
float max_val = wt[0];
|
|
uint max_expert = lane;
|
|
|
|
[[unroll]]
|
|
for (int i = 1; i < experts_per_thread; i++) {
|
|
const uint expert = lane + i * WARP_SIZE;
|
|
if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) {
|
|
max_val = wt[i];
|
|
max_expert = expert;
|
|
}
|
|
}
|
|
|
|
[[unroll]]
|
|
for (uint mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
|
|
const float val = subgroupShuffleXor(max_val, mask);
|
|
const uint expert = subgroupShuffleXor(max_expert, mask);
|
|
if (val > max_val || (val == max_val && expert < max_expert)) {
|
|
max_val = val;
|
|
max_expert = expert;
|
|
}
|
|
}
|
|
|
|
if ((k & (WARP_SIZE - 1)) == lane) {
|
|
output_weights[k / WARP_SIZE] = max_val;
|
|
}
|
|
|
|
if ((max_expert & (WARP_SIZE - 1)) == lane) {
|
|
wt[max_expert / WARP_SIZE] = -INFINITY;
|
|
|
|
ids[ids_offset + k] = max_expert;
|
|
if (with_norm) {
|
|
wt_sum += max_val;
|
|
}
|
|
}
|
|
}
|
|
|
|
if (with_norm) {
|
|
wt_sum = subgroupAdd(wt_sum);
|
|
wt_sum = clamp(wt_sum, clamp_min, clamp_max);
|
|
const float inv_sum = 1.0f / wt_sum;
|
|
|
|
[[unroll]]
|
|
for (uint i = 0; i < experts_per_thread; ++i) {
|
|
output_weights[i] *= inv_sum;
|
|
}
|
|
}
|
|
|
|
if (late_softmax) {
|
|
softmax_warp_inplace(output_weights, n_expert_used, lane, true);
|
|
}
|
|
|
|
[[unroll]]
|
|
for (uint i = 0; i < experts_per_thread; ++i) {
|
|
uint idx = i * WARP_SIZE + lane;
|
|
if (idx < n_expert_used) {
|
|
weights[weights_offset + idx] = output_weights[i];
|
|
}
|
|
}
|
|
}
|