diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index ab0f6fe9ce..55fa2e6a7c 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3076,8 +3076,11 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 9 })) { ggml_tensor * softmax = cgraph->nodes[node_idx]; ggml_tensor * weights = cgraph->nodes[node_idx + 9]; + ggml_tensor * get_rows = cgraph->nodes[node_idx + 4]; + ggml_tensor * argsort = cgraph->nodes[node_idx + 2]; + int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0]; - if (ggml_cuda_should_use_topk_moe(softmax, weights)) { + if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) { return true; } } @@ -3085,7 +3088,11 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, if (is_equal(topk_moe_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) { ggml_tensor * softmax = cgraph->nodes[node_idx]; ggml_tensor * weights = cgraph->nodes[node_idx + 4]; - if (ggml_cuda_should_use_topk_moe(softmax, weights)) { + ggml_tensor * get_rows = cgraph->nodes[node_idx + 4]; + ggml_tensor * argsort = cgraph->nodes[node_idx + 2]; + int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0]; + + if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) { return true; } } @@ -3094,8 +3101,11 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 1, node_idx + 5 })) { ggml_tensor * softmax = cgraph->nodes[node_idx + 4]; ggml_tensor * weights = cgraph->nodes[node_idx + 5]; + ggml_tensor * get_rows = cgraph->nodes[node_idx + 2]; + ggml_tensor * argsort = cgraph->nodes[node_idx + 0]; + int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0]; - if (ggml_cuda_should_use_topk_moe(softmax, weights)) { + if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) { return true; } } diff --git a/ggml/src/ggml-cuda/topk-moe.cu b/ggml/src/ggml-cuda/topk-moe.cu index 572379fcbf..48e569efa0 100644 --- a/ggml/src/ggml-cuda/topk-moe.cu +++ b/ggml/src/ggml-cuda/topk-moe.cu @@ -268,7 +268,23 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, } } -bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights, const ggml_tensor * clamp) { +bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, + const ggml_tensor * weights, + const ggml_tensor * get_rows, + const ggml_tensor * argsort, + const ggml_tensor * clamp, + int n_expert) { + ggml_tensor * probs = get_rows->src[0]; + if (probs->op != GGML_OP_RESHAPE) { + return false; + } + probs = probs->src[0]; + ggml_tensor * selection_probs = argsort->src[0]; + + if (probs != selection_probs) { + return false; + } + float scale = 1.0f; float max_bias = 0.0f; @@ -288,7 +304,6 @@ bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tenso return false; } - const int n_expert = softmax->ne[0]; // n_expert must be a power of 2 if ((n_expert & (n_expert - 1)) != 0 || n_expert > 512) { return false; diff --git a/ggml/src/ggml-cuda/topk-moe.cuh b/ggml/src/ggml-cuda/topk-moe.cuh index 2eff408b03..6b6c13c587 100644 --- a/ggml/src/ggml-cuda/topk-moe.cuh +++ b/ggml/src/ggml-cuda/topk-moe.cuh @@ -11,6 +11,11 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, const bool delayed_softmax = false, ggml_tensor * weight_clamp = nullptr); -bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights, const ggml_tensor * clamp = nullptr); +bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, + const ggml_tensor * weights, + const ggml_tensor * get_rows, + const ggml_tensor * argsort, + const ggml_tensor * clamp, + int n_expert); std::initializer_list ggml_cuda_topk_moe_ops(bool with_norm, bool delayed_softmax = false); diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 34ec09d403..b40a57e48f 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -12889,24 +12889,43 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc const ggml_tensor * softmax; const ggml_tensor * weights; + const ggml_tensor * get_rows; + const ggml_tensor * argsort; switch (mode) { case TOPK_MOE_EARLY_SOFTMAX_NORM: softmax = cgraph->nodes[node_idx + 0]; weights = cgraph->nodes[node_idx + 9]; + get_rows = cgraph->nodes[node_idx + 4]; + argsort = cgraph->nodes[node_idx + 2]; break; case TOPK_MOE_EARLY_SOFTMAX: softmax = cgraph->nodes[node_idx + 0]; weights = cgraph->nodes[node_idx + 4]; + get_rows = cgraph->nodes[node_idx + 4]; + argsort = cgraph->nodes[node_idx + 2]; break; case TOPK_MOE_LATE_SOFTMAX: softmax = cgraph->nodes[node_idx + 4]; weights = cgraph->nodes[node_idx + 5]; + get_rows = cgraph->nodes[node_idx + 2]; + argsort = cgraph->nodes[node_idx + 0]; break; default: return false; } + ggml_tensor * probs = get_rows->src[0]; + if (probs->op != GGML_OP_RESHAPE) { + return false; + } + probs = probs->src[0]; + ggml_tensor * selection_probs = argsort->src[0]; + + if (probs != selection_probs) { + return false; + } + const float * op_params = (const float *)softmax->op_params; float scale = op_params[0]; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 416218b5b8..c395bbc88d 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -5118,25 +5118,36 @@ struct test_top_k : public test_case { } }; +enum MoeGatingFunc { + GATING_FUNC_SOFTMAX, + GATING_FUNC_SIGMOID, + GATING_FUNC_SOFTMAX_WEIGHT, +}; + struct test_topk_moe : public test_case { const std::array ne; const int n_expert_used; const bool with_norm; - const bool delayed_softmax; + const bool bias_probs; + const MoeGatingFunc gating_func; + const float scale_w; test_topk_moe(std::array ne = { 10, 5, 1, 1 }, int n_expert_used = 1, bool with_norm = false, - bool delayed_softmax = false) : + bool bias_probs = false, + MoeGatingFunc gating_func = GATING_FUNC_SOFTMAX, + float scale_w = 0.0f) : ne(ne), n_expert_used(n_expert_used), with_norm(with_norm), - delayed_softmax(delayed_softmax) { + bias_probs(bias_probs), + gating_func(gating_func), + scale_w(scale_w) { GGML_ASSERT(n_expert_used <= ne[0]); - GGML_ASSERT(!(with_norm && delayed_softmax)); } - std::string vars() override { return VARS_TO_STR4(ne, n_expert_used, with_norm, delayed_softmax); } + std::string vars() override { return VARS_TO_STR6(ne, n_expert_used, with_norm, bias_probs, gating_func, scale_w); } std::string op_desc(ggml_tensor * t) override { GGML_UNUSED(t); @@ -5150,28 +5161,47 @@ struct test_topk_moe : public test_case { const int n_tokens = ne[1]; ggml_tensor * logits = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data()); - ggml_tensor * probs = delayed_softmax ? logits : ggml_soft_max(ctx, logits); - ggml_tensor * selected_experts = ggml_argsort_top_k(ctx, probs, n_expert_used); // [n_expert_used, n_tokens] + ggml_tensor * probs = + (gating_func == GATING_FUNC_SOFTMAX) ? ggml_soft_max(ctx, logits) : + (gating_func == GATING_FUNC_SIGMOID) ? ggml_sigmoid(ctx, logits) : logits; + ggml_set_name(probs, "probs"); - ggml_tensor * out = ggml_get_rows(ctx, ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens] + ggml_tensor * selection_probs = probs; + if (bias_probs) { + ggml_tensor * exp_probs_b = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data()); + ggml_set_name(exp_probs_b, "exp_probs_b"); + selection_probs = ggml_add(ctx, probs, exp_probs_b); + ggml_set_name(selection_probs, "selection_probs"); + } - if (delayed_softmax) { - out = ggml_reshape_2d(ctx, out, n_expert_used, n_tokens); - out = ggml_soft_max(ctx, out); // [n_expert_used, n_tokens] - out = ggml_reshape_3d(ctx, out, 1, n_expert_used, n_tokens); + ggml_tensor * selected_experts = ggml_argsort_top_k(ctx, selection_probs, n_expert_used); // [n_expert_used, n_tokens] + ggml_set_name(selected_experts, "selected_experts"); + + ggml_tensor * weights = ggml_get_rows(ctx, ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens] + ggml_set_name(weights, "weights"); + + if (gating_func == GATING_FUNC_SOFTMAX_WEIGHT) { + weights = ggml_reshape_2d(ctx, weights, n_expert_used, n_tokens); + weights = ggml_soft_max(ctx, weights); // [n_expert_used, n_tokens] + weights = ggml_reshape_3d(ctx, weights, 1, n_expert_used, n_tokens); } if (with_norm) { - out = ggml_reshape_2d(ctx, out, n_expert_used, n_tokens); - ggml_tensor * weights_sum = ggml_sum_rows(ctx, out); // [1, n_tokens] + weights = ggml_reshape_2d(ctx, weights, n_expert_used, n_tokens); + ggml_tensor * weights_sum = ggml_sum_rows(ctx, weights); // [1, n_tokens] + ggml_set_name(weights_sum, "weights_sum"); weights_sum = ggml_clamp(ctx, weights_sum, 6.103515625e-5, INFINITY); - out = ggml_div(ctx, out, weights_sum); // [n_expert_used, n_tokens] - out = ggml_reshape_3d(ctx, out, 1, n_expert_used, n_tokens); + weights = ggml_div(ctx, weights, weights_sum); // [n_expert_used, n_tokens] + weights = ggml_reshape_3d(ctx, weights, 1, n_expert_used, n_tokens); } - ggml_set_name(out, "out"); - return out; + if (scale_w) { + weights = ggml_scale(ctx, weights, scale_w); + } + + ggml_set_name(weights, "weights"); + return weights; } }; @@ -7972,19 +8002,22 @@ static std::vector> make_test_cases_eval() { } } - for (bool with_norm : {false, true}) { - test_cases.emplace_back(new test_topk_moe({8, 22, 1, 1}, 4, with_norm)); - test_cases.emplace_back(new test_topk_moe({31, 22, 1, 1}, 8, with_norm)); - test_cases.emplace_back(new test_topk_moe({32, 22, 1, 1}, 8, with_norm)); - test_cases.emplace_back(new test_topk_moe({40, 22, 1, 1}, 8, with_norm)); - test_cases.emplace_back(new test_topk_moe({71, 22, 1, 1}, 8, with_norm)); - test_cases.emplace_back(new test_topk_moe({128, 1, 1, 1}, 128, with_norm)); - test_cases.emplace_back(new test_topk_moe({129, 1, 1, 1}, 128, with_norm)); + for (auto gate : {GATING_FUNC_SOFTMAX, GATING_FUNC_SIGMOID, GATING_FUNC_SOFTMAX_WEIGHT}) { + for (bool with_norm : {false, true}) { + for (bool bias_probs : {false, true}) { + for (float scale_w : {0.0f, 2.0f}) { + test_cases.emplace_back(new test_topk_moe({8, 22, 1, 1}, 4, with_norm, bias_probs, gate, scale_w)); + test_cases.emplace_back(new test_topk_moe({31, 22, 1, 1}, 8, with_norm, bias_probs, gate, scale_w)); + test_cases.emplace_back(new test_topk_moe({32, 22, 1, 1}, 8, with_norm, bias_probs, gate, scale_w)); + test_cases.emplace_back(new test_topk_moe({40, 22, 1, 1}, 8, with_norm, bias_probs, gate, scale_w)); + test_cases.emplace_back(new test_topk_moe({71, 22, 1, 1}, 8, with_norm, bias_probs, gate, scale_w)); + test_cases.emplace_back(new test_topk_moe({128, 1, 1, 1}, 128, with_norm, bias_probs, gate, scale_w)); + test_cases.emplace_back(new test_topk_moe({129, 1, 1, 1}, 128, with_norm, bias_probs, gate, scale_w)); + } + } + } } - test_cases.emplace_back(new test_topk_moe({ 8, 22, 1, 1 }, 4, /*with_norm*/ false, /*delayed_softmax*/ true)); - test_cases.emplace_back(new test_topk_moe({ 32, 22, 1, 1 }, 8, /*with_norm*/ false, /*delayed_softmax*/ true)); - #if 0 // these tests are disabled to save execution time, sbut they can be handy for debugging test_cases.emplace_back(new test_llama(2, true));