vulkan: Allow non-pow2 n_experts in topk_moe (#17872)

This commit is contained in:
Jeff Bolz 2025-12-13 01:40:04 -06:00 committed by GitHub
parent 2bc94e7928
commit 07a10c1090
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 27 additions and 11 deletions

View File

@ -757,7 +757,8 @@ struct vk_device_struct {
vk_pipeline pipeline_flash_attn_split_k_reduce;
vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][TOPK_MOE_COUNT];
// [2] is for whether to take n_experts from spec constant (0) or push constant (1)
vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][TOPK_MOE_COUNT][2];
std::vector<vk_pipeline_ref> all_pipelines;
@ -1149,6 +1150,7 @@ static_assert(sizeof(vk_op_multi_add_push_constants) <= 256);
struct vk_op_topk_moe_push_constants {
uint32_t n_rows;
uint32_t n_experts_push;
uint32_t n_expert_used;
float clamp_min;
float clamp_max;
@ -4204,10 +4206,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f16_f32, "conv2d_dw_whcn_f16_f32", conv2d_dw_whcn_f16_f32_len, conv2d_dw_whcn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) {
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX], "topk_moe_f32_early_softmax_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 0}, 1, true, true, device->subgroup_size);
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM], "topk_moe_f32_early_softmax_norm"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1, 0}, 1, true, true, device->subgroup_size);
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX], "topk_moe_f32_late_softmax"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 1}, 1, true, true, device->subgroup_size);
for (uint32_t use_push = 0; use_push < 2; ++use_push) {
for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) {
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX][use_push], "topk_moe_f32_early_softmax_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 0, use_push}, 1, true, true, device->subgroup_size);
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM][use_push], "topk_moe_f32_early_softmax_norm"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1, 0, use_push}, 1, true, true, device->subgroup_size);
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX][use_push], "topk_moe_f32_late_softmax"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 1, use_push}, 1, true, true, device->subgroup_size);
}
}
for (auto &c : compiles) {
@ -8554,7 +8558,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
GGML_ASSERT(idx < num_topk_moe_pipelines);
topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
return ctx->device->pipeline_topk_moe[idx][mode];
// use n_experts from push constant if it's not equal to the power of two spec constant
bool use_push = dst->ne[0] != (1u << idx);
return ctx->device->pipeline_topk_moe[idx][mode][use_push];
}
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
@ -10158,6 +10164,7 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx,
vk_op_topk_moe_push_constants pc {};
pc.n_rows = n_rows;
pc.n_experts_push = n_experts;
pc.n_expert_used = n_expert_used;
if (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) {
ggml_tensor * clamp = cgraph->nodes[node_idx + 7];
@ -12832,8 +12839,7 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc
}
const int n_expert = softmax->ne[0];
// n_expert must be a power of 2
if (!is_pow2(n_expert) || n_expert > (1 << (num_topk_moe_pipelines-1))) {
if (n_expert > (1 << (num_topk_moe_pipelines-1))) {
return false;
}

View File

@ -10,6 +10,7 @@
layout (push_constant) uniform parameter
{
uint n_rows;
uint n_experts_push;
uint n_expert_used;
float clamp_min;
float clamp_max;
@ -18,11 +19,16 @@ layout (push_constant) uniform parameter
layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
layout(constant_id = 0) const uint WARP_SIZE = 32;
layout(constant_id = 1) const uint n_experts = 512;
layout(constant_id = 1) const uint n_experts_spec = 512;
layout(constant_id = 2) const bool with_norm = true;
layout(constant_id = 3) const bool late_softmax = false;
layout(constant_id = 4) const bool nexperts_use_push = false;
const uint experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1;
uint n_experts = nexperts_use_push ? n_experts_push : n_experts_spec;
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
const uint experts_per_thread = CEIL_DIV(n_experts_spec, WARP_SIZE);
layout (binding = 0, std430) readonly buffer Logits {float logits[];};
layout (binding = 1, std430) writeonly buffer Weights {float weights[];};
@ -94,7 +100,7 @@ void main() {
}
if (!late_softmax) {
softmax_warp_inplace(wt, n_experts, lane, false);
softmax_warp_inplace(wt, n_experts, lane, nexperts_use_push);
}
// at this point, each thread holds a portion of softmax,

View File

@ -7971,8 +7971,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
for (bool with_norm : {false, true}) {
test_cases.emplace_back(new test_topk_moe({8, 22, 1, 1}, 4, with_norm));
test_cases.emplace_back(new test_topk_moe({31, 22, 1, 1}, 8, with_norm));
test_cases.emplace_back(new test_topk_moe({32, 22, 1, 1}, 8, with_norm));
test_cases.emplace_back(new test_topk_moe({40, 22, 1, 1}, 8, with_norm));
test_cases.emplace_back(new test_topk_moe({71, 22, 1, 1}, 8, with_norm));
test_cases.emplace_back(new test_topk_moe({128, 1, 1, 1}, 128, with_norm));
test_cases.emplace_back(new test_topk_moe({129, 1, 1, 1}, 128, with_norm));
}
test_cases.emplace_back(new test_topk_moe({ 8, 22, 1, 1 }, 4, /*with_norm*/ false, /*delayed_softmax*/ true));