diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h index 4ed5f35774..a9d1778641 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -358,7 +358,7 @@ extern "C" { typedef bool (*ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data); // Compare the output of two backends - GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor * test_node); + GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor const * const * test_nodes, size_t num_test_nodes); // Tensor initialization GGML_API enum ggml_status ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr); diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 8547ecc849..1b59924b8c 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -2053,7 +2053,7 @@ void ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy) { ggml_free(copy.ctx_unallocated); } -bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor * test_node) { +bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor const * const * test_nodes, size_t num_test_nodes) { struct ggml_backend_graph_copy copy = ggml_backend_graph_copy(backend2, graph); if (copy.buffer == NULL) { return false; @@ -2064,22 +2064,22 @@ bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t assert(g1->n_nodes == g2->n_nodes); - if (test_node != nullptr) { - // Compute the whole graph and only test the output for a specific tensor + if (num_test_nodes != 0) { + GGML_ASSERT(test_nodes); + // Compute the whole graph and only test the output for specific tensors ggml_backend_graph_compute(backend1, g1); ggml_backend_graph_compute(backend2, g2); - int test_node_idx = -1; + bool verified = false; for (int i = 0; i < g1->n_nodes; i++) { - struct ggml_tensor * t1 = g1->nodes[i]; - if (t1 == test_node) { - test_node_idx = i; - break; + for (size_t j = 0; j < num_test_nodes; ++j) { + if (g1->nodes[i] == test_nodes[j]) { + callback(i, g1->nodes[i], g2->nodes[i], user_data); + verified = true; + } } } - GGML_ASSERT(test_node_idx != -1); - - callback(test_node_idx, g1->nodes[test_node_idx], g2->nodes[test_node_idx], user_data); + GGML_ASSERT(verified); } else { for (int i = 0; i < g1->n_nodes; i++) { struct ggml_tensor * t1 = g1->nodes[i]; diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 493ee9c9a4..541e4a50b7 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -434,8 +434,15 @@ static constexpr std::initializer_list topk_moe_early_softmax_norm{ GGM GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV, GGML_OP_RESHAPE }; + +static constexpr std::initializer_list topk_moe_sigmoid_norm_bias{ GGML_OP_UNARY, GGML_OP_RESHAPE, GGML_OP_ADD, + GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS, + GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP, + GGML_OP_DIV, GGML_OP_RESHAPE }; + static constexpr std::initializer_list topk_moe_early_softmax { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS }; + static constexpr std::initializer_list topk_moe_late_softmax { GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE, GGML_OP_SOFT_MAX, GGML_OP_RESHAPE }; @@ -464,6 +471,32 @@ static constexpr std::initializer_list> topk_moe_early_softma { 9, 0, 8 }, // reshape->src[0] == div }; +//node #436 ( UNARY): ffn_moe_probs-10 ( 256K) [Vulka ] use=2: ffn_moe_logits-10 ( 256K) [Vulka ] +//node #437 ( RESHAPE): ffn_moe_probs-10 (re ( 256K) [Vulka ] use=1: ffn_moe_probs-10 ( 256K) [Vulka ] +//node #438 ( ADD): ffn_moe_probs_biased ( 256K) [Vulka ] use=1: ffn_moe_probs-10 ( 256K) [Vulka ] blk.10.exp_probs_b.b ( 0K) [Vulka ] +//node #439 ( ARGSORT): ffn_moe_argsort-10 ( 256K) [Vulka ] use=1: ffn_moe_probs_biased ( 256K) [Vulka ] +//node #440 ( VIEW): ffn_moe_topk-10 ( 255K) [Vulka ] use=3: ffn_moe_argsort-10 ( 256K) [Vulka ] +//node #441 ( GET_ROWS): ffn_moe_weights-10 ( 12K) [Vulka ] use=1: ffn_moe_probs-10 (re ( 256K) [Vulka ] ffn_moe_topk-10 ( 255K) [Vulka ] +//node #442 ( RESHAPE): ffn_moe_weights-10 ( ( 12K) [Vulka ] use=2: ffn_moe_weights-10 ( 12K) [Vulka ] +//node #443 ( SUM_ROWS): ffn_moe_weights_sum- ( 2K) [Vulka ] use=1: ffn_moe_weights-10 ( ( 12K) [Vulka ] +//node #444 ( CLAMP): ffn_moe_weights_sum_ ( 2K) [Vulka ] use=1: ffn_moe_weights_sum- ( 2K) [Vulka ] +//node #445 ( DIV): ffn_moe_weights_norm ( 12K) [Vulka ] use=1: ffn_moe_weights-10 ( ( 12K) [Vulka ] ffn_moe_weights_sum_ ( 2K) [Vulka ] +//node #446 ( RESHAPE): ffn_moe_weights_norm ( 12K) [Vulka ] use=1: ffn_moe_weights_norm ( 12K) [Vulka ] +static constexpr std::initializer_list> topk_moe_sigmoid_norm_bias_edges { + { 1, 0, 0 }, // reshape->src[0] == sigmoid + { 2, 0, 0 }, // add->src[0] == sigmoid + { 3, 0, 2 }, // argsort->src[0] == add + { 4, 0, 3 }, // view->src[0] == argsort + { 5, 0, 1 }, // get_rows->src[0] == reshape + { 5, 1, 4 }, // get_rows->src[1] == view + { 6, 0, 5 }, // reshape->src[0] == get_rows + { 7, 0, 6 }, // sum_rows->src[0] == reshape + { 8, 0, 7 }, // clamp->src[0] == sum_rows + { 9, 0, 6 }, // div->src[0] == reshape + { 9, 1, 8 }, // div->src[1] == clamp + {10, 0, 9 }, // reshape->src[0] == div +}; + // same as early_softmax_norm but ending after the get_rows static constexpr std::initializer_list> topk_moe_early_softmax_edges { { 1, 0, 0 }, // reshape->src[0] == softmax @@ -491,16 +524,10 @@ enum topk_moe_mode { TOPK_MOE_EARLY_SOFTMAX, TOPK_MOE_EARLY_SOFTMAX_NORM, TOPK_MOE_LATE_SOFTMAX, + TOPK_MOE_SIGMOID_NORM_BIAS, TOPK_MOE_COUNT, }; -static topk_moe_mode ggml_vk_num_additional_ops_to_topk_moe_mode(uint32_t num) { - topk_moe_mode mode = num == topk_moe_early_softmax_norm.size() - 1 ? TOPK_MOE_EARLY_SOFTMAX_NORM : - num == topk_moe_early_softmax.size() - 1 ? TOPK_MOE_EARLY_SOFTMAX : - TOPK_MOE_LATE_SOFTMAX; - return mode; -} - static constexpr std::initializer_list> rope_view_set_rows_edges { { 1, 0, 0 }, // view->src[0] == rope { 2, 0, 1 }, // set_rows->src[0] == view @@ -766,7 +793,7 @@ struct vk_device_struct { vk_pipeline pipeline_count_experts; // [2] is for whether to take n_experts from spec constant (0) or push constant (1) - vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][TOPK_MOE_COUNT][2]; + vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][2]; std::vector all_pipelines; @@ -1181,6 +1208,11 @@ struct vk_op_topk_moe_push_constants { uint32_t n_expert_used; float clamp_min; float clamp_max; + uint32_t gating_func; + uint32_t has_bias; + uint32_t with_norm; + float output_scale; + float output_bias; }; struct vk_op_add_id_push_constants { @@ -1771,6 +1803,8 @@ struct ggml_backend_vk_context { // Bit 'i' means nodes[start_of_fusion + i] writes to memory. // If there's no fusion, bit 0 is still set. int fused_ops_write_mask {}; + topk_moe_mode fused_topk_moe_mode {}; + bool fused_topk_moe_scale {}; // for GGML_VK_PERF_LOGGER std::unique_ptr perf_logger; @@ -4291,9 +4325,7 @@ static void ggml_vk_load_shaders(vk_device& device) { for (uint32_t use_push = 0; use_push < 2; ++use_push) { for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) { - ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX][use_push], "topk_moe_f32_early_softmax_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<subgroup_size); - ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM][use_push], "topk_moe_f32_early_softmax_norm"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<subgroup_size); - ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX][use_push], "topk_moe_f32_late_softmax"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<subgroup_size); + ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][use_push], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 4, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<subgroup_size); } } @@ -8684,10 +8716,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const if (ctx->num_additional_fused_ops) { uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0]))); GGML_ASSERT(idx < num_topk_moe_pipelines); - topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops); // use n_experts from push constant if it's not equal to the power of two spec constant bool use_push = dst->ne[0] != (1u << idx); - return ctx->device->pipeline_topk_moe[idx][mode][use_push]; + return ctx->device->pipeline_topk_moe[idx][use_push]; } if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) { @@ -10346,14 +10377,16 @@ static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& sub } static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx) { - topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops); + topk_moe_mode mode = ctx->fused_topk_moe_mode; ggml_tensor * logits = cgraph->nodes[node_idx + 0]->src[0]; - ggml_tensor * weights = (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) ? cgraph->nodes[node_idx + 9] : - (mode == TOPK_MOE_EARLY_SOFTMAX) ? cgraph->nodes[node_idx + 4] : - cgraph->nodes[node_idx + 5]; - ggml_tensor * ids = (mode == TOPK_MOE_LATE_SOFTMAX) ? cgraph->nodes[node_idx + 1] : cgraph->nodes[node_idx + 3]; + ggml_tensor * bias = (mode == TOPK_MOE_SIGMOID_NORM_BIAS) ? cgraph->nodes[node_idx + 2]->src[1] : logits; + ggml_tensor * weights = cgraph->nodes[node_idx + ctx->num_additional_fused_ops]; + ggml_tensor * ids = (mode == TOPK_MOE_SIGMOID_NORM_BIAS) ? cgraph->nodes[node_idx + 4] : + (mode == TOPK_MOE_LATE_SOFTMAX) ? cgraph->nodes[node_idx + 1] : + cgraph->nodes[node_idx + 3]; GGML_ASSERT(logits->type == GGML_TYPE_F32); + GGML_ASSERT(bias->type == GGML_TYPE_F32); GGML_ASSERT(weights->type == GGML_TYPE_F32); GGML_ASSERT(ids->type == GGML_TYPE_I32); @@ -10368,6 +10401,7 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); vk_subbuffer logits_buf = ggml_vk_tensor_subbuffer(ctx, logits); + vk_subbuffer bias_buf = ggml_vk_tensor_subbuffer(ctx, bias); vk_subbuffer weights_buf = ggml_vk_tensor_subbuffer(ctx, weights); vk_subbuffer ids_buf = ggml_vk_tensor_subbuffer(ctx, ids); @@ -10375,18 +10409,45 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, pc.n_rows = n_rows; pc.n_experts_push = n_experts; pc.n_expert_used = n_expert_used; + pc.clamp_min = -std::numeric_limits::infinity(); + pc.clamp_max = std::numeric_limits::infinity(); if (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) { ggml_tensor * clamp = cgraph->nodes[node_idx + 7]; + GGML_ASSERT(clamp->op == GGML_OP_CLAMP); pc.clamp_min = ggml_get_op_params_f32(clamp, 0); pc.clamp_max = ggml_get_op_params_f32(clamp, 1); } + if (mode == TOPK_MOE_SIGMOID_NORM_BIAS) { + ggml_tensor * clamp = cgraph->nodes[node_idx + 8]; + GGML_ASSERT(clamp->op == GGML_OP_CLAMP); + pc.clamp_min = ggml_get_op_params_f32(clamp, 0); + pc.clamp_max = ggml_get_op_params_f32(clamp, 1); + } + +#define GATING_FUNC_SOFTMAX 0 +#define GATING_FUNC_SIGMOID 1 +#define GATING_FUNC_SOFTMAX_WEIGHT 2 + + pc.gating_func = mode == TOPK_MOE_SIGMOID_NORM_BIAS ? GATING_FUNC_SIGMOID : + mode == TOPK_MOE_LATE_SOFTMAX ? GATING_FUNC_SOFTMAX_WEIGHT : + GATING_FUNC_SOFTMAX; + pc.has_bias = mode == TOPK_MOE_SIGMOID_NORM_BIAS; + pc.with_norm = mode == TOPK_MOE_EARLY_SOFTMAX_NORM || mode == TOPK_MOE_SIGMOID_NORM_BIAS; + if (ctx->fused_topk_moe_scale) { + GGML_ASSERT(weights->op == GGML_OP_SCALE); + pc.output_scale = ggml_get_op_params_f32(weights, 0); + pc.output_bias = ggml_get_op_params_f32(weights, 1); + } else { + pc.output_scale = 1.0f; + pc.output_bias = 0.0f; + } GGML_ASSERT(n_expert_used <= n_experts); const uint32_t rows_per_block = 4; std::array elements = { CEIL_DIV(n_rows, rows_per_block), 1, 1 }; - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {logits_buf, weights_buf, ids_buf}, pc, elements); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {logits_buf, bias_buf, weights_buf, ids_buf}, pc, elements); } static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_cgraph * cgraph, int node_idx, bool backprop) { @@ -12128,6 +12189,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr break; case GGML_OP_UNARY: + if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) { + ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx); + break; + } + switch (ggml_get_unary_op(node)) { case GGML_UNARY_OP_EXP: case GGML_UNARY_OP_SILU: @@ -12175,7 +12241,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr break; case GGML_OP_SOFT_MAX: - if (ctx->num_additional_fused_ops) { + if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) { ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx); } else { ggml_vk_soft_max(ctx, compute_ctx, src0, src1, src2, node); @@ -12195,7 +12261,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr break; case GGML_OP_ARGSORT: - if (ctx->num_additional_fused_ops) { + if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) { ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx); } else { ggml_vk_argsort(ctx, compute_ctx, src0, node); @@ -13048,6 +13114,24 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc get_rows = cgraph->nodes[node_idx + 4]; argsort = cgraph->nodes[node_idx + 2]; break; + case TOPK_MOE_SIGMOID_NORM_BIAS: + softmax = cgraph->nodes[node_idx + 0]; // really sigmoid + weights = cgraph->nodes[node_idx + 10]; + get_rows = cgraph->nodes[node_idx + 5]; + argsort = cgraph->nodes[node_idx + 3]; + if (ggml_get_unary_op(softmax) != GGML_UNARY_OP_SIGMOID) { + return false; + } + // bias is expected to be 1D + if (ggml_nrows(cgraph->nodes[node_idx + 2]->src[1]) != 1 || + !ggml_is_contiguous(cgraph->nodes[node_idx + 2]->src[1])) { + return false; + } + // sigmoid fusion seems to generate infinities on moltenvk + if (ctx->device->driver_id == vk::DriverId::eMoltenvk) { + return false; + } + break; case TOPK_MOE_EARLY_SOFTMAX: softmax = cgraph->nodes[node_idx + 0]; weights = cgraph->nodes[node_idx + 4]; @@ -13071,26 +13155,28 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc probs = probs->src[0]; ggml_tensor * selection_probs = argsort->src[0]; - if (probs != selection_probs) { + if (probs != selection_probs && mode != TOPK_MOE_SIGMOID_NORM_BIAS) { return false; } - const float * op_params = (const float *)softmax->op_params; - - float scale = op_params[0]; - float max_bias = op_params[1]; - if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) { return false; } - if (scale != 1.0f || max_bias != 0.0f) { - return false; - } + if (softmax->op == GGML_OP_SOFT_MAX) { + const float * op_params = (const float *)softmax->op_params; - // don't fuse when masks or sinks are present - if (softmax->src[1] || softmax->src[2]) { - return false; + float scale = op_params[0]; + float max_bias = op_params[1]; + + if (scale != 1.0f || max_bias != 0.0f) { + return false; + } + + // don't fuse when masks or sinks are present + if (softmax->src[1] || softmax->src[2]) { + return false; + } } const int n_expert = softmax->ne[0]; @@ -13363,6 +13449,8 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg total_mul_mat_bytes += bytes; } + ctx->fused_topk_moe_mode = TOPK_MOE_COUNT; + ctx->fused_topk_moe_scale = false; const char *fusion_string {}; if (!ctx->device->disable_fusion) { uint32_t num_adds = ggml_vk_fuse_multi_add(ctx, cgraph, i); @@ -13408,13 +13496,23 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1; // view of argsort writes to memory ctx->fused_ops_write_mask |= 1 << 3; + ctx->fused_topk_moe_mode = TOPK_MOE_EARLY_SOFTMAX_NORM; fusion_string = "TOPK_MOE_EARLY_SOFTMAX_NORM"; + } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_sigmoid_norm_bias, { i + 4, i + 10 }) && + ggml_check_edges(cgraph, i, topk_moe_sigmoid_norm_bias_edges) && + ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_SIGMOID_NORM_BIAS)) { + ctx->num_additional_fused_ops = topk_moe_sigmoid_norm_bias.size() - 1; + // view of argsort writes to memory + ctx->fused_ops_write_mask |= 1 << 4; + ctx->fused_topk_moe_mode = TOPK_MOE_SIGMOID_NORM_BIAS; + fusion_string = "TOPK_MOE_SIGMOID_NORM_BIAS"; } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) && ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) && ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) { ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1; // view of argsort writes to memory ctx->fused_ops_write_mask |= 1 << 3; + ctx->fused_topk_moe_mode = TOPK_MOE_EARLY_SOFTMAX; fusion_string = "TOPK_MOE_EARLY_SOFTMAX"; } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) && ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) && @@ -13422,8 +13520,17 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1; // view of argsort writes to memory ctx->fused_ops_write_mask |= 1 << 1; + ctx->fused_topk_moe_mode = TOPK_MOE_LATE_SOFTMAX; fusion_string = "TOPK_MOE_LATE_SOFTMAX"; } + if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) { + // Look for an additional scale op to fuse - occurs in deepseek2 and nemotron3 nano. + if (ggml_can_fuse_subgraph(cgraph, i + ctx->num_additional_fused_ops - 1, { GGML_OP_DIV, GGML_OP_RESHAPE, GGML_OP_SCALE }, { i + ctx->num_additional_fused_ops + 1 }) || + ggml_can_fuse_subgraph(cgraph, i + ctx->num_additional_fused_ops, { GGML_OP_GET_ROWS, GGML_OP_SCALE }, { i + ctx->num_additional_fused_ops + 1 })) { + ctx->fused_topk_moe_scale = true; + ctx->num_additional_fused_ops++; + } + } } ctx->fused_ops_write_mask |= 1 << ctx->num_additional_fused_ops; @@ -13602,6 +13709,9 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * if (keep_pattern(topk_moe_early_softmax_norm)) { continue; } + if (keep_pattern(topk_moe_sigmoid_norm_bias)) { + continue; + } if (keep_pattern(topk_moe_early_softmax)) { continue; } @@ -13628,6 +13738,7 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * } // Don't pull forward nodes from fusion patterns if (match_pattern(topk_moe_early_softmax_norm, j) || + match_pattern(topk_moe_sigmoid_norm_bias, j) || match_pattern(topk_moe_early_softmax, j) || match_pattern(topk_moe_late_softmax, j)) { continue; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp b/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp index b83a2b9d2d..4bf6d2bcb0 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp @@ -7,6 +7,10 @@ #include "types.glsl" +#define GATING_FUNC_SOFTMAX 0 +#define GATING_FUNC_SIGMOID 1 +#define GATING_FUNC_SOFTMAX_WEIGHT 2 + layout (push_constant) uniform parameter { uint n_rows; @@ -14,15 +18,18 @@ layout (push_constant) uniform parameter uint n_expert_used; float clamp_min; float clamp_max; + uint gating_func; + uint has_bias; + uint with_norm; + float output_scale; + float output_bias; }; layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in; layout(constant_id = 0) const uint WARP_SIZE = 32; layout(constant_id = 1) const uint n_experts_spec = 512; -layout(constant_id = 2) const bool with_norm = true; -layout(constant_id = 3) const bool late_softmax = false; -layout(constant_id = 4) const bool nexperts_use_push = false; +layout(constant_id = 2) const bool nexperts_use_push = false; uint n_experts = nexperts_use_push ? n_experts_push : n_experts_spec; @@ -31,8 +38,9 @@ uint n_experts = nexperts_use_push ? n_experts_push : n_experts_spec; const uint experts_per_thread = CEIL_DIV(n_experts_spec, WARP_SIZE); layout (binding = 0, std430) readonly buffer Logits {float logits[];}; -layout (binding = 1, std430) writeonly buffer Weights {float weights[];}; -layout (binding = 2, std430) writeonly buffer Ids {uint ids[];}; +layout (binding = 1, std430) readonly buffer BiasProbs {float bias[];}; +layout (binding = 2, std430) writeonly buffer Weights {float weights[];}; +layout (binding = 3, std430) writeonly buffer Ids {uint ids[];}; const float INFINITY = 1.0 / 0.0; @@ -87,20 +95,40 @@ void main() { } const uint logits_offset = n_experts * row; + const uint bias_offset = 0; // 1D const uint weights_offset = n_expert_used * row; const uint ids_offset = n_experts * row; const uint lane = gl_SubgroupInvocationID; - float wt[experts_per_thread]; + float probs[experts_per_thread]; [[unroll]] for (uint i = 0; i < n_experts; i += WARP_SIZE) { const uint expert = i + lane; - wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[logits_offset + expert] : -INFINITY; + probs[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[logits_offset + expert] : -INFINITY; } - if (!late_softmax) { - softmax_warp_inplace(wt, n_experts, lane, nexperts_use_push); + if (gating_func == GATING_FUNC_SOFTMAX) { + softmax_warp_inplace(probs, n_experts, lane, nexperts_use_push); + } else if (gating_func == GATING_FUNC_SIGMOID) { + [[unroll]] + for (int i = 0; i < experts_per_thread; i++) { + probs[i] = 1.f / (1.f + exp(-probs[i])); + } + } + + float selection_probs[experts_per_thread]; + if (has_bias != 0) { + [[unroll]] + for (uint i = 0; i < n_experts; i += WARP_SIZE) { + const uint expert = i + lane; + selection_probs[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? probs[i / WARP_SIZE] + bias[bias_offset + expert] : -INFINITY; + } + } else { + [[unroll]] + for (int i = 0; i < experts_per_thread; i++) { + selection_probs[i] = probs[i]; + } } // at this point, each thread holds a portion of softmax, @@ -117,14 +145,16 @@ void main() { } for (int k = 0; k < n_expert_used; k++) { - float max_val = wt[0]; + float max_val = probs[0]; + float max_val_s = selection_probs[0]; uint max_expert = lane; [[unroll]] for (int i = 1; i < experts_per_thread; i++) { const uint expert = lane + i * WARP_SIZE; - if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) { - max_val = wt[i]; + if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && selection_probs[i] > max_val_s) { + max_val = probs[i]; + max_val_s = selection_probs[i]; max_expert = expert; } } @@ -132,9 +162,11 @@ void main() { [[unroll]] for (uint mask = WARP_SIZE / 2; mask > 0; mask /= 2) { const float val = subgroupShuffleXor(max_val, mask); + const float val_s = subgroupShuffleXor(max_val_s, mask); const uint expert = subgroupShuffleXor(max_expert, mask); - if (val > max_val || (val == max_val && expert < max_expert)) { + if (val_s > max_val_s || (val_s == max_val_s && expert < max_expert)) { max_val = val; + max_val_s = val_s; max_expert = expert; } } @@ -144,16 +176,14 @@ void main() { } if ((max_expert & (WARP_SIZE - 1)) == lane) { - wt[max_expert / WARP_SIZE] = -INFINITY; + selection_probs[max_expert / WARP_SIZE] = -INFINITY; ids[ids_offset + k] = max_expert; - if (with_norm) { - wt_sum += max_val; - } + wt_sum += max_val; } } - if (with_norm) { + if (with_norm != 0) { wt_sum = subgroupAdd(wt_sum); wt_sum = clamp(wt_sum, clamp_min, clamp_max); const float inv_sum = 1.0f / wt_sum; @@ -164,7 +194,7 @@ void main() { } } - if (late_softmax) { + if (gating_func == GATING_FUNC_SOFTMAX_WEIGHT) { softmax_warp_inplace(output_weights, n_expert_used, lane, true); } @@ -172,7 +202,7 @@ void main() { for (uint i = 0; i < experts_per_thread; ++i) { uint idx = i * WARP_SIZE + lane; if (idx < n_expert_used) { - weights[weights_offset + idx] = output_weights[i]; + weights[weights_offset + idx] = output_scale * output_weights[i] + output_bias; } } } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 0b981b1788..6dedd8de58 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1158,6 +1158,7 @@ struct test_case { } virtual bool run_whole_graph() { return false; } + virtual std::vector fusion_test_nodes() { return {}; } ggml_cgraph * gf = nullptr; ggml_cgraph * gb = nullptr; @@ -1391,7 +1392,13 @@ struct test_case { GGML_UNUSED(index); }; - const bool cmp_ok = ggml_backend_compare_graph_backend(backend1, backend2, gf, callback, &ud, run_whole_graph() ? out : nullptr); + std::vector fused_nodes_to_verify = fusion_test_nodes(); + if (fused_nodes_to_verify.size() == 0 && run_whole_graph()) { + fused_nodes_to_verify.push_back(out); + } + const bool cmp_ok = ggml_backend_compare_graph_backend(backend1, backend2, gf, callback, &ud, + run_whole_graph() ? fused_nodes_to_verify.data() : nullptr, + fused_nodes_to_verify.size()); ggml_backend_buffer_free(buf); @@ -5180,6 +5187,8 @@ struct test_topk_moe : public test_case { const bool bias_probs; const MoeGatingFunc gating_func; const float scale_w; + ggml_tensor * weights {}; + ggml_tensor * selected_experts {}; test_topk_moe(std::array ne = { 10, 5, 1, 1 }, int n_expert_used = 1, @@ -5217,16 +5226,16 @@ struct test_topk_moe : public test_case { ggml_tensor * selection_probs = probs; if (bias_probs) { - ggml_tensor * exp_probs_b = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data()); + ggml_tensor * exp_probs_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ne[0]); ggml_set_name(exp_probs_b, "exp_probs_b"); selection_probs = ggml_add(ctx, probs, exp_probs_b); ggml_set_name(selection_probs, "selection_probs"); } - ggml_tensor * selected_experts = ggml_argsort_top_k(ctx, selection_probs, n_expert_used); // [n_expert_used, n_tokens] + selected_experts = ggml_argsort_top_k(ctx, selection_probs, n_expert_used); // [n_expert_used, n_tokens] ggml_set_name(selected_experts, "selected_experts"); - ggml_tensor * weights = ggml_get_rows(ctx, ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens] + weights = ggml_get_rows(ctx, ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens] ggml_set_name(weights, "weights"); if (gating_func == GATING_FUNC_SOFTMAX_WEIGHT) { @@ -5252,6 +5261,21 @@ struct test_topk_moe : public test_case { ggml_set_name(weights, "weights"); return weights; } + // Verify two outputs + std::vector fusion_test_nodes() override { return { selected_experts, weights }; } + + // allow output in arbitrary order + double err(const float * a, const float * b, size_t n) override { + std::vector a2(n); + std::vector b2(n); + for (size_t i = 0; i < n; ++i) { + a2[i] = a[i]; + b2[i] = b[i]; + } + std::sort(a2.begin(), a2.end()); + std::sort(b2.begin(), b2.end()); + return nmse(a2.data(), b2.data(), n); + } }; struct test_mul_mat_vec_fusion : public test_case {