diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 35015bc7f3..54dc43bc08 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3412,6 +3412,69 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, return false; } +// returns whether the write (out) nodes overwrite the read nodes in operation +static bool ggml_cuda_check_fusion_memory_ranges(ggml_cgraph * cgraph, + int node_idx, + int node_count, + int * out_nodes, + int out_count) { + auto nodes_overlap = [&](const ggml_tensor * a, const ggml_tensor * b) { + const int64_t a_start = (int64_t) a->data; + const int64_t a_end = a_start + ggml_nbytes(a); + + const int64_t b_start = (int64_t) b->data; + const int64_t b_end = b_start + ggml_nbytes(b); + + if ((b_start <= a_start && a_start < b_end) || (a_start <= b_start && b_start < a_end)) { + return true; + } + + return false; + }; + + bool is_ok = true; + // for nrows=1, all fusion operations correctly read the src before writing dst or do it elementwise, so we should be ok + if (ggml_nrows(cgraph->nodes[node_idx]) == 1) { + return true; + } + + for (int i = 0; i < out_count; ++i) { + const ggml_tensor * dst = cgraph->nodes[out_nodes[i]]; + + for (int j = node_idx; j < node_idx + node_count; ++j) { + // Loop over all srcs of all nodes in the fusion. If the src overlaps + // the destination and the src is not an intermediate node that's being + // elided, then disable fusion. + + for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) { + const ggml_tensor * src = cgraph->nodes[j]->src[src_idx]; + + if (!src || src->op == GGML_OP_NONE) { + continue; + } + + if (nodes_overlap(dst, src)) { + bool found = false; + + for (int k = node_idx; k < j; ++k) { + if (cgraph->nodes[k] == src) { + found = true; + break; + } + } + + if (!found) { + is_ok = false; + break; + } + } + } + } + } + + return is_ok; +} + static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required, const void * graph_key) { bool graph_evaluated_or_captured = false; @@ -3608,7 +3671,8 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud out_nodes[1] = i + ops.size() - 1; if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) && - ggml_cuda_should_use_topk_moe(node, logits, weights, ids)) { + ggml_cuda_should_use_topk_moe(node, logits, weights, ids) && + ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2)) { ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args); i += ops.size() - 1; continue; @@ -3623,7 +3687,8 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud int out_nodes[2] = { i + 1, i + 5 }; if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) && - ggml_cuda_should_use_topk_moe(softmax, logits, weights, ids)) { + ggml_cuda_should_use_topk_moe(softmax, logits, weights, ids) && + ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2)) { ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args); i += ops.size() - 1; continue; diff --git a/ggml/src/ggml-cuda/topk-moe.cu b/ggml/src/ggml-cuda/topk-moe.cu index 08a88990dd..3020e5c743 100644 --- a/ggml/src/ggml-cuda/topk-moe.cu +++ b/ggml/src/ggml-cuda/topk-moe.cu @@ -119,6 +119,18 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * } } + // Sanitize NaN to -FLT_MAX so the iterative argmax produces unique expert IDs. + // NaN comparisons always return false, which would cause the same expert to be + // selected repeatedly. -FLT_MAX compares normally and is still excluded by the + // -INFINITY sentinel used after each selection round. + // More relevant for the cuBLAS path. See https://github.com/ggml-org/llama.cpp/issues/19659 +#pragma unroll + for (int i = 0; i < experts_per_thread; i++) { + if (__isnanf(wt[i])) { + wt[i] = -FLT_MAX; + } + } + // selection_wt is only needed when bias is present (selection uses wt + bias) // when no bias, we use wt directly for both selection and weight values float selection_wt[has_bias ? experts_per_thread : 1];