diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 8a9cfaf165..a1149e606e 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -13820,12 +13820,11 @@ static bool ggml_vk_can_fuse_rope_set_rows(ggml_backend_vk_context * ctx, const return true; } -// Check whether the tensors overlap in memory but are not equal. -// Fusions can potenitally overwrite src tensors in ways that are not prevented -// by ggml-alloc. If the fusion is entirely elementwise, then it's OK for them -// to overlap if they are exactly equal. -// XXX TODO this check is probably missing from several fusion optimizations. -static bool ggml_vk_tensors_overlap_but_not_equal(const ggml_tensor * a, const ggml_tensor * b) { +// Check whether the tensors overlap in memory. +// Fusions can potentially overwrite src tensors in ways that are not prevented +// by ggml-alloc. If the fusion src is being applied in a way that's elementwise +// with the destination, then it's OK for them to overlap if they are exactly equal. +static bool ggml_vk_tensors_overlap(const ggml_tensor * a, const ggml_tensor * b, bool elementwise) { ggml_backend_vk_buffer_context * a_buf_ctx = (ggml_backend_vk_buffer_context *)a->buffer->context; vk_buffer a_buf = a_buf_ctx->dev_buffer; ggml_backend_vk_buffer_context * b_buf_ctx = (ggml_backend_vk_buffer_context *)b->buffer->context; @@ -13836,7 +13835,7 @@ static bool ggml_vk_tensors_overlap_but_not_equal(const ggml_tensor * a, const g auto b_base = vk_tensor_offset(b) + b->view_offs; auto b_size = ggml_nbytes(b); - if (a_base == b_base && a_size == b_size) { + if (elementwise && a_base == b_base && a_size == b_size) { return false; } @@ -13874,13 +13873,6 @@ static bool ggml_vk_can_fuse_rms_norm_mul_rope(ggml_backend_vk_context * ctx, co return false; } - // must not overwrite srcs in a way that's not elementwise - ggml_tensor *other_src = mul->src[0] == rms ? mul->src[1] : mul->src[0]; - if (ggml_vk_tensors_overlap_but_not_equal(rms->src[0], rope) || - ggml_vk_tensors_overlap_but_not_equal(other_src, rope)) { - return false; - } - // conditions for pipeline creation if (!(ctx->device->float_controls_rte_fp16 && sizeof(vk_op_rms_norm_mul_rope_push_constants) <= ctx->device->properties.limits.maxPushConstantsSize)) { @@ -13942,6 +13934,18 @@ static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const stru return num_adds; } +static int32_t find_first_set(uint32_t x) { + int32_t ret = 0; + if (!x) { + return -1; + } + while (!(x & 1)) { + x >>= 1; + ret++; + } + return ret; +} + static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)"); ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; @@ -14040,6 +14044,12 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg total_mul_mat_bytes += bytes; } + // op_srcs_fused_elementwise indicates whether an op's srcs all contribute to + // the fused result in an elementwise-way. This affects whether the memory for + // the src is allowed to overlap the memory for the destination. + // The array is sized to handle the largest fusion (asserted later). + bool op_srcs_fused_elementwise[12]; + ctx->fused_topk_moe_mode = TOPK_MOE_COUNT; ctx->fused_topk_moe_scale = false; const char *fusion_string {}; @@ -14048,39 +14058,68 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg if (num_adds) { ctx->num_additional_fused_ops = num_adds - 1; fusion_string = "MULTI_ADD"; + std::fill_n(op_srcs_fused_elementwise, ctx->num_additional_fused_ops + 1, true); } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_ADD })) { ctx->num_additional_fused_ops = 2; fusion_string = "MUL_MAT_ADD_ADD"; + op_srcs_fused_elementwise[0] = false; + op_srcs_fused_elementwise[1] = true; + op_srcs_fused_elementwise[2] = true; } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT, GGML_OP_ADD })) { ctx->num_additional_fused_ops = 1; fusion_string = "MUL_MAT_ADD"; + op_srcs_fused_elementwise[0] = false; + op_srcs_fused_elementwise[1] = true; } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_MUL })) { ctx->num_additional_fused_ops = 2; fusion_string = "MUL_MAT_ID_ADD_ID_MUL"; + op_srcs_fused_elementwise[0] = false; + op_srcs_fused_elementwise[1] = true; + op_srcs_fused_elementwise[2] = true; } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID })) { ctx->num_additional_fused_ops = 1; fusion_string = "MUL_MAT_ID_ADD_ID"; + op_srcs_fused_elementwise[0] = false; + op_srcs_fused_elementwise[1] = true; } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_MUL })) { ctx->num_additional_fused_ops = 1; fusion_string = "MUL_MAT_ID_MUL"; + op_srcs_fused_elementwise[0] = false; + op_srcs_fused_elementwise[1] = true; } else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 4 }) && ggml_check_edges(cgraph, i, rms_norm_mul_rope_view_set_rows_edges) && ggml_vk_can_fuse_rms_norm_mul_rope(ctx, cgraph, i) && ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i + 2)) { ctx->num_additional_fused_ops = 4; fusion_string = "RMS_NORM_MUL_ROPE_VIEW_SET_ROWS"; + op_srcs_fused_elementwise[0] = false; + op_srcs_fused_elementwise[1] = false; + op_srcs_fused_elementwise[2] = false; + op_srcs_fused_elementwise[3] = false; + op_srcs_fused_elementwise[4] = false; } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ROPE })&& ggml_vk_can_fuse_rms_norm_mul_rope(ctx, cgraph, i)) { ctx->num_additional_fused_ops = 2; fusion_string = "RMS_NORM_MUL_ROPE"; + // rope is approximately elementwise - whole rows are done by a single workgroup and it's row-wise + op_srcs_fused_elementwise[0] = false; + op_srcs_fused_elementwise[1] = true; + op_srcs_fused_elementwise[2] = true; } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { ctx->num_additional_fused_ops = 1; fusion_string = "RMS_NORM_MUL"; + // rms_norm is not elementwise, but whole rows must be consumed and the scale factor computed before + // they are overwritten, and one workgroup per row. So close enough. + op_srcs_fused_elementwise[0] = true; + op_srcs_fused_elementwise[1] = true; } else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 2 }) && ggml_check_edges(cgraph, i, rope_view_set_rows_edges) && ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i)) { ctx->num_additional_fused_ops = 2; fusion_string = "ROPE_VIEW_SET_ROWS"; + op_srcs_fused_elementwise[0] = false; + op_srcs_fused_elementwise[1] = false; + op_srcs_fused_elementwise[2] = false; } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) && ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) && ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) { @@ -14089,6 +14128,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ctx->fused_ops_write_mask |= 1 << 3; ctx->fused_topk_moe_mode = TOPK_MOE_EARLY_SOFTMAX_NORM; fusion_string = "TOPK_MOE_EARLY_SOFTMAX_NORM"; + std::fill_n(op_srcs_fused_elementwise, ctx->num_additional_fused_ops + 1, false); } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_sigmoid_norm_bias, { i + 4, i + 10 }) && ggml_check_edges(cgraph, i, topk_moe_sigmoid_norm_bias_edges) && ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_SIGMOID_NORM_BIAS)) { @@ -14097,6 +14137,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ctx->fused_ops_write_mask |= 1 << 4; ctx->fused_topk_moe_mode = TOPK_MOE_SIGMOID_NORM_BIAS; fusion_string = "TOPK_MOE_SIGMOID_NORM_BIAS"; + std::fill_n(op_srcs_fused_elementwise, ctx->num_additional_fused_ops + 1, false); } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) && ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) && ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) { @@ -14105,6 +14146,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ctx->fused_ops_write_mask |= 1 << 3; ctx->fused_topk_moe_mode = TOPK_MOE_EARLY_SOFTMAX; fusion_string = "TOPK_MOE_EARLY_SOFTMAX"; + std::fill_n(op_srcs_fused_elementwise, ctx->num_additional_fused_ops + 1, false); } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) && ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) && ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) { @@ -14113,6 +14155,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ctx->fused_ops_write_mask |= 1 << 1; ctx->fused_topk_moe_mode = TOPK_MOE_LATE_SOFTMAX; fusion_string = "TOPK_MOE_LATE_SOFTMAX"; + std::fill_n(op_srcs_fused_elementwise, ctx->num_additional_fused_ops + 1, false); } if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) { // Look for an additional scale op to fuse - occurs in deepseek2 and nemotron3 nano. @@ -14120,11 +14163,73 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ggml_can_fuse_subgraph(cgraph, i + ctx->num_additional_fused_ops, { GGML_OP_GET_ROWS, GGML_OP_SCALE }, { i + ctx->num_additional_fused_ops + 1 })) { ctx->fused_topk_moe_scale = true; ctx->num_additional_fused_ops++; + op_srcs_fused_elementwise[ctx->num_additional_fused_ops] = false; } } } + GGML_ASSERT(ctx->num_additional_fused_ops < (int)(sizeof(op_srcs_fused_elementwise) / sizeof(op_srcs_fused_elementwise[0]))); ctx->fused_ops_write_mask |= 1 << ctx->num_additional_fused_ops; + // Check whether fusion would overwrite src operands while they're still in use. + // If so, disable fusion. + if (ctx->num_additional_fused_ops) { + // There are up to two output nodes - topk_moe has two. + uint32_t bits = ctx->fused_ops_write_mask & ~(1 << ctx->num_additional_fused_ops); + ggml_tensor *output_nodes[2] {}; + output_nodes[0] = cgraph->nodes[i + ctx->num_additional_fused_ops]; + if (bits) { + int output_idx = find_first_set(bits); + GGML_ASSERT(bits == (1u << output_idx)); + output_nodes[1] = cgraph->nodes[i + output_idx]; + } + + bool need_disable = false; + + // topk_moe often overwrites the source, but for a given row all the src values are + // loaded before anything is stored. If there's only one row, this is safe, so treat + // this as a special case. + bool is_topk_moe_single_row = ctx->fused_topk_moe_mode != TOPK_MOE_COUNT && + ggml_nrows(cgraph->nodes[i]->src[0]) == 1; + + if (!is_topk_moe_single_row) { + for (int j = 0; j < 2; ++j) { + ggml_tensor *dst = output_nodes[j]; + if (!dst) { + continue; + } + // Loop over all srcs of all nodes in the fusion. If the src overlaps + // the destination and the src is not an intermediate node that's being + // elided, then disable fusion. + for (int k = 0; k <= ctx->num_additional_fused_ops; ++k) { + for (uint32_t s = 0; s < GGML_MAX_SRC; ++s) { + ggml_tensor *src = cgraph->nodes[i + k]->src[s]; + if (!src || src->op == GGML_OP_NONE) { + continue; + } + if (ggml_vk_tensors_overlap(src, dst, op_srcs_fused_elementwise[k])) { + bool found = false; + for (int n = 0; n < k; ++n) { + if (cgraph->nodes[i + n] == src) { + found = true; + break; + } + } + if (!found) { + need_disable = true; + } + } + } + } + } + } + if (need_disable) { + ctx->num_additional_fused_ops = 0; + ctx->fused_ops_write_mask = 1; + ctx->fused_topk_moe_mode = TOPK_MOE_COUNT; + ctx->fused_topk_moe_scale = false; + } + } + // Signal the almost_ready fence when the graph is mostly complete (< 20% remaining) bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5; bool submit = (submitted_nodes >= nodes_per_submit) ||