diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index e9df0ea4a7..76d0f12550 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3080,63 +3080,166 @@ static bool ggml_cuda_should_fuse_rope_set_rows(const ggml_tensor * rope, return true; } -static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list ops, std::initializer_list unary_ops) { +static bool ggml_cuda_topk_moe_fusion(const struct ggml_cgraph * cgraph, int node_idx, ggml_cuda_topk_moe_args & args) { + args.sigmoid = false; + args.softmax = false; + args.delayed_softmax = false; + args.prob_bias = false; + args.norm = false; + + const int n_nodes = cgraph->n_nodes; + ggml_tensor ** nodes = cgraph->nodes; + + if (nodes[node_idx]->op == GGML_OP_SOFT_MAX) { + args.softmax = true; + } + + if (nodes[node_idx]->op == GGML_OP_UNARY) { + if (ggml_get_unary_op(nodes[node_idx]) != GGML_UNARY_OP_SIGMOID) { + return false; + } + args.sigmoid = true; + } + + if (nodes[node_idx]->op == GGML_OP_ARGSORT) { + args.delayed_softmax = true; + } + + node_idx++; + + if (args.sigmoid || args.softmax) { + // SOFTMAX -> RESHAPE + if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_RESHAPE || + nodes[node_idx]->src[0] != nodes[node_idx - 1]) { + return false; + } + ggml_tensor * probs_reshaped = nodes[node_idx]; + node_idx++; + + if (node_idx >= n_nodes) { + return false; + } + + // src of bias add is the unreshaped probs (-2 instead of -1) + if (nodes[node_idx]->op == GGML_OP_ADD && nodes[node_idx]->src[0] == nodes[node_idx - 2]) { + args.prob_bias = true; + node_idx++; + } + // RESHAPE/ADD -> ARGSORT + if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_ARGSORT) { + return false; + } + + if (args.prob_bias && nodes[node_idx]->src[0] != nodes[node_idx - 1]) { + return false; + } else if (!args.prob_bias && nodes[node_idx]->src[0] != nodes[node_idx - 2]) { + return false; + } + + node_idx++; + + // ARGSORT-> VIEW + if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_VIEW || + nodes[node_idx]->src[0] != nodes[node_idx - 1]) { + return false; + } + node_idx++; + + if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_GET_ROWS) { + return false; + } + + // GET_ROWS + if (nodes[node_idx]->src[0] != probs_reshaped || nodes[node_idx]->src[1] != nodes[node_idx - 1]) { + return false; + } + node_idx++; + } else if (args.delayed_softmax) { + if (node_idx - 2 < 0) { + return false; + } + ggml_tensor * probs_reshaped = nodes[node_idx - 2]; + + // VIEW->ARGSORT + if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_VIEW || + nodes[node_idx]->src[0] != nodes[node_idx - 1]) { + return false; + } + node_idx++; + + // GET_ROWS + if (node_idx >= n_nodes || nodes[node_idx]->src[1] != nodes[node_idx - 1] || + nodes[node_idx]->src[0] != probs_reshaped) { + return false; + } + node_idx++; + + static const std::vector remaining_ops = { GGML_OP_RESHAPE, GGML_OP_SOFT_MAX, GGML_OP_RESHAPE }; + + for (const ggml_op op : remaining_ops) { + if (node_idx >= n_nodes || nodes[node_idx]->op != op || nodes[node_idx]->src[0] != nodes[node_idx - 1]) { + return false; + } + node_idx++; + } + } + + // At this point we can check for norm + scale. Everything is now at least valid till the norm + if (node_idx >= n_nodes) { + return true; + } + + if (nodes[node_idx]->op == GGML_OP_RESHAPE) { + //check RESHAPE->SUM_ROWS->CLAMP->DIV->RESHAPE + static const std::vector norm_ops = { GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP }; + + args.norm = true; + for (const ggml_op op : norm_ops) { + if (nodes[node_idx]->op == op && nodes[node_idx]->src[0] == nodes[node_idx - 1]) { + node_idx++; + } else { + args.norm = false; + return true; + } + } + + // DIV <- CLAMP, RESHAPE + if (nodes[node_idx]->op != GGML_OP_DIV || nodes[node_idx]->src[1] != nodes[node_idx - 1] || + nodes[node_idx]->src[0] != nodes[node_idx - 3]) { + args.norm = false; + return true; + } + node_idx++; + + if (nodes[node_idx]->op != GGML_OP_RESHAPE || nodes[node_idx]->src[0] != nodes[node_idx - 1]) { + args.norm = false; + return true; + } + + node_idx++; + } + + if (nodes[node_idx]->op == GGML_OP_SCALE && nodes[node_idx]->src[0] == nodes[node_idx - 1]) { + args.scale = true; + } + + return true; +} + +static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, + int node_idx, + std::initializer_list ops, + std::initializer_list unary_ops) { #ifndef NDEBUG const size_t num_unary = std::count(ops.begin(), ops.end(), GGML_OP_UNARY); GGML_ASSERT(unary_ops.size() == num_unary); #endif - //TODO: remove special case once ggml_can_fuse can handle empty nodes - std::initializer_list topk_moe_ops = - ggml_cuda_topk_moe_ops(/*with_norm*/ false, /*delayed_softmax=*/false); - std::initializer_list topk_moe_ops_with_norm = - ggml_cuda_topk_moe_ops(/*with_norm=*/true, /*delayed_softmax=*/false); - std::initializer_list topk_moe_ops_delayed_softmax = - ggml_cuda_topk_moe_ops(/*with_norm=*/false, /*delayed_softmax=*/true); - const auto is_equal = [](const std::initializer_list & list1, const std::initializer_list & list2) { return std::equal(list1.begin(), list1.end(), list2.begin(), list2.end()); }; - if (is_equal(topk_moe_ops_with_norm, ops) && - ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 9 })) { - ggml_tensor * softmax = cgraph->nodes[node_idx]; - ggml_tensor * weights = cgraph->nodes[node_idx + 9]; - ggml_tensor * get_rows = cgraph->nodes[node_idx + 4]; - ggml_tensor * argsort = cgraph->nodes[node_idx + 2]; - int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0]; - - if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) { - return true; - } - } - - if (is_equal(topk_moe_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) { - ggml_tensor * softmax = cgraph->nodes[node_idx]; - ggml_tensor * weights = cgraph->nodes[node_idx + 4]; - ggml_tensor * get_rows = cgraph->nodes[node_idx + 4]; - ggml_tensor * argsort = cgraph->nodes[node_idx + 2]; - int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0]; - - if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) { - return true; - } - } - - if (is_equal(topk_moe_ops_delayed_softmax, ops) && - ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 1, node_idx + 5 })) { - ggml_tensor * softmax = cgraph->nodes[node_idx + 4]; - ggml_tensor * weights = cgraph->nodes[node_idx + 5]; - ggml_tensor * get_rows = cgraph->nodes[node_idx + 2]; - ggml_tensor * argsort = cgraph->nodes[node_idx + 0]; - int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0]; - - if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) { - return true; - } - } - std::initializer_list mul_mat_bias_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_GLU }; std::initializer_list mul_mat_id_bias_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_GLU }; @@ -3398,35 +3501,75 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud // start of fusion operations static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr); if (!disable_fusion) { + ggml_cuda_topk_moe_args args; - if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ true), {})) { - ggml_tensor * weights = cgraph->nodes[i + 9]; - ggml_tensor * selected_experts = cgraph->nodes[i + 3]; - ggml_tensor * clamp = cgraph->nodes[i + 7]; - ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ true, - /*delayed softmax*/ false, clamp); - i += 9; - continue; - } + if (cgraph->nodes[i]->op == GGML_OP_UNARY || cgraph->nodes[i]->op == GGML_OP_SOFT_MAX || + cgraph->nodes[i]->op == GGML_OP_ARGSORT) { + const bool can_fuse = ggml_cuda_topk_moe_fusion(cgraph, i, args); - if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ false), {})) { - ggml_tensor * weights = cgraph->nodes[i + 4]; - ggml_tensor * selected_experts = cgraph->nodes[i + 3]; - ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ false, - /*delayed softmax*/ false); - i += 4; - continue; - } + std::vector ops; - if (ggml_cuda_can_fuse(cgraph, i, - ggml_cuda_topk_moe_ops(/*with norm*/ false, /*delayed softmax*/ true), {})) { - ggml_tensor * weights = cgraph->nodes[i + 5]; - ggml_tensor * ids = cgraph->nodes[i + 1]; + if (can_fuse) { + const ggml_tensor * logits = node->src[0]; + ggml_tensor * weights = nullptr; + ggml_tensor * ids = nullptr; + const ggml_tensor * bias = nullptr; + const ggml_tensor * clamp = nullptr; + const ggml_tensor * scale = nullptr; - ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, ids, /*with norm*/ false, - /*delayed_softmax*/ true); - i += 5; - continue; + if (!args.delayed_softmax) { + ggml_op gating_op = args.sigmoid ? GGML_OP_UNARY : GGML_OP_SOFT_MAX; + int out_nodes[2]; // nodes which can't be elided + + if (args.prob_bias) { + bias = cgraph->nodes[i + 2]->src[1]; + ops.insert(ops.end(), { gating_op, GGML_OP_RESHAPE, GGML_OP_ADD, GGML_OP_ARGSORT, + GGML_OP_VIEW, GGML_OP_GET_ROWS }); + out_nodes[0] = i + 4; + ids = cgraph->nodes[i + 4]; + } else { + ops.insert(ops.end(), { gating_op, GGML_OP_RESHAPE, GGML_OP_ARGSORT, GGML_OP_VIEW, + GGML_OP_GET_ROWS }); + out_nodes[0] = i + 3; + ids = cgraph->nodes[i + 3]; + } + + if (args.norm) { + ops.insert(ops.end(), { GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP, + GGML_OP_DIV, GGML_OP_RESHAPE }); + clamp = cgraph->nodes[i + ops.size() - 3]; + } + if (args.scale) { + ops.insert(ops.end(), { GGML_OP_SCALE }); + scale = cgraph->nodes[i + ops.size() - 1]; + } + + weights = cgraph->nodes[i + ops.size() - 1]; + out_nodes[1] = i + ops.size() - 1; + + if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) && + ggml_cuda_should_use_topk_moe(node, logits, weights, ids)) { + ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args); + i += ops.size() - 1; + continue; + } + } else if (!args.norm && !args.prob_bias) { + //special case gpt-oss, no norm, no bias. + ops.insert(ops.end(), { GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS, + GGML_OP_RESHAPE, GGML_OP_SOFT_MAX, GGML_OP_RESHAPE }); + weights = cgraph->nodes[i + 5]; + ids = cgraph->nodes[i + 1]; + const ggml_tensor * softmax = cgraph->nodes[i + 4]; + + int out_nodes[2] = { i + 1, i + 5 }; + if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) && + ggml_cuda_should_use_topk_moe(softmax, logits, weights, ids)) { + ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args); + i += ops.size() - 1; + continue; + } + } + } } if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, {})) { diff --git a/ggml/src/ggml-cuda/topk-moe.cu b/ggml/src/ggml-cuda/topk-moe.cu index 48e569efa0..08a88990dd 100644 --- a/ggml/src/ggml-cuda/topk-moe.cu +++ b/ggml/src/ggml-cuda/topk-moe.cu @@ -5,6 +5,13 @@ #include #include +// Kernel config struct - passed by value to CUDA kernel +struct topk_moe_config { + bool use_sigmoid; + bool with_norm; + bool delayed_softmax; +}; + // Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path. template __device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const int limit, const int lane) { @@ -50,6 +57,16 @@ __device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const in } } +template +__device__ void sigmoid_warp_inplace(float (&vals)[experts_per_thread], const int limit, const int lane) { +#pragma unroll + for (int i = 0; i < experts_per_thread; i++) { + const int idx = lane + i * WARP_SIZE; + const bool active = !use_limit || (idx < limit); + vals[i] = active ? 1.f / (1.f + expf(-vals[i])) : -INFINITY; + } +} + /* This kernel does the following: 1. optionally softmax over the logits per token [n_experts, n_tokens] @@ -59,13 +76,16 @@ __device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const in It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models */ -template -__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits, - float * weights, - int32_t * ids, - const int n_rows, - const int n_expert_used, - const float clamp_val) { +template +__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits, + float * weights, + int32_t * ids, + float * bias, + const int n_rows, + const int n_expert_used, + const float clamp_val, + const float scale_val, + const topk_moe_config config) { const int row = blockIdx.x * blockDim.y + threadIdx.y; if (row >= n_rows) { return; @@ -79,14 +99,41 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * float wt[experts_per_thread]; + // Initialize all slots to -INFINITY +#pragma unroll + for (int i = 0; i < experts_per_thread; i++) { + wt[i] = -INFINITY; + } + #pragma unroll for (int i = 0; i < n_experts; i += WARP_SIZE) { const int expert = i + threadIdx.x; wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[expert] : -INFINITY; } - if constexpr (!delayed_softmax) { - softmax_warp_inplace(wt, n_experts, threadIdx.x); + if (!config.delayed_softmax) { + if (config.use_sigmoid) { + sigmoid_warp_inplace(wt, n_experts, threadIdx.x); + } else { + softmax_warp_inplace(wt, n_experts, threadIdx.x); + } + } + + // selection_wt is only needed when bias is present (selection uses wt + bias) + // when no bias, we use wt directly for both selection and weight values + float selection_wt[has_bias ? experts_per_thread : 1]; + + if constexpr (has_bias) { +#pragma unroll + for (int i = 0; i < experts_per_thread; i++) { + selection_wt[i] = -INFINITY; + } +#pragma unroll + for (int i = 0; i < n_experts; i += WARP_SIZE) { + const int expert = i + threadIdx.x; + selection_wt[i / WARP_SIZE] = + (n_experts % WARP_SIZE == 0 || expert < n_experts) ? wt[i / WARP_SIZE] + bias[expert] : -INFINITY; + } } //at this point, each thread holds either a portion of the softmax distribution @@ -106,22 +153,56 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * float max_val = wt[0]; int max_expert = threadIdx.x; -#pragma unroll - for (int i = 1; i < experts_per_thread; i++) { - const int expert = threadIdx.x + i * WARP_SIZE; - if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) { - max_val = wt[i]; - max_expert = expert; - } - } + if constexpr (has_bias) { + float max_val_s = selection_wt[0]; #pragma unroll - for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) { - const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE); - const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE); - if (val > max_val || (val == max_val && expert < max_expert)) { - max_val = val; - max_expert = expert; + for (int i = 1; i < experts_per_thread; i++) { + const int expert = threadIdx.x + i * WARP_SIZE; + if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && selection_wt[i] > max_val_s) { + max_val = wt[i]; + max_val_s = selection_wt[i]; + max_expert = expert; + } + } + +#pragma unroll + for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) { + const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE); + const float val_s = __shfl_xor_sync(0xFFFFFFFF, max_val_s, mask, WARP_SIZE); + const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE); + if (val_s > max_val_s || (val_s == max_val_s && expert < max_expert)) { + max_val = val; + max_val_s = val_s; + max_expert = expert; + } + } + + if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) { + selection_wt[max_expert / WARP_SIZE] = -INFINITY; + } + } else { +#pragma unroll + for (int i = 1; i < experts_per_thread; i++) { + const int expert = threadIdx.x + i * WARP_SIZE; + if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) { + max_val = wt[i]; + max_expert = expert; + } + } + +#pragma unroll + for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) { + const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE); + const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE); + if (val > max_val || (val == max_val && expert < max_expert)) { + max_val = val; + max_expert = expert; + } + } + + if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) { + wt[max_expert / WARP_SIZE] = -INFINITY; } } @@ -130,16 +211,14 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * } if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) { - wt[max_expert / WARP_SIZE] = -INFINITY; - ids[k] = max_expert; - if constexpr (with_norm) { + if (config.with_norm) { wt_sum += max_val; } } } - if constexpr (with_norm) { + if (config.with_norm) { wt_sum = warp_reduce_sum(wt_sum); wt_sum = max(wt_sum, clamp_val); const float inv_sum = 1.0f / wt_sum; @@ -149,7 +228,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * } } - if constexpr (delayed_softmax) { + if (config.delayed_softmax) { softmax_warp_inplace(output_weights, n_expert_used, threadIdx.x); } @@ -157,25 +236,25 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * for (int i = 0; i < experts_per_thread; i++) { const int idx = i * WARP_SIZE + threadIdx.x; if (idx < n_expert_used) { - weights[idx] = output_weights[i]; + weights[idx] = output_weights[i] * scale_val; } } - - if (!with_norm) { - GGML_UNUSED(clamp_val); - } } -template +template static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx, const float * logits, float * weights, int32_t * ids, + float * bias, const int n_rows, const int n_expert, const int n_expert_used, - const float clamp_val) { - static_assert(!(with_norm && delayed_softmax), "delayed softmax is not supported with weight normalization"); + const float clamp_val, + const float scale_val, + const topk_moe_config config) { + GGML_ASSERT(!(config.with_norm && config.delayed_softmax) && + "delayed softmax is not supported with weight normalization"); const int rows_per_block = 4; dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1); dim3 block_dims(WARP_SIZE, rows_per_block, 1); @@ -183,44 +262,48 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx, switch (n_expert) { case 1: - topk_moe_cuda<1, with_norm, delayed_softmax> - <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); + topk_moe_cuda<1, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, + clamp_val, scale_val, config); break; case 2: - topk_moe_cuda<2, with_norm, delayed_softmax> - <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); + topk_moe_cuda<2, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, + clamp_val, scale_val, config); break; case 4: - topk_moe_cuda<4, with_norm, delayed_softmax> - <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); + topk_moe_cuda<4, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, + clamp_val, scale_val, config); break; case 8: - topk_moe_cuda<8, with_norm, delayed_softmax> - <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); + topk_moe_cuda<8, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, + clamp_val, scale_val, config); break; case 16: - topk_moe_cuda<16, with_norm, delayed_softmax> - <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); + topk_moe_cuda<16, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, + clamp_val, scale_val, config); break; case 32: - topk_moe_cuda<32, with_norm, delayed_softmax> - <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); + topk_moe_cuda<32, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, + clamp_val, scale_val, config); break; case 64: - topk_moe_cuda<64, with_norm, delayed_softmax> - <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); + topk_moe_cuda<64, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, + clamp_val, scale_val, config); break; case 128: - topk_moe_cuda<128, with_norm, delayed_softmax> - <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); + topk_moe_cuda<128, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, + clamp_val, scale_val, config); break; case 256: - topk_moe_cuda<256, with_norm, delayed_softmax> - <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); + topk_moe_cuda<256, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, + clamp_val, scale_val, config); break; case 512: - topk_moe_cuda<512, with_norm, delayed_softmax> - <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); + topk_moe_cuda<512, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, + clamp_val, scale_val, config); + break; + case 576: + topk_moe_cuda<576, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, + clamp_val, scale_val, config); break; default: GGML_ASSERT(false && "fatal error"); @@ -228,13 +311,14 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx, } } -void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, - const ggml_tensor * logits, - ggml_tensor * weights, - ggml_tensor * ids, - const bool with_norm, - const bool delayed_softmax, - ggml_tensor * clamp) { +void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, + const ggml_tensor * logits, + ggml_tensor * weights, + ggml_tensor * ids, + const ggml_tensor * clamp, + const ggml_tensor * scale, + const ggml_tensor * bias, + const ggml_cuda_topk_moe_args & args) { GGML_ASSERT(logits->type == GGML_TYPE_F32); GGML_ASSERT(weights->type == GGML_TYPE_F32); GGML_ASSERT(ids->type == GGML_TYPE_I32); @@ -245,107 +329,75 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, const float * logits_d = (const float *) logits->data; float * weights_d = (float *) weights->data; int32_t * ids_d = (int32_t *) ids->data; + float * bias_d = bias ? (float *) bias->data : nullptr; + + float scale_val = scale ? ggml_get_op_params_f32(scale, 0) : 1.0f; GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts); const int n_expert_used = weights->ne[1]; + const bool with_norm = clamp != nullptr; + float clamp_val = -INFINITY; - if (with_norm) { - if (clamp) { - clamp_val = ggml_get_op_params_f32(clamp, 0); - } - launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used, clamp_val); + if (clamp) { + clamp_val = ggml_get_op_params_f32(clamp, 0); + } + + topk_moe_config config; + config.use_sigmoid = args.sigmoid; + config.with_norm = with_norm; + config.delayed_softmax = args.delayed_softmax; + + if (bias) { + launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, bias_d, n_rows, n_experts, n_expert_used, clamp_val, + scale_val, config); } else { - GGML_ASSERT(clamp == nullptr); - if (delayed_softmax) { - launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used, - clamp_val); - } else { - launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used, - clamp_val); - } + launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, bias_d, n_rows, n_experts, n_expert_used, clamp_val, + scale_val, config); } } -bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, +bool ggml_cuda_should_use_topk_moe(const ggml_tensor * gating_op, const ggml_tensor * weights, - const ggml_tensor * get_rows, - const ggml_tensor * argsort, - const ggml_tensor * clamp, - int n_expert) { - ggml_tensor * probs = get_rows->src[0]; - if (probs->op != GGML_OP_RESHAPE) { - return false; - } - probs = probs->src[0]; - ggml_tensor * selection_probs = argsort->src[0]; - - if (probs != selection_probs) { + const ggml_tensor * logits, + const ggml_tensor * ids) { + const int n_expert = ids->nb[1] / ids->nb[0]; + if (((n_expert & (n_expert - 1)) != 0 || n_expert > 512) && n_expert != 576) { return false; } - float scale = 1.0f; - float max_bias = 0.0f; - - memcpy(&scale, (const float *) softmax->op_params + 0, sizeof(float)); - memcpy(&max_bias, (const float *) softmax->op_params + 1, sizeof(float)); - - if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) { + if (!ggml_is_contiguous(weights) || !ggml_is_contiguous(logits)) { return false; } - if (scale != 1.0f || max_bias != 0.0f) { - return false; - } + if (gating_op->op == GGML_OP_SOFT_MAX) { + const ggml_tensor * softmax = gating_op; + float scale = 1.0f; + float max_bias = 0.0f; - // don't fuse when masks or sinks are present - if (softmax->src[1] || softmax->src[2]) { - return false; - } + memcpy(&scale, (const float *) softmax->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) softmax->op_params + 1, sizeof(float)); - // n_expert must be a power of 2 - if ((n_expert & (n_expert - 1)) != 0 || n_expert > 512) { - return false; - } - - if (clamp) { - if (clamp->op != GGML_OP_CLAMP) { + if (!ggml_is_contiguous(softmax->src[0])) { return false; } - float max_val = ggml_get_op_params_f32(clamp, 1); - if (max_val != INFINITY) { + if (scale != 1.0f || max_bias != 0.0f) { + return false; + } + + // don't fuse when masks or sinks are present + if (softmax->src[1] || softmax->src[2]) { + return false; + } + } else if (gating_op->op == GGML_OP_UNARY) { + ggml_unary_op op = ggml_get_unary_op(gating_op); + + if (op != GGML_UNARY_OP_SIGMOID) { return false; } } - return true; } - -std::initializer_list ggml_cuda_topk_moe_ops(bool norm, bool delayed_softmax) { - static std::initializer_list norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, - GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE, - GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV, - GGML_OP_RESHAPE }; - - static std::initializer_list no_norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, - GGML_OP_VIEW, GGML_OP_GET_ROWS }; - - static std::initializer_list delayed_softmax_ops = { GGML_OP_ARGSORT, GGML_OP_VIEW, - GGML_OP_GET_ROWS, GGML_OP_RESHAPE, - GGML_OP_SOFT_MAX, GGML_OP_RESHAPE }; - - GGML_ASSERT(!norm || !delayed_softmax); - - if (delayed_softmax) { - return delayed_softmax_ops; - } - - if (norm) { - return norm_ops; - } - - return no_norm_ops; -} diff --git a/ggml/src/ggml-cuda/topk-moe.cuh b/ggml/src/ggml-cuda/topk-moe.cuh index 6b6c13c587..243dc2f1c4 100644 --- a/ggml/src/ggml-cuda/topk-moe.cuh +++ b/ggml/src/ggml-cuda/topk-moe.cuh @@ -3,19 +3,25 @@ #include -void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, - const ggml_tensor * logits, - ggml_tensor * weights, - ggml_tensor * ids, - const bool with_norm, - const bool delayed_softmax = false, - ggml_tensor * weight_clamp = nullptr); +struct ggml_cuda_topk_moe_args { + bool sigmoid{}; + bool softmax{}; + bool delayed_softmax{}; + bool prob_bias{}; + bool norm{}; + bool scale{}; +}; -bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, +void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, + const ggml_tensor * logits, + ggml_tensor * weights, + ggml_tensor * ids, + const ggml_tensor * clamp, + const ggml_tensor * scale, + const ggml_tensor * bias, + const ggml_cuda_topk_moe_args & args); + +bool ggml_cuda_should_use_topk_moe(const ggml_tensor * gating_op, const ggml_tensor * weights, - const ggml_tensor * get_rows, - const ggml_tensor * argsort, - const ggml_tensor * clamp, - int n_expert); - -std::initializer_list ggml_cuda_topk_moe_ops(bool with_norm, bool delayed_softmax = false); + const ggml_tensor * logits, + const ggml_tensor * ids);