llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp

179 lines
5.2 KiB
Plaintext

#version 450
#extension GL_EXT_control_flow_attributes : require
#extension GL_KHR_shader_subgroup_basic : enable
#extension GL_KHR_shader_subgroup_arithmetic : enable
#extension GL_KHR_shader_subgroup_shuffle : enable
#include "types.glsl"
layout (push_constant) uniform parameter
{
uint n_rows;
uint n_experts_push;
uint n_expert_used;
float clamp_min;
float clamp_max;
};
layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
layout(constant_id = 0) const uint WARP_SIZE = 32;
layout(constant_id = 1) const uint n_experts_spec = 512;
layout(constant_id = 2) const bool with_norm = true;
layout(constant_id = 3) const bool late_softmax = false;
layout(constant_id = 4) const bool nexperts_use_push = false;
uint n_experts = nexperts_use_push ? n_experts_push : n_experts_spec;
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
const uint experts_per_thread = CEIL_DIV(n_experts_spec, WARP_SIZE);
layout (binding = 0, std430) readonly buffer Logits {float logits[];};
layout (binding = 1, std430) writeonly buffer Weights {float weights[];};
layout (binding = 2, std430) writeonly buffer Ids {uint ids[];};
const float INFINITY = 1.0 / 0.0;
// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
void softmax_warp_inplace(inout float vals[experts_per_thread], const uint limit, const uint lane, const bool use_limit) {
float max_val = -INFINITY;
[[unroll]]
for (int i = 0; i < experts_per_thread; i++) {
const uint idx = lane + i * WARP_SIZE;
const bool is_active = !use_limit || (idx < limit);
if (is_active) {
max_val = max(max_val, vals[i]);
}
}
max_val = subgroupMax(max_val);
float sum = 0.f;
[[unroll]]
for (int i = 0; i < experts_per_thread; i++) {
const uint idx = lane + i * WARP_SIZE;
const bool is_active = !use_limit || (idx < limit);
if (is_active) {
const float val = exp(vals[i] - max_val);
vals[i] = val;
sum += val;
} else {
vals[i] = 0.f;
}
}
sum = subgroupAdd(sum);
const float inv_sum = 1.0f / sum;
[[unroll]]
for (int i = 0; i < experts_per_thread; i++) {
const uint idx = lane + i * WARP_SIZE;
const bool is_active = !use_limit || (idx < limit);
if (is_active) {
vals[i] *= inv_sum;
}
}
}
void main() {
const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_SubgroupID;
if (row >= n_rows) {
return;
}
const uint logits_offset = n_experts * row;
const uint weights_offset = n_expert_used * row;
const uint ids_offset = n_experts * row;
const uint lane = gl_SubgroupInvocationID;
float wt[experts_per_thread];
[[unroll]]
for (uint i = 0; i < n_experts; i += WARP_SIZE) {
const uint expert = i + lane;
wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[logits_offset + expert] : -INFINITY;
}
if (!late_softmax) {
softmax_warp_inplace(wt, n_experts, lane, nexperts_use_push);
}
// at this point, each thread holds a portion of softmax,
// we do the argmax reduce over n_expert_used, each time marking
// the expert weight as -inf to exclude from the next iteration
float wt_sum = 0.f;
float output_weights[experts_per_thread];
[[unroll]]
for (int i = 0; i < experts_per_thread; i++) {
output_weights[i] = 0.f;
}
for (int k = 0; k < n_expert_used; k++) {
float max_val = wt[0];
uint max_expert = lane;
[[unroll]]
for (int i = 1; i < experts_per_thread; i++) {
const uint expert = lane + i * WARP_SIZE;
if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) {
max_val = wt[i];
max_expert = expert;
}
}
[[unroll]]
for (uint mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
const float val = subgroupShuffleXor(max_val, mask);
const uint expert = subgroupShuffleXor(max_expert, mask);
if (val > max_val || (val == max_val && expert < max_expert)) {
max_val = val;
max_expert = expert;
}
}
if ((k & (WARP_SIZE - 1)) == lane) {
output_weights[k / WARP_SIZE] = max_val;
}
if ((max_expert & (WARP_SIZE - 1)) == lane) {
wt[max_expert / WARP_SIZE] = -INFINITY;
ids[ids_offset + k] = max_expert;
if (with_norm) {
wt_sum += max_val;
}
}
}
if (with_norm) {
wt_sum = subgroupAdd(wt_sum);
wt_sum = clamp(wt_sum, clamp_min, clamp_max);
const float inv_sum = 1.0f / wt_sum;
[[unroll]]
for (uint i = 0; i < experts_per_thread; ++i) {
output_weights[i] *= inv_sum;
}
}
if (late_softmax) {
softmax_warp_inplace(output_weights, n_expert_used, lane, true);
}
[[unroll]]
for (uint i = 0; i < experts_per_thread; ++i) {
uint idx = i * WARP_SIZE + lane;
if (idx < n_expert_used) {
weights[weights_offset + idx] = output_weights[i];
}
}
}