vulkan/cuda: fix topk_moe with exp_probs_b
I updated test_topk_moe to more closely match llm_graph_context::build_moe_ffn and added coverage for exp_probs_b and some other missing combinations. This exposed a bug in both CUDA and Vulkan backends where they were assuming the input to argsort and the input to get_rows are the same. I'd like to optimize this graph in another change, but for now just get it functional. CUDA also had a bug where it got n_experts from the wrong place, leading to GGML_ASSERT failures in some of the new tests.
This commit is contained in:
parent
5c8a717128
commit
e262bf9798
|
|
@ -3076,8 +3076,11 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
|
|||
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 9 })) {
|
||||
ggml_tensor * softmax = cgraph->nodes[node_idx];
|
||||
ggml_tensor * weights = cgraph->nodes[node_idx + 9];
|
||||
ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
|
||||
ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
|
||||
int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
|
||||
|
||||
if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
|
||||
if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
|
@ -3085,7 +3088,11 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
|
|||
if (is_equal(topk_moe_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) {
|
||||
ggml_tensor * softmax = cgraph->nodes[node_idx];
|
||||
ggml_tensor * weights = cgraph->nodes[node_idx + 4];
|
||||
if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
|
||||
ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
|
||||
ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
|
||||
int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
|
||||
|
||||
if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
|
@ -3094,8 +3101,11 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
|
|||
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 1, node_idx + 5 })) {
|
||||
ggml_tensor * softmax = cgraph->nodes[node_idx + 4];
|
||||
ggml_tensor * weights = cgraph->nodes[node_idx + 5];
|
||||
ggml_tensor * get_rows = cgraph->nodes[node_idx + 2];
|
||||
ggml_tensor * argsort = cgraph->nodes[node_idx + 0];
|
||||
int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
|
||||
|
||||
if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
|
||||
if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -268,7 +268,23 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
|
|||
}
|
||||
}
|
||||
|
||||
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights, const ggml_tensor * clamp) {
|
||||
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax,
|
||||
const ggml_tensor * weights,
|
||||
const ggml_tensor * get_rows,
|
||||
const ggml_tensor * argsort,
|
||||
const ggml_tensor * clamp,
|
||||
int n_expert) {
|
||||
ggml_tensor * probs = get_rows->src[0];
|
||||
if (probs->op != GGML_OP_RESHAPE) {
|
||||
return false;
|
||||
}
|
||||
probs = probs->src[0];
|
||||
ggml_tensor * selection_probs = argsort->src[0];
|
||||
|
||||
if (probs != selection_probs) {
|
||||
return false;
|
||||
}
|
||||
|
||||
float scale = 1.0f;
|
||||
float max_bias = 0.0f;
|
||||
|
||||
|
|
@ -288,7 +304,6 @@ bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tenso
|
|||
return false;
|
||||
}
|
||||
|
||||
const int n_expert = softmax->ne[0];
|
||||
// n_expert must be a power of 2
|
||||
if ((n_expert & (n_expert - 1)) != 0 || n_expert > 512) {
|
||||
return false;
|
||||
|
|
|
|||
|
|
@ -11,6 +11,11 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
|
|||
const bool delayed_softmax = false,
|
||||
ggml_tensor * weight_clamp = nullptr);
|
||||
|
||||
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights, const ggml_tensor * clamp = nullptr);
|
||||
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax,
|
||||
const ggml_tensor * weights,
|
||||
const ggml_tensor * get_rows,
|
||||
const ggml_tensor * argsort,
|
||||
const ggml_tensor * clamp,
|
||||
int n_expert);
|
||||
|
||||
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool with_norm, bool delayed_softmax = false);
|
||||
|
|
|
|||
|
|
@ -12889,24 +12889,43 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc
|
|||
|
||||
const ggml_tensor * softmax;
|
||||
const ggml_tensor * weights;
|
||||
const ggml_tensor * get_rows;
|
||||
const ggml_tensor * argsort;
|
||||
|
||||
switch (mode) {
|
||||
case TOPK_MOE_EARLY_SOFTMAX_NORM:
|
||||
softmax = cgraph->nodes[node_idx + 0];
|
||||
weights = cgraph->nodes[node_idx + 9];
|
||||
get_rows = cgraph->nodes[node_idx + 4];
|
||||
argsort = cgraph->nodes[node_idx + 2];
|
||||
break;
|
||||
case TOPK_MOE_EARLY_SOFTMAX:
|
||||
softmax = cgraph->nodes[node_idx + 0];
|
||||
weights = cgraph->nodes[node_idx + 4];
|
||||
get_rows = cgraph->nodes[node_idx + 4];
|
||||
argsort = cgraph->nodes[node_idx + 2];
|
||||
break;
|
||||
case TOPK_MOE_LATE_SOFTMAX:
|
||||
softmax = cgraph->nodes[node_idx + 4];
|
||||
weights = cgraph->nodes[node_idx + 5];
|
||||
get_rows = cgraph->nodes[node_idx + 2];
|
||||
argsort = cgraph->nodes[node_idx + 0];
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
||||
ggml_tensor * probs = get_rows->src[0];
|
||||
if (probs->op != GGML_OP_RESHAPE) {
|
||||
return false;
|
||||
}
|
||||
probs = probs->src[0];
|
||||
ggml_tensor * selection_probs = argsort->src[0];
|
||||
|
||||
if (probs != selection_probs) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const float * op_params = (const float *)softmax->op_params;
|
||||
|
||||
float scale = op_params[0];
|
||||
|
|
|
|||
|
|
@ -5118,25 +5118,36 @@ struct test_top_k : public test_case {
|
|||
}
|
||||
};
|
||||
|
||||
enum MoeGatingFunc {
|
||||
GATING_FUNC_SOFTMAX,
|
||||
GATING_FUNC_SIGMOID,
|
||||
GATING_FUNC_SOFTMAX_WEIGHT,
|
||||
};
|
||||
|
||||
struct test_topk_moe : public test_case {
|
||||
const std::array<int64_t, 4> ne;
|
||||
const int n_expert_used;
|
||||
const bool with_norm;
|
||||
const bool delayed_softmax;
|
||||
const bool bias_probs;
|
||||
const MoeGatingFunc gating_func;
|
||||
const float scale_w;
|
||||
|
||||
test_topk_moe(std::array<int64_t, 4> ne = { 10, 5, 1, 1 },
|
||||
int n_expert_used = 1,
|
||||
bool with_norm = false,
|
||||
bool delayed_softmax = false) :
|
||||
bool bias_probs = false,
|
||||
MoeGatingFunc gating_func = GATING_FUNC_SOFTMAX,
|
||||
float scale_w = 0.0f) :
|
||||
ne(ne),
|
||||
n_expert_used(n_expert_used),
|
||||
with_norm(with_norm),
|
||||
delayed_softmax(delayed_softmax) {
|
||||
bias_probs(bias_probs),
|
||||
gating_func(gating_func),
|
||||
scale_w(scale_w) {
|
||||
GGML_ASSERT(n_expert_used <= ne[0]);
|
||||
GGML_ASSERT(!(with_norm && delayed_softmax));
|
||||
}
|
||||
|
||||
std::string vars() override { return VARS_TO_STR4(ne, n_expert_used, with_norm, delayed_softmax); }
|
||||
std::string vars() override { return VARS_TO_STR6(ne, n_expert_used, with_norm, bias_probs, gating_func, scale_w); }
|
||||
|
||||
std::string op_desc(ggml_tensor * t) override {
|
||||
GGML_UNUSED(t);
|
||||
|
|
@ -5150,28 +5161,47 @@ struct test_topk_moe : public test_case {
|
|||
const int n_tokens = ne[1];
|
||||
|
||||
ggml_tensor * logits = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data());
|
||||
ggml_tensor * probs = delayed_softmax ? logits : ggml_soft_max(ctx, logits);
|
||||
ggml_tensor * selected_experts = ggml_argsort_top_k(ctx, probs, n_expert_used); // [n_expert_used, n_tokens]
|
||||
ggml_tensor * probs =
|
||||
(gating_func == GATING_FUNC_SOFTMAX) ? ggml_soft_max(ctx, logits) :
|
||||
(gating_func == GATING_FUNC_SIGMOID) ? ggml_sigmoid(ctx, logits) : logits;
|
||||
ggml_set_name(probs, "probs");
|
||||
|
||||
ggml_tensor * out = ggml_get_rows(ctx, ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
|
||||
ggml_tensor * selection_probs = probs;
|
||||
if (bias_probs) {
|
||||
ggml_tensor * exp_probs_b = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data());
|
||||
ggml_set_name(exp_probs_b, "exp_probs_b");
|
||||
selection_probs = ggml_add(ctx, probs, exp_probs_b);
|
||||
ggml_set_name(selection_probs, "selection_probs");
|
||||
}
|
||||
|
||||
if (delayed_softmax) {
|
||||
out = ggml_reshape_2d(ctx, out, n_expert_used, n_tokens);
|
||||
out = ggml_soft_max(ctx, out); // [n_expert_used, n_tokens]
|
||||
out = ggml_reshape_3d(ctx, out, 1, n_expert_used, n_tokens);
|
||||
ggml_tensor * selected_experts = ggml_argsort_top_k(ctx, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
|
||||
ggml_set_name(selected_experts, "selected_experts");
|
||||
|
||||
ggml_tensor * weights = ggml_get_rows(ctx, ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
|
||||
ggml_set_name(weights, "weights");
|
||||
|
||||
if (gating_func == GATING_FUNC_SOFTMAX_WEIGHT) {
|
||||
weights = ggml_reshape_2d(ctx, weights, n_expert_used, n_tokens);
|
||||
weights = ggml_soft_max(ctx, weights); // [n_expert_used, n_tokens]
|
||||
weights = ggml_reshape_3d(ctx, weights, 1, n_expert_used, n_tokens);
|
||||
}
|
||||
|
||||
if (with_norm) {
|
||||
out = ggml_reshape_2d(ctx, out, n_expert_used, n_tokens);
|
||||
ggml_tensor * weights_sum = ggml_sum_rows(ctx, out); // [1, n_tokens]
|
||||
weights = ggml_reshape_2d(ctx, weights, n_expert_used, n_tokens);
|
||||
ggml_tensor * weights_sum = ggml_sum_rows(ctx, weights); // [1, n_tokens]
|
||||
ggml_set_name(weights_sum, "weights_sum");
|
||||
|
||||
weights_sum = ggml_clamp(ctx, weights_sum, 6.103515625e-5, INFINITY);
|
||||
out = ggml_div(ctx, out, weights_sum); // [n_expert_used, n_tokens]
|
||||
out = ggml_reshape_3d(ctx, out, 1, n_expert_used, n_tokens);
|
||||
weights = ggml_div(ctx, weights, weights_sum); // [n_expert_used, n_tokens]
|
||||
weights = ggml_reshape_3d(ctx, weights, 1, n_expert_used, n_tokens);
|
||||
}
|
||||
|
||||
ggml_set_name(out, "out");
|
||||
return out;
|
||||
if (scale_w) {
|
||||
weights = ggml_scale(ctx, weights, scale_w);
|
||||
}
|
||||
|
||||
ggml_set_name(weights, "weights");
|
||||
return weights;
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -7972,19 +8002,22 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
}
|
||||
}
|
||||
|
||||
for (bool with_norm : {false, true}) {
|
||||
test_cases.emplace_back(new test_topk_moe({8, 22, 1, 1}, 4, with_norm));
|
||||
test_cases.emplace_back(new test_topk_moe({31, 22, 1, 1}, 8, with_norm));
|
||||
test_cases.emplace_back(new test_topk_moe({32, 22, 1, 1}, 8, with_norm));
|
||||
test_cases.emplace_back(new test_topk_moe({40, 22, 1, 1}, 8, with_norm));
|
||||
test_cases.emplace_back(new test_topk_moe({71, 22, 1, 1}, 8, with_norm));
|
||||
test_cases.emplace_back(new test_topk_moe({128, 1, 1, 1}, 128, with_norm));
|
||||
test_cases.emplace_back(new test_topk_moe({129, 1, 1, 1}, 128, with_norm));
|
||||
for (auto gate : {GATING_FUNC_SOFTMAX, GATING_FUNC_SIGMOID, GATING_FUNC_SOFTMAX_WEIGHT}) {
|
||||
for (bool with_norm : {false, true}) {
|
||||
for (bool bias_probs : {false, true}) {
|
||||
for (float scale_w : {0.0f, 2.0f}) {
|
||||
test_cases.emplace_back(new test_topk_moe({8, 22, 1, 1}, 4, with_norm, bias_probs, gate, scale_w));
|
||||
test_cases.emplace_back(new test_topk_moe({31, 22, 1, 1}, 8, with_norm, bias_probs, gate, scale_w));
|
||||
test_cases.emplace_back(new test_topk_moe({32, 22, 1, 1}, 8, with_norm, bias_probs, gate, scale_w));
|
||||
test_cases.emplace_back(new test_topk_moe({40, 22, 1, 1}, 8, with_norm, bias_probs, gate, scale_w));
|
||||
test_cases.emplace_back(new test_topk_moe({71, 22, 1, 1}, 8, with_norm, bias_probs, gate, scale_w));
|
||||
test_cases.emplace_back(new test_topk_moe({128, 1, 1, 1}, 128, with_norm, bias_probs, gate, scale_w));
|
||||
test_cases.emplace_back(new test_topk_moe({129, 1, 1, 1}, 128, with_norm, bias_probs, gate, scale_w));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test_cases.emplace_back(new test_topk_moe({ 8, 22, 1, 1 }, 4, /*with_norm*/ false, /*delayed_softmax*/ true));
|
||||
test_cases.emplace_back(new test_topk_moe({ 32, 22, 1, 1 }, 8, /*with_norm*/ false, /*delayed_softmax*/ true));
|
||||
|
||||
#if 0
|
||||
// these tests are disabled to save execution time, sbut they can be handy for debugging
|
||||
test_cases.emplace_back(new test_llama(2, true));
|
||||
|
|
|
|||
Loading…
Reference in New Issue