llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp

140 lines
4.1 KiB
Plaintext

#version 450
#extension GL_EXT_control_flow_attributes : require
#extension GL_KHR_shader_subgroup_basic : enable
#extension GL_KHR_shader_subgroup_arithmetic : enable
#extension GL_KHR_shader_subgroup_shuffle : enable
#include "types.glsl"
layout (push_constant) uniform parameter
{
uint n_rows;
uint n_expert_used;
};
layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
layout(constant_id = 0) const uint WARP_SIZE = 32;
layout(constant_id = 1) const uint n_experts = 512;
layout(constant_id = 2) const bool with_norm = true;
const uint experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1;
layout (binding = 0, std430) readonly buffer Logits {float logits[];};
layout (binding = 1, std430) writeonly buffer Weights {float weights[];};
layout (binding = 2, std430) writeonly buffer Ids {uint ids[];};
void main() {
const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_LocalInvocationID.y;
if (row >= n_rows) {
return;
}
const uint logits_offset = n_experts * row;
const uint weights_offset = n_expert_used * row;
const uint ids_offset = n_experts * row;
float logits_r[experts_per_thread];
const float INFINITY = 1.0 / 0.0;
[[unroll]]
for (uint i = 0; i < n_experts; i += WARP_SIZE) {
const uint expert = i + gl_LocalInvocationID.x;
logits_r[i / WARP_SIZE] = n_experts % WARP_SIZE == 0 || expert < n_experts ? logits[logits_offset + expert] : -INFINITY;
}
float max_val = logits_r[0];
[[unroll]]
for (int i = 1; i < experts_per_thread; i++) {
const float val = logits_r[i];
max_val = max(val, max_val);
}
max_val = subgroupMax(max_val);
float wt[experts_per_thread];
float tmp = 0.f;
[[unroll]]
for (int i = 0; i < experts_per_thread; i++) {
const float val = logits_r[i];
wt[i] = exp(val - max_val);
tmp += wt[i];
}
tmp = subgroupAdd(tmp);
const float inv_sum = 1.0f / tmp;
[[unroll]]
for (int i = 0; i < experts_per_thread; i++) {
wt[i] = wt[i] * inv_sum;
}
// at this point, each thread holds a portion of softmax,
// we do the argmax reduce over n_expert_used, each time marking
// the expert weight as -inf to exclude from the next iteration
float wt_sum = 0.f;
float output_weights[experts_per_thread];
for (int k = 0; k < n_expert_used; k++) {
float max_val = wt[0];
uint max_expert = gl_LocalInvocationID.x;
[[unroll]]
for (int i = 1; i < experts_per_thread; i++) {
const uint expert = gl_LocalInvocationID.x + i * WARP_SIZE;
if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) {
max_val = wt[i];
max_expert = expert;
}
}
[[unroll]]
for (uint mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
const float val = subgroupShuffleXor(max_val, mask);
const uint expert = subgroupShuffleXor(max_expert, mask);
if (val > max_val || (val == max_val && expert < max_expert)) {
max_val = val;
max_expert = expert;
}
}
if ((k & (WARP_SIZE - 1)) == gl_LocalInvocationID.x) {
output_weights[k / WARP_SIZE] = max_val;
}
if ((max_expert & (WARP_SIZE - 1)) == gl_LocalInvocationID.x) {
wt[max_expert / WARP_SIZE] = -INFINITY;
ids[ids_offset + k] = max_expert;
if (with_norm) {
wt_sum += max_val;
}
}
}
if (with_norm) {
wt_sum = subgroupAdd(wt_sum);
const float inv_sum = 1.0f / wt_sum;
[[unroll]]
for (uint i = 0; i < experts_per_thread; ++i) {
output_weights[i] *= inv_sum;
}
}
[[unroll]]
for (uint i = 0; i < experts_per_thread; ++i) {
uint idx = i * WARP_SIZE + gl_LocalInvocationID.x;
if (idx < n_expert_used) {
weights[weights_offset + idx] = output_weights[i];
}
}
}