CUDA: refactor topk-moe to enable more models (GLM 4.7, Nemotron etc.) (#19126)

This commit is contained in:
Aman Gupta 2026-01-29 10:31:28 +08:00 committed by GitHub
parent d4964a7c66
commit 3bcc990997
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 426 additions and 225 deletions

View File

@ -3080,63 +3080,166 @@ static bool ggml_cuda_should_fuse_rope_set_rows(const ggml_tensor * rope,
return true;
}
static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops, std::initializer_list<enum ggml_unary_op> unary_ops) {
static bool ggml_cuda_topk_moe_fusion(const struct ggml_cgraph * cgraph, int node_idx, ggml_cuda_topk_moe_args & args) {
args.sigmoid = false;
args.softmax = false;
args.delayed_softmax = false;
args.prob_bias = false;
args.norm = false;
const int n_nodes = cgraph->n_nodes;
ggml_tensor ** nodes = cgraph->nodes;
if (nodes[node_idx]->op == GGML_OP_SOFT_MAX) {
args.softmax = true;
}
if (nodes[node_idx]->op == GGML_OP_UNARY) {
if (ggml_get_unary_op(nodes[node_idx]) != GGML_UNARY_OP_SIGMOID) {
return false;
}
args.sigmoid = true;
}
if (nodes[node_idx]->op == GGML_OP_ARGSORT) {
args.delayed_softmax = true;
}
node_idx++;
if (args.sigmoid || args.softmax) {
// SOFTMAX -> RESHAPE
if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_RESHAPE ||
nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
return false;
}
ggml_tensor * probs_reshaped = nodes[node_idx];
node_idx++;
if (node_idx >= n_nodes) {
return false;
}
// src of bias add is the unreshaped probs (-2 instead of -1)
if (nodes[node_idx]->op == GGML_OP_ADD && nodes[node_idx]->src[0] == nodes[node_idx - 2]) {
args.prob_bias = true;
node_idx++;
}
// RESHAPE/ADD -> ARGSORT
if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_ARGSORT) {
return false;
}
if (args.prob_bias && nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
return false;
} else if (!args.prob_bias && nodes[node_idx]->src[0] != nodes[node_idx - 2]) {
return false;
}
node_idx++;
// ARGSORT-> VIEW
if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_VIEW ||
nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
return false;
}
node_idx++;
if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_GET_ROWS) {
return false;
}
// GET_ROWS
if (nodes[node_idx]->src[0] != probs_reshaped || nodes[node_idx]->src[1] != nodes[node_idx - 1]) {
return false;
}
node_idx++;
} else if (args.delayed_softmax) {
if (node_idx - 2 < 0) {
return false;
}
ggml_tensor * probs_reshaped = nodes[node_idx - 2];
// VIEW->ARGSORT
if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_VIEW ||
nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
return false;
}
node_idx++;
// GET_ROWS
if (node_idx >= n_nodes || nodes[node_idx]->src[1] != nodes[node_idx - 1] ||
nodes[node_idx]->src[0] != probs_reshaped) {
return false;
}
node_idx++;
static const std::vector<ggml_op> remaining_ops = { GGML_OP_RESHAPE, GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
for (const ggml_op op : remaining_ops) {
if (node_idx >= n_nodes || nodes[node_idx]->op != op || nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
return false;
}
node_idx++;
}
}
// At this point we can check for norm + scale. Everything is now at least valid till the norm
if (node_idx >= n_nodes) {
return true;
}
if (nodes[node_idx]->op == GGML_OP_RESHAPE) {
//check RESHAPE->SUM_ROWS->CLAMP->DIV->RESHAPE
static const std::vector<ggml_op> norm_ops = { GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP };
args.norm = true;
for (const ggml_op op : norm_ops) {
if (nodes[node_idx]->op == op && nodes[node_idx]->src[0] == nodes[node_idx - 1]) {
node_idx++;
} else {
args.norm = false;
return true;
}
}
// DIV <- CLAMP, RESHAPE
if (nodes[node_idx]->op != GGML_OP_DIV || nodes[node_idx]->src[1] != nodes[node_idx - 1] ||
nodes[node_idx]->src[0] != nodes[node_idx - 3]) {
args.norm = false;
return true;
}
node_idx++;
if (nodes[node_idx]->op != GGML_OP_RESHAPE || nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
args.norm = false;
return true;
}
node_idx++;
}
if (nodes[node_idx]->op == GGML_OP_SCALE && nodes[node_idx]->src[0] == nodes[node_idx - 1]) {
args.scale = true;
}
return true;
}
static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph,
int node_idx,
std::initializer_list<enum ggml_op> ops,
std::initializer_list<enum ggml_unary_op> unary_ops) {
#ifndef NDEBUG
const size_t num_unary = std::count(ops.begin(), ops.end(), GGML_OP_UNARY);
GGML_ASSERT(unary_ops.size() == num_unary);
#endif
//TODO: remove special case once ggml_can_fuse can handle empty nodes
std::initializer_list<enum ggml_op> topk_moe_ops =
ggml_cuda_topk_moe_ops(/*with_norm*/ false, /*delayed_softmax=*/false);
std::initializer_list<enum ggml_op> topk_moe_ops_with_norm =
ggml_cuda_topk_moe_ops(/*with_norm=*/true, /*delayed_softmax=*/false);
std::initializer_list<enum ggml_op> topk_moe_ops_delayed_softmax =
ggml_cuda_topk_moe_ops(/*with_norm=*/false, /*delayed_softmax=*/true);
const auto is_equal = [](const std::initializer_list<enum ggml_op> & list1,
const std::initializer_list<enum ggml_op> & list2) {
return std::equal(list1.begin(), list1.end(), list2.begin(), list2.end());
};
if (is_equal(topk_moe_ops_with_norm, ops) &&
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 9 })) {
ggml_tensor * softmax = cgraph->nodes[node_idx];
ggml_tensor * weights = cgraph->nodes[node_idx + 9];
ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
return true;
}
}
if (is_equal(topk_moe_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) {
ggml_tensor * softmax = cgraph->nodes[node_idx];
ggml_tensor * weights = cgraph->nodes[node_idx + 4];
ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
return true;
}
}
if (is_equal(topk_moe_ops_delayed_softmax, ops) &&
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 1, node_idx + 5 })) {
ggml_tensor * softmax = cgraph->nodes[node_idx + 4];
ggml_tensor * weights = cgraph->nodes[node_idx + 5];
ggml_tensor * get_rows = cgraph->nodes[node_idx + 2];
ggml_tensor * argsort = cgraph->nodes[node_idx + 0];
int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
return true;
}
}
std::initializer_list<enum ggml_op> mul_mat_bias_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_GLU };
std::initializer_list<enum ggml_op> mul_mat_id_bias_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_GLU };
@ -3398,35 +3501,75 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
// start of fusion operations
static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
if (!disable_fusion) {
ggml_cuda_topk_moe_args args;
if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ true), {})) {
ggml_tensor * weights = cgraph->nodes[i + 9];
ggml_tensor * selected_experts = cgraph->nodes[i + 3];
ggml_tensor * clamp = cgraph->nodes[i + 7];
ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ true,
/*delayed softmax*/ false, clamp);
i += 9;
continue;
}
if (cgraph->nodes[i]->op == GGML_OP_UNARY || cgraph->nodes[i]->op == GGML_OP_SOFT_MAX ||
cgraph->nodes[i]->op == GGML_OP_ARGSORT) {
const bool can_fuse = ggml_cuda_topk_moe_fusion(cgraph, i, args);
if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ false), {})) {
ggml_tensor * weights = cgraph->nodes[i + 4];
ggml_tensor * selected_experts = cgraph->nodes[i + 3];
ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ false,
/*delayed softmax*/ false);
i += 4;
continue;
}
std::vector<ggml_op> ops;
if (ggml_cuda_can_fuse(cgraph, i,
ggml_cuda_topk_moe_ops(/*with norm*/ false, /*delayed softmax*/ true), {})) {
ggml_tensor * weights = cgraph->nodes[i + 5];
ggml_tensor * ids = cgraph->nodes[i + 1];
if (can_fuse) {
const ggml_tensor * logits = node->src[0];
ggml_tensor * weights = nullptr;
ggml_tensor * ids = nullptr;
const ggml_tensor * bias = nullptr;
const ggml_tensor * clamp = nullptr;
const ggml_tensor * scale = nullptr;
ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, ids, /*with norm*/ false,
/*delayed_softmax*/ true);
i += 5;
continue;
if (!args.delayed_softmax) {
ggml_op gating_op = args.sigmoid ? GGML_OP_UNARY : GGML_OP_SOFT_MAX;
int out_nodes[2]; // nodes which can't be elided
if (args.prob_bias) {
bias = cgraph->nodes[i + 2]->src[1];
ops.insert(ops.end(), { gating_op, GGML_OP_RESHAPE, GGML_OP_ADD, GGML_OP_ARGSORT,
GGML_OP_VIEW, GGML_OP_GET_ROWS });
out_nodes[0] = i + 4;
ids = cgraph->nodes[i + 4];
} else {
ops.insert(ops.end(), { gating_op, GGML_OP_RESHAPE, GGML_OP_ARGSORT, GGML_OP_VIEW,
GGML_OP_GET_ROWS });
out_nodes[0] = i + 3;
ids = cgraph->nodes[i + 3];
}
if (args.norm) {
ops.insert(ops.end(), { GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP,
GGML_OP_DIV, GGML_OP_RESHAPE });
clamp = cgraph->nodes[i + ops.size() - 3];
}
if (args.scale) {
ops.insert(ops.end(), { GGML_OP_SCALE });
scale = cgraph->nodes[i + ops.size() - 1];
}
weights = cgraph->nodes[i + ops.size() - 1];
out_nodes[1] = i + ops.size() - 1;
if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) &&
ggml_cuda_should_use_topk_moe(node, logits, weights, ids)) {
ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args);
i += ops.size() - 1;
continue;
}
} else if (!args.norm && !args.prob_bias) {
//special case gpt-oss, no norm, no bias.
ops.insert(ops.end(), { GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS,
GGML_OP_RESHAPE, GGML_OP_SOFT_MAX, GGML_OP_RESHAPE });
weights = cgraph->nodes[i + 5];
ids = cgraph->nodes[i + 1];
const ggml_tensor * softmax = cgraph->nodes[i + 4];
int out_nodes[2] = { i + 1, i + 5 };
if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) &&
ggml_cuda_should_use_topk_moe(softmax, logits, weights, ids)) {
ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args);
i += ops.size() - 1;
continue;
}
}
}
}
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, {})) {

View File

@ -5,6 +5,13 @@
#include <cmath>
#include <initializer_list>
// Kernel config struct - passed by value to CUDA kernel
struct topk_moe_config {
bool use_sigmoid;
bool with_norm;
bool delayed_softmax;
};
// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
template <int experts_per_thread, bool use_limit>
__device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const int limit, const int lane) {
@ -50,6 +57,16 @@ __device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const in
}
}
template <int experts_per_thread, bool use_limit>
__device__ void sigmoid_warp_inplace(float (&vals)[experts_per_thread], const int limit, const int lane) {
#pragma unroll
for (int i = 0; i < experts_per_thread; i++) {
const int idx = lane + i * WARP_SIZE;
const bool active = !use_limit || (idx < limit);
vals[i] = active ? 1.f / (1.f + expf(-vals[i])) : -INFINITY;
}
}
/*
This kernel does the following:
1. optionally softmax over the logits per token [n_experts, n_tokens]
@ -59,13 +76,16 @@ __device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const in
It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models
*/
template <int n_experts, bool with_norm, bool delayed_softmax = false>
__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits,
float * weights,
int32_t * ids,
const int n_rows,
const int n_expert_used,
const float clamp_val) {
template <int n_experts, bool has_bias>
__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits,
float * weights,
int32_t * ids,
float * bias,
const int n_rows,
const int n_expert_used,
const float clamp_val,
const float scale_val,
const topk_moe_config config) {
const int row = blockIdx.x * blockDim.y + threadIdx.y;
if (row >= n_rows) {
return;
@ -79,14 +99,41 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
float wt[experts_per_thread];
// Initialize all slots to -INFINITY
#pragma unroll
for (int i = 0; i < experts_per_thread; i++) {
wt[i] = -INFINITY;
}
#pragma unroll
for (int i = 0; i < n_experts; i += WARP_SIZE) {
const int expert = i + threadIdx.x;
wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[expert] : -INFINITY;
}
if constexpr (!delayed_softmax) {
softmax_warp_inplace<experts_per_thread, false>(wt, n_experts, threadIdx.x);
if (!config.delayed_softmax) {
if (config.use_sigmoid) {
sigmoid_warp_inplace<experts_per_thread, false>(wt, n_experts, threadIdx.x);
} else {
softmax_warp_inplace<experts_per_thread, false>(wt, n_experts, threadIdx.x);
}
}
// selection_wt is only needed when bias is present (selection uses wt + bias)
// when no bias, we use wt directly for both selection and weight values
float selection_wt[has_bias ? experts_per_thread : 1];
if constexpr (has_bias) {
#pragma unroll
for (int i = 0; i < experts_per_thread; i++) {
selection_wt[i] = -INFINITY;
}
#pragma unroll
for (int i = 0; i < n_experts; i += WARP_SIZE) {
const int expert = i + threadIdx.x;
selection_wt[i / WARP_SIZE] =
(n_experts % WARP_SIZE == 0 || expert < n_experts) ? wt[i / WARP_SIZE] + bias[expert] : -INFINITY;
}
}
//at this point, each thread holds either a portion of the softmax distribution
@ -106,22 +153,56 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
float max_val = wt[0];
int max_expert = threadIdx.x;
#pragma unroll
for (int i = 1; i < experts_per_thread; i++) {
const int expert = threadIdx.x + i * WARP_SIZE;
if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) {
max_val = wt[i];
max_expert = expert;
}
}
if constexpr (has_bias) {
float max_val_s = selection_wt[0];
#pragma unroll
for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE);
const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE);
if (val > max_val || (val == max_val && expert < max_expert)) {
max_val = val;
max_expert = expert;
for (int i = 1; i < experts_per_thread; i++) {
const int expert = threadIdx.x + i * WARP_SIZE;
if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && selection_wt[i] > max_val_s) {
max_val = wt[i];
max_val_s = selection_wt[i];
max_expert = expert;
}
}
#pragma unroll
for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE);
const float val_s = __shfl_xor_sync(0xFFFFFFFF, max_val_s, mask, WARP_SIZE);
const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE);
if (val_s > max_val_s || (val_s == max_val_s && expert < max_expert)) {
max_val = val;
max_val_s = val_s;
max_expert = expert;
}
}
if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
selection_wt[max_expert / WARP_SIZE] = -INFINITY;
}
} else {
#pragma unroll
for (int i = 1; i < experts_per_thread; i++) {
const int expert = threadIdx.x + i * WARP_SIZE;
if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) {
max_val = wt[i];
max_expert = expert;
}
}
#pragma unroll
for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE);
const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE);
if (val > max_val || (val == max_val && expert < max_expert)) {
max_val = val;
max_expert = expert;
}
}
if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
wt[max_expert / WARP_SIZE] = -INFINITY;
}
}
@ -130,16 +211,14 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
}
if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
wt[max_expert / WARP_SIZE] = -INFINITY;
ids[k] = max_expert;
if constexpr (with_norm) {
if (config.with_norm) {
wt_sum += max_val;
}
}
}
if constexpr (with_norm) {
if (config.with_norm) {
wt_sum = warp_reduce_sum(wt_sum);
wt_sum = max(wt_sum, clamp_val);
const float inv_sum = 1.0f / wt_sum;
@ -149,7 +228,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
}
}
if constexpr (delayed_softmax) {
if (config.delayed_softmax) {
softmax_warp_inplace<experts_per_thread, true>(output_weights, n_expert_used, threadIdx.x);
}
@ -157,25 +236,25 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
for (int i = 0; i < experts_per_thread; i++) {
const int idx = i * WARP_SIZE + threadIdx.x;
if (idx < n_expert_used) {
weights[idx] = output_weights[i];
weights[idx] = output_weights[i] * scale_val;
}
}
if (!with_norm) {
GGML_UNUSED(clamp_val);
}
}
template <bool with_norm, bool delayed_softmax = false>
template<bool has_bias>
static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
const float * logits,
float * weights,
int32_t * ids,
float * bias,
const int n_rows,
const int n_expert,
const int n_expert_used,
const float clamp_val) {
static_assert(!(with_norm && delayed_softmax), "delayed softmax is not supported with weight normalization");
const float clamp_val,
const float scale_val,
const topk_moe_config config) {
GGML_ASSERT(!(config.with_norm && config.delayed_softmax) &&
"delayed softmax is not supported with weight normalization");
const int rows_per_block = 4;
dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1);
dim3 block_dims(WARP_SIZE, rows_per_block, 1);
@ -183,44 +262,48 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
switch (n_expert) {
case 1:
topk_moe_cuda<1, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
topk_moe_cuda<1, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
clamp_val, scale_val, config);
break;
case 2:
topk_moe_cuda<2, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
topk_moe_cuda<2, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
clamp_val, scale_val, config);
break;
case 4:
topk_moe_cuda<4, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
topk_moe_cuda<4, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
clamp_val, scale_val, config);
break;
case 8:
topk_moe_cuda<8, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
topk_moe_cuda<8, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
clamp_val, scale_val, config);
break;
case 16:
topk_moe_cuda<16, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
topk_moe_cuda<16, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
clamp_val, scale_val, config);
break;
case 32:
topk_moe_cuda<32, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
topk_moe_cuda<32, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
clamp_val, scale_val, config);
break;
case 64:
topk_moe_cuda<64, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
topk_moe_cuda<64, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
clamp_val, scale_val, config);
break;
case 128:
topk_moe_cuda<128, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
topk_moe_cuda<128, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
clamp_val, scale_val, config);
break;
case 256:
topk_moe_cuda<256, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
topk_moe_cuda<256, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
clamp_val, scale_val, config);
break;
case 512:
topk_moe_cuda<512, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
topk_moe_cuda<512, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
clamp_val, scale_val, config);
break;
case 576:
topk_moe_cuda<576, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
clamp_val, scale_val, config);
break;
default:
GGML_ASSERT(false && "fatal error");
@ -228,13 +311,14 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
}
}
void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
const ggml_tensor * logits,
ggml_tensor * weights,
ggml_tensor * ids,
const bool with_norm,
const bool delayed_softmax,
ggml_tensor * clamp) {
void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
const ggml_tensor * logits,
ggml_tensor * weights,
ggml_tensor * ids,
const ggml_tensor * clamp,
const ggml_tensor * scale,
const ggml_tensor * bias,
const ggml_cuda_topk_moe_args & args) {
GGML_ASSERT(logits->type == GGML_TYPE_F32);
GGML_ASSERT(weights->type == GGML_TYPE_F32);
GGML_ASSERT(ids->type == GGML_TYPE_I32);
@ -245,107 +329,75 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
const float * logits_d = (const float *) logits->data;
float * weights_d = (float *) weights->data;
int32_t * ids_d = (int32_t *) ids->data;
float * bias_d = bias ? (float *) bias->data : nullptr;
float scale_val = scale ? ggml_get_op_params_f32(scale, 0) : 1.0f;
GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts);
const int n_expert_used = weights->ne[1];
const bool with_norm = clamp != nullptr;
float clamp_val = -INFINITY;
if (with_norm) {
if (clamp) {
clamp_val = ggml_get_op_params_f32(clamp, 0);
}
launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used, clamp_val);
if (clamp) {
clamp_val = ggml_get_op_params_f32(clamp, 0);
}
topk_moe_config config;
config.use_sigmoid = args.sigmoid;
config.with_norm = with_norm;
config.delayed_softmax = args.delayed_softmax;
if (bias) {
launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, bias_d, n_rows, n_experts, n_expert_used, clamp_val,
scale_val, config);
} else {
GGML_ASSERT(clamp == nullptr);
if (delayed_softmax) {
launch_topk_moe_cuda<false, true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used,
clamp_val);
} else {
launch_topk_moe_cuda<false, false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used,
clamp_val);
}
launch_topk_moe_cuda<false>(ctx, logits_d, weights_d, ids_d, bias_d, n_rows, n_experts, n_expert_used, clamp_val,
scale_val, config);
}
}
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax,
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * gating_op,
const ggml_tensor * weights,
const ggml_tensor * get_rows,
const ggml_tensor * argsort,
const ggml_tensor * clamp,
int n_expert) {
ggml_tensor * probs = get_rows->src[0];
if (probs->op != GGML_OP_RESHAPE) {
return false;
}
probs = probs->src[0];
ggml_tensor * selection_probs = argsort->src[0];
if (probs != selection_probs) {
const ggml_tensor * logits,
const ggml_tensor * ids) {
const int n_expert = ids->nb[1] / ids->nb[0];
if (((n_expert & (n_expert - 1)) != 0 || n_expert > 512) && n_expert != 576) {
return false;
}
float scale = 1.0f;
float max_bias = 0.0f;
memcpy(&scale, (const float *) softmax->op_params + 0, sizeof(float));
memcpy(&max_bias, (const float *) softmax->op_params + 1, sizeof(float));
if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) {
if (!ggml_is_contiguous(weights) || !ggml_is_contiguous(logits)) {
return false;
}
if (scale != 1.0f || max_bias != 0.0f) {
return false;
}
if (gating_op->op == GGML_OP_SOFT_MAX) {
const ggml_tensor * softmax = gating_op;
float scale = 1.0f;
float max_bias = 0.0f;
// don't fuse when masks or sinks are present
if (softmax->src[1] || softmax->src[2]) {
return false;
}
memcpy(&scale, (const float *) softmax->op_params + 0, sizeof(float));
memcpy(&max_bias, (const float *) softmax->op_params + 1, sizeof(float));
// n_expert must be a power of 2
if ((n_expert & (n_expert - 1)) != 0 || n_expert > 512) {
return false;
}
if (clamp) {
if (clamp->op != GGML_OP_CLAMP) {
if (!ggml_is_contiguous(softmax->src[0])) {
return false;
}
float max_val = ggml_get_op_params_f32(clamp, 1);
if (max_val != INFINITY) {
if (scale != 1.0f || max_bias != 0.0f) {
return false;
}
// don't fuse when masks or sinks are present
if (softmax->src[1] || softmax->src[2]) {
return false;
}
} else if (gating_op->op == GGML_OP_UNARY) {
ggml_unary_op op = ggml_get_unary_op(gating_op);
if (op != GGML_UNARY_OP_SIGMOID) {
return false;
}
}
return true;
}
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool norm, bool delayed_softmax) {
static std::initializer_list<enum ggml_op> norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV,
GGML_OP_RESHAPE };
static std::initializer_list<enum ggml_op> no_norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
GGML_OP_VIEW, GGML_OP_GET_ROWS };
static std::initializer_list<enum ggml_op> delayed_softmax_ops = { GGML_OP_ARGSORT, GGML_OP_VIEW,
GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
GGML_ASSERT(!norm || !delayed_softmax);
if (delayed_softmax) {
return delayed_softmax_ops;
}
if (norm) {
return norm_ops;
}
return no_norm_ops;
}

View File

@ -3,19 +3,25 @@
#include <initializer_list>
void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
const ggml_tensor * logits,
ggml_tensor * weights,
ggml_tensor * ids,
const bool with_norm,
const bool delayed_softmax = false,
ggml_tensor * weight_clamp = nullptr);
struct ggml_cuda_topk_moe_args {
bool sigmoid{};
bool softmax{};
bool delayed_softmax{};
bool prob_bias{};
bool norm{};
bool scale{};
};
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax,
void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
const ggml_tensor * logits,
ggml_tensor * weights,
ggml_tensor * ids,
const ggml_tensor * clamp,
const ggml_tensor * scale,
const ggml_tensor * bias,
const ggml_cuda_topk_moe_args & args);
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * gating_op,
const ggml_tensor * weights,
const ggml_tensor * get_rows,
const ggml_tensor * argsort,
const ggml_tensor * clamp,
int n_expert);
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool with_norm, bool delayed_softmax = false);
const ggml_tensor * logits,
const ggml_tensor * ids);