This commit is contained in:
Jeff Bolz 2025-12-17 06:42:43 +01:00 committed by GitHub
commit 227bd469a0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 117 additions and 35 deletions

View File

@ -3076,8 +3076,11 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 9 })) { ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 9 })) {
ggml_tensor * softmax = cgraph->nodes[node_idx]; ggml_tensor * softmax = cgraph->nodes[node_idx];
ggml_tensor * weights = cgraph->nodes[node_idx + 9]; ggml_tensor * weights = cgraph->nodes[node_idx + 9];
ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
if (ggml_cuda_should_use_topk_moe(softmax, weights)) { if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
return true; return true;
} }
} }
@ -3085,7 +3088,11 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
if (is_equal(topk_moe_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) { if (is_equal(topk_moe_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) {
ggml_tensor * softmax = cgraph->nodes[node_idx]; ggml_tensor * softmax = cgraph->nodes[node_idx];
ggml_tensor * weights = cgraph->nodes[node_idx + 4]; ggml_tensor * weights = cgraph->nodes[node_idx + 4];
if (ggml_cuda_should_use_topk_moe(softmax, weights)) { ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
return true; return true;
} }
} }
@ -3094,8 +3101,11 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 1, node_idx + 5 })) { ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 1, node_idx + 5 })) {
ggml_tensor * softmax = cgraph->nodes[node_idx + 4]; ggml_tensor * softmax = cgraph->nodes[node_idx + 4];
ggml_tensor * weights = cgraph->nodes[node_idx + 5]; ggml_tensor * weights = cgraph->nodes[node_idx + 5];
ggml_tensor * get_rows = cgraph->nodes[node_idx + 2];
ggml_tensor * argsort = cgraph->nodes[node_idx + 0];
int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
if (ggml_cuda_should_use_topk_moe(softmax, weights)) { if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
return true; return true;
} }
} }

View File

@ -268,7 +268,23 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
} }
} }
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights, const ggml_tensor * clamp) { bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax,
const ggml_tensor * weights,
const ggml_tensor * get_rows,
const ggml_tensor * argsort,
const ggml_tensor * clamp,
int n_expert) {
ggml_tensor * probs = get_rows->src[0];
if (probs->op != GGML_OP_RESHAPE) {
return false;
}
probs = probs->src[0];
ggml_tensor * selection_probs = argsort->src[0];
if (probs != selection_probs) {
return false;
}
float scale = 1.0f; float scale = 1.0f;
float max_bias = 0.0f; float max_bias = 0.0f;
@ -288,7 +304,6 @@ bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tenso
return false; return false;
} }
const int n_expert = softmax->ne[0];
// n_expert must be a power of 2 // n_expert must be a power of 2
if ((n_expert & (n_expert - 1)) != 0 || n_expert > 512) { if ((n_expert & (n_expert - 1)) != 0 || n_expert > 512) {
return false; return false;

View File

@ -11,6 +11,11 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
const bool delayed_softmax = false, const bool delayed_softmax = false,
ggml_tensor * weight_clamp = nullptr); ggml_tensor * weight_clamp = nullptr);
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights, const ggml_tensor * clamp = nullptr); bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax,
const ggml_tensor * weights,
const ggml_tensor * get_rows,
const ggml_tensor * argsort,
const ggml_tensor * clamp,
int n_expert);
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool with_norm, bool delayed_softmax = false); std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool with_norm, bool delayed_softmax = false);

View File

@ -12889,24 +12889,43 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc
const ggml_tensor * softmax; const ggml_tensor * softmax;
const ggml_tensor * weights; const ggml_tensor * weights;
const ggml_tensor * get_rows;
const ggml_tensor * argsort;
switch (mode) { switch (mode) {
case TOPK_MOE_EARLY_SOFTMAX_NORM: case TOPK_MOE_EARLY_SOFTMAX_NORM:
softmax = cgraph->nodes[node_idx + 0]; softmax = cgraph->nodes[node_idx + 0];
weights = cgraph->nodes[node_idx + 9]; weights = cgraph->nodes[node_idx + 9];
get_rows = cgraph->nodes[node_idx + 4];
argsort = cgraph->nodes[node_idx + 2];
break; break;
case TOPK_MOE_EARLY_SOFTMAX: case TOPK_MOE_EARLY_SOFTMAX:
softmax = cgraph->nodes[node_idx + 0]; softmax = cgraph->nodes[node_idx + 0];
weights = cgraph->nodes[node_idx + 4]; weights = cgraph->nodes[node_idx + 4];
get_rows = cgraph->nodes[node_idx + 4];
argsort = cgraph->nodes[node_idx + 2];
break; break;
case TOPK_MOE_LATE_SOFTMAX: case TOPK_MOE_LATE_SOFTMAX:
softmax = cgraph->nodes[node_idx + 4]; softmax = cgraph->nodes[node_idx + 4];
weights = cgraph->nodes[node_idx + 5]; weights = cgraph->nodes[node_idx + 5];
get_rows = cgraph->nodes[node_idx + 2];
argsort = cgraph->nodes[node_idx + 0];
break; break;
default: default:
return false; return false;
} }
ggml_tensor * probs = get_rows->src[0];
if (probs->op != GGML_OP_RESHAPE) {
return false;
}
probs = probs->src[0];
ggml_tensor * selection_probs = argsort->src[0];
if (probs != selection_probs) {
return false;
}
const float * op_params = (const float *)softmax->op_params; const float * op_params = (const float *)softmax->op_params;
float scale = op_params[0]; float scale = op_params[0];

View File

@ -5118,25 +5118,36 @@ struct test_top_k : public test_case {
} }
}; };
enum MoeGatingFunc {
GATING_FUNC_SOFTMAX,
GATING_FUNC_SIGMOID,
GATING_FUNC_SOFTMAX_WEIGHT,
};
struct test_topk_moe : public test_case { struct test_topk_moe : public test_case {
const std::array<int64_t, 4> ne; const std::array<int64_t, 4> ne;
const int n_expert_used; const int n_expert_used;
const bool with_norm; const bool with_norm;
const bool delayed_softmax; const bool bias_probs;
const MoeGatingFunc gating_func;
const float scale_w;
test_topk_moe(std::array<int64_t, 4> ne = { 10, 5, 1, 1 }, test_topk_moe(std::array<int64_t, 4> ne = { 10, 5, 1, 1 },
int n_expert_used = 1, int n_expert_used = 1,
bool with_norm = false, bool with_norm = false,
bool delayed_softmax = false) : bool bias_probs = false,
MoeGatingFunc gating_func = GATING_FUNC_SOFTMAX,
float scale_w = 0.0f) :
ne(ne), ne(ne),
n_expert_used(n_expert_used), n_expert_used(n_expert_used),
with_norm(with_norm), with_norm(with_norm),
delayed_softmax(delayed_softmax) { bias_probs(bias_probs),
gating_func(gating_func),
scale_w(scale_w) {
GGML_ASSERT(n_expert_used <= ne[0]); GGML_ASSERT(n_expert_used <= ne[0]);
GGML_ASSERT(!(with_norm && delayed_softmax));
} }
std::string vars() override { return VARS_TO_STR4(ne, n_expert_used, with_norm, delayed_softmax); } std::string vars() override { return VARS_TO_STR6(ne, n_expert_used, with_norm, bias_probs, gating_func, scale_w); }
std::string op_desc(ggml_tensor * t) override { std::string op_desc(ggml_tensor * t) override {
GGML_UNUSED(t); GGML_UNUSED(t);
@ -5150,28 +5161,47 @@ struct test_topk_moe : public test_case {
const int n_tokens = ne[1]; const int n_tokens = ne[1];
ggml_tensor * logits = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data()); ggml_tensor * logits = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data());
ggml_tensor * probs = delayed_softmax ? logits : ggml_soft_max(ctx, logits); ggml_tensor * probs =
ggml_tensor * selected_experts = ggml_argsort_top_k(ctx, probs, n_expert_used); // [n_expert_used, n_tokens] (gating_func == GATING_FUNC_SOFTMAX) ? ggml_soft_max(ctx, logits) :
(gating_func == GATING_FUNC_SIGMOID) ? ggml_sigmoid(ctx, logits) : logits;
ggml_set_name(probs, "probs");
ggml_tensor * out = ggml_get_rows(ctx, ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens] ggml_tensor * selection_probs = probs;
if (bias_probs) {
ggml_tensor * exp_probs_b = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data());
ggml_set_name(exp_probs_b, "exp_probs_b");
selection_probs = ggml_add(ctx, probs, exp_probs_b);
ggml_set_name(selection_probs, "selection_probs");
}
if (delayed_softmax) { ggml_tensor * selected_experts = ggml_argsort_top_k(ctx, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
out = ggml_reshape_2d(ctx, out, n_expert_used, n_tokens); ggml_set_name(selected_experts, "selected_experts");
out = ggml_soft_max(ctx, out); // [n_expert_used, n_tokens]
out = ggml_reshape_3d(ctx, out, 1, n_expert_used, n_tokens); ggml_tensor * weights = ggml_get_rows(ctx, ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
ggml_set_name(weights, "weights");
if (gating_func == GATING_FUNC_SOFTMAX_WEIGHT) {
weights = ggml_reshape_2d(ctx, weights, n_expert_used, n_tokens);
weights = ggml_soft_max(ctx, weights); // [n_expert_used, n_tokens]
weights = ggml_reshape_3d(ctx, weights, 1, n_expert_used, n_tokens);
} }
if (with_norm) { if (with_norm) {
out = ggml_reshape_2d(ctx, out, n_expert_used, n_tokens); weights = ggml_reshape_2d(ctx, weights, n_expert_used, n_tokens);
ggml_tensor * weights_sum = ggml_sum_rows(ctx, out); // [1, n_tokens] ggml_tensor * weights_sum = ggml_sum_rows(ctx, weights); // [1, n_tokens]
ggml_set_name(weights_sum, "weights_sum");
weights_sum = ggml_clamp(ctx, weights_sum, 6.103515625e-5, INFINITY); weights_sum = ggml_clamp(ctx, weights_sum, 6.103515625e-5, INFINITY);
out = ggml_div(ctx, out, weights_sum); // [n_expert_used, n_tokens] weights = ggml_div(ctx, weights, weights_sum); // [n_expert_used, n_tokens]
out = ggml_reshape_3d(ctx, out, 1, n_expert_used, n_tokens); weights = ggml_reshape_3d(ctx, weights, 1, n_expert_used, n_tokens);
} }
ggml_set_name(out, "out"); if (scale_w) {
return out; weights = ggml_scale(ctx, weights, scale_w);
}
ggml_set_name(weights, "weights");
return weights;
} }
}; };
@ -7972,19 +8002,22 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
} }
} }
for (bool with_norm : {false, true}) { for (auto gate : {GATING_FUNC_SOFTMAX, GATING_FUNC_SIGMOID, GATING_FUNC_SOFTMAX_WEIGHT}) {
test_cases.emplace_back(new test_topk_moe({8, 22, 1, 1}, 4, with_norm)); for (bool with_norm : {false, true}) {
test_cases.emplace_back(new test_topk_moe({31, 22, 1, 1}, 8, with_norm)); for (bool bias_probs : {false, true}) {
test_cases.emplace_back(new test_topk_moe({32, 22, 1, 1}, 8, with_norm)); for (float scale_w : {0.0f, 2.0f}) {
test_cases.emplace_back(new test_topk_moe({40, 22, 1, 1}, 8, with_norm)); test_cases.emplace_back(new test_topk_moe({8, 22, 1, 1}, 4, with_norm, bias_probs, gate, scale_w));
test_cases.emplace_back(new test_topk_moe({71, 22, 1, 1}, 8, with_norm)); test_cases.emplace_back(new test_topk_moe({31, 22, 1, 1}, 8, with_norm, bias_probs, gate, scale_w));
test_cases.emplace_back(new test_topk_moe({128, 1, 1, 1}, 128, with_norm)); test_cases.emplace_back(new test_topk_moe({32, 22, 1, 1}, 8, with_norm, bias_probs, gate, scale_w));
test_cases.emplace_back(new test_topk_moe({129, 1, 1, 1}, 128, with_norm)); test_cases.emplace_back(new test_topk_moe({40, 22, 1, 1}, 8, with_norm, bias_probs, gate, scale_w));
test_cases.emplace_back(new test_topk_moe({71, 22, 1, 1}, 8, with_norm, bias_probs, gate, scale_w));
test_cases.emplace_back(new test_topk_moe({128, 1, 1, 1}, 128, with_norm, bias_probs, gate, scale_w));
test_cases.emplace_back(new test_topk_moe({129, 1, 1, 1}, 128, with_norm, bias_probs, gate, scale_w));
}
}
}
} }
test_cases.emplace_back(new test_topk_moe({ 8, 22, 1, 1 }, 4, /*with_norm*/ false, /*delayed_softmax*/ true));
test_cases.emplace_back(new test_topk_moe({ 32, 22, 1, 1 }, 8, /*with_norm*/ false, /*delayed_softmax*/ true));
#if 0 #if 0
// these tests are disabled to save execution time, sbut they can be handy for debugging // these tests are disabled to save execution time, sbut they can be handy for debugging
test_cases.emplace_back(new test_llama(2, true)); test_cases.emplace_back(new test_llama(2, true));