vulkan: fix topk_moe_sigmoid_norm_bias failures in GLM-4.6 (#18582)

This commit is contained in:
Jeff Bolz 2026-01-05 04:51:39 -06:00 committed by GitHub
parent 2da64a2f8a
commit f1768d8f03
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 13 additions and 7 deletions

View File

@ -101,6 +101,10 @@ void main() {
const uint lane = gl_SubgroupInvocationID;
float probs[experts_per_thread];
[[unroll]]
for (int i = 0; i < experts_per_thread; i++) {
probs[i] = -INFINITY;
}
[[unroll]]
for (uint i = 0; i < n_experts; i += WARP_SIZE) {
@ -112,8 +116,9 @@ void main() {
softmax_warp_inplace(probs, n_experts, lane, nexperts_use_push);
} else if (gating_func == GATING_FUNC_SIGMOID) {
[[unroll]]
for (int i = 0; i < experts_per_thread; i++) {
probs[i] = 1.f / (1.f + exp(-probs[i]));
for (uint i = 0; i < n_experts; i += WARP_SIZE) {
const uint expert = i + lane;
probs[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? 1.f / (1.f + exp(-probs[i / WARP_SIZE])) : -INFINITY;
}
}
@ -150,11 +155,11 @@ void main() {
uint max_expert = lane;
[[unroll]]
for (int i = 1; i < experts_per_thread; i++) {
const uint expert = lane + i * WARP_SIZE;
if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && selection_probs[i] > max_val_s) {
max_val = probs[i];
max_val_s = selection_probs[i];
for (uint i = WARP_SIZE; i < n_experts; i += WARP_SIZE) {
const uint expert = i + lane;
if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && selection_probs[i / WARP_SIZE] > max_val_s) {
max_val = probs[i / WARP_SIZE];
max_val_s = selection_probs[i / WARP_SIZE];
max_expert = expert;
}
}

View File

@ -8184,6 +8184,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_topk_moe({71, 22, 1, 1}, 8, with_norm, bias_probs, gate, scale_w));
test_cases.emplace_back(new test_topk_moe({128, 1, 1, 1}, 128, with_norm, bias_probs, gate, scale_w));
test_cases.emplace_back(new test_topk_moe({129, 1, 1, 1}, 128, with_norm, bias_probs, gate, scale_w));
test_cases.emplace_back(new test_topk_moe({160, 4, 1, 1}, 160, with_norm, bias_probs, gate, scale_w));
}
}
}