diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp b/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp index 4bf6d2bcb0..ef2f202ec9 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp @@ -101,6 +101,10 @@ void main() { const uint lane = gl_SubgroupInvocationID; float probs[experts_per_thread]; + [[unroll]] + for (int i = 0; i < experts_per_thread; i++) { + probs[i] = -INFINITY; + } [[unroll]] for (uint i = 0; i < n_experts; i += WARP_SIZE) { @@ -112,8 +116,9 @@ void main() { softmax_warp_inplace(probs, n_experts, lane, nexperts_use_push); } else if (gating_func == GATING_FUNC_SIGMOID) { [[unroll]] - for (int i = 0; i < experts_per_thread; i++) { - probs[i] = 1.f / (1.f + exp(-probs[i])); + for (uint i = 0; i < n_experts; i += WARP_SIZE) { + const uint expert = i + lane; + probs[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? 1.f / (1.f + exp(-probs[i / WARP_SIZE])) : -INFINITY; } } @@ -150,11 +155,11 @@ void main() { uint max_expert = lane; [[unroll]] - for (int i = 1; i < experts_per_thread; i++) { - const uint expert = lane + i * WARP_SIZE; - if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && selection_probs[i] > max_val_s) { - max_val = probs[i]; - max_val_s = selection_probs[i]; + for (uint i = WARP_SIZE; i < n_experts; i += WARP_SIZE) { + const uint expert = i + lane; + if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && selection_probs[i / WARP_SIZE] > max_val_s) { + max_val = probs[i / WARP_SIZE]; + max_val_s = selection_probs[i / WARP_SIZE]; max_expert = expert; } } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 8df994e91c..15567abedc 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -8184,6 +8184,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_topk_moe({71, 22, 1, 1}, 8, with_norm, bias_probs, gate, scale_w)); test_cases.emplace_back(new test_topk_moe({128, 1, 1, 1}, 128, with_norm, bias_probs, gate, scale_w)); test_cases.emplace_back(new test_topk_moe({129, 1, 1, 1}, 128, with_norm, bias_probs, gate, scale_w)); + test_cases.emplace_back(new test_topk_moe({160, 4, 1, 1}, 160, with_norm, bias_probs, gate, scale_w)); } } }