vulkan: extend topk_moe to handle sigmoid w/exp_probs_b for nemotron (#18295)
* vulkan: extend topk_moe to handle sigmoid w/exp_probs_b for nemotron Also handle GGML_OP_SCALE at the end (nemotron, deepseek2). Fewer pipeline variants and spec constants, just use push constants. In test_topk_moe, change exp_probs_b to be 1D, matching real networks. Update test-backend-ops and ggml-backend to allow verifying multiple outputs in a fusion test (topk_moe has two outputs). Previously only the final node was verified. * change test_topk_moe to allow results in arbitrary order * disable sigmoid fusion for moltenvk
This commit is contained in:
parent
9e10bd2eaf
commit
be47fb9285
|
|
@ -358,7 +358,7 @@ extern "C" {
|
|||
typedef bool (*ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data);
|
||||
|
||||
// Compare the output of two backends
|
||||
GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor * test_node);
|
||||
GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor const * const * test_nodes, size_t num_test_nodes);
|
||||
|
||||
// Tensor initialization
|
||||
GGML_API enum ggml_status ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr);
|
||||
|
|
|
|||
|
|
@ -2053,7 +2053,7 @@ void ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy) {
|
|||
ggml_free(copy.ctx_unallocated);
|
||||
}
|
||||
|
||||
bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor * test_node) {
|
||||
bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor const * const * test_nodes, size_t num_test_nodes) {
|
||||
struct ggml_backend_graph_copy copy = ggml_backend_graph_copy(backend2, graph);
|
||||
if (copy.buffer == NULL) {
|
||||
return false;
|
||||
|
|
@ -2064,22 +2064,22 @@ bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t
|
|||
|
||||
assert(g1->n_nodes == g2->n_nodes);
|
||||
|
||||
if (test_node != nullptr) {
|
||||
// Compute the whole graph and only test the output for a specific tensor
|
||||
if (num_test_nodes != 0) {
|
||||
GGML_ASSERT(test_nodes);
|
||||
// Compute the whole graph and only test the output for specific tensors
|
||||
ggml_backend_graph_compute(backend1, g1);
|
||||
ggml_backend_graph_compute(backend2, g2);
|
||||
|
||||
int test_node_idx = -1;
|
||||
bool verified = false;
|
||||
for (int i = 0; i < g1->n_nodes; i++) {
|
||||
struct ggml_tensor * t1 = g1->nodes[i];
|
||||
if (t1 == test_node) {
|
||||
test_node_idx = i;
|
||||
break;
|
||||
for (size_t j = 0; j < num_test_nodes; ++j) {
|
||||
if (g1->nodes[i] == test_nodes[j]) {
|
||||
callback(i, g1->nodes[i], g2->nodes[i], user_data);
|
||||
verified = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
GGML_ASSERT(test_node_idx != -1);
|
||||
|
||||
callback(test_node_idx, g1->nodes[test_node_idx], g2->nodes[test_node_idx], user_data);
|
||||
GGML_ASSERT(verified);
|
||||
} else {
|
||||
for (int i = 0; i < g1->n_nodes; i++) {
|
||||
struct ggml_tensor * t1 = g1->nodes[i];
|
||||
|
|
|
|||
|
|
@ -434,8 +434,15 @@ static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax_norm{ GGM
|
|||
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
|
||||
GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV,
|
||||
GGML_OP_RESHAPE };
|
||||
|
||||
static constexpr std::initializer_list<ggml_op> topk_moe_sigmoid_norm_bias{ GGML_OP_UNARY, GGML_OP_RESHAPE, GGML_OP_ADD,
|
||||
GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS,
|
||||
GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP,
|
||||
GGML_OP_DIV, GGML_OP_RESHAPE };
|
||||
|
||||
static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
|
||||
GGML_OP_VIEW, GGML_OP_GET_ROWS };
|
||||
|
||||
static constexpr std::initializer_list<ggml_op> topk_moe_late_softmax { GGML_OP_ARGSORT, GGML_OP_VIEW,
|
||||
GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
|
||||
GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
|
||||
|
|
@ -464,6 +471,32 @@ static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softma
|
|||
{ 9, 0, 8 }, // reshape->src[0] == div
|
||||
};
|
||||
|
||||
//node #436 ( UNARY): ffn_moe_probs-10 ( 256K) [Vulka ] use=2: ffn_moe_logits-10 ( 256K) [Vulka ]
|
||||
//node #437 ( RESHAPE): ffn_moe_probs-10 (re ( 256K) [Vulka ] use=1: ffn_moe_probs-10 ( 256K) [Vulka ]
|
||||
//node #438 ( ADD): ffn_moe_probs_biased ( 256K) [Vulka ] use=1: ffn_moe_probs-10 ( 256K) [Vulka ] blk.10.exp_probs_b.b ( 0K) [Vulka ]
|
||||
//node #439 ( ARGSORT): ffn_moe_argsort-10 ( 256K) [Vulka ] use=1: ffn_moe_probs_biased ( 256K) [Vulka ]
|
||||
//node #440 ( VIEW): ffn_moe_topk-10 ( 255K) [Vulka ] use=3: ffn_moe_argsort-10 ( 256K) [Vulka ]
|
||||
//node #441 ( GET_ROWS): ffn_moe_weights-10 ( 12K) [Vulka ] use=1: ffn_moe_probs-10 (re ( 256K) [Vulka ] ffn_moe_topk-10 ( 255K) [Vulka ]
|
||||
//node #442 ( RESHAPE): ffn_moe_weights-10 ( ( 12K) [Vulka ] use=2: ffn_moe_weights-10 ( 12K) [Vulka ]
|
||||
//node #443 ( SUM_ROWS): ffn_moe_weights_sum- ( 2K) [Vulka ] use=1: ffn_moe_weights-10 ( ( 12K) [Vulka ]
|
||||
//node #444 ( CLAMP): ffn_moe_weights_sum_ ( 2K) [Vulka ] use=1: ffn_moe_weights_sum- ( 2K) [Vulka ]
|
||||
//node #445 ( DIV): ffn_moe_weights_norm ( 12K) [Vulka ] use=1: ffn_moe_weights-10 ( ( 12K) [Vulka ] ffn_moe_weights_sum_ ( 2K) [Vulka ]
|
||||
//node #446 ( RESHAPE): ffn_moe_weights_norm ( 12K) [Vulka ] use=1: ffn_moe_weights_norm ( 12K) [Vulka ]
|
||||
static constexpr std::initializer_list<std::array<int, 3>> topk_moe_sigmoid_norm_bias_edges {
|
||||
{ 1, 0, 0 }, // reshape->src[0] == sigmoid
|
||||
{ 2, 0, 0 }, // add->src[0] == sigmoid
|
||||
{ 3, 0, 2 }, // argsort->src[0] == add
|
||||
{ 4, 0, 3 }, // view->src[0] == argsort
|
||||
{ 5, 0, 1 }, // get_rows->src[0] == reshape
|
||||
{ 5, 1, 4 }, // get_rows->src[1] == view
|
||||
{ 6, 0, 5 }, // reshape->src[0] == get_rows
|
||||
{ 7, 0, 6 }, // sum_rows->src[0] == reshape
|
||||
{ 8, 0, 7 }, // clamp->src[0] == sum_rows
|
||||
{ 9, 0, 6 }, // div->src[0] == reshape
|
||||
{ 9, 1, 8 }, // div->src[1] == clamp
|
||||
{10, 0, 9 }, // reshape->src[0] == div
|
||||
};
|
||||
|
||||
// same as early_softmax_norm but ending after the get_rows
|
||||
static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softmax_edges {
|
||||
{ 1, 0, 0 }, // reshape->src[0] == softmax
|
||||
|
|
@ -491,16 +524,10 @@ enum topk_moe_mode {
|
|||
TOPK_MOE_EARLY_SOFTMAX,
|
||||
TOPK_MOE_EARLY_SOFTMAX_NORM,
|
||||
TOPK_MOE_LATE_SOFTMAX,
|
||||
TOPK_MOE_SIGMOID_NORM_BIAS,
|
||||
TOPK_MOE_COUNT,
|
||||
};
|
||||
|
||||
static topk_moe_mode ggml_vk_num_additional_ops_to_topk_moe_mode(uint32_t num) {
|
||||
topk_moe_mode mode = num == topk_moe_early_softmax_norm.size() - 1 ? TOPK_MOE_EARLY_SOFTMAX_NORM :
|
||||
num == topk_moe_early_softmax.size() - 1 ? TOPK_MOE_EARLY_SOFTMAX :
|
||||
TOPK_MOE_LATE_SOFTMAX;
|
||||
return mode;
|
||||
}
|
||||
|
||||
static constexpr std::initializer_list<std::array<int, 3>> rope_view_set_rows_edges {
|
||||
{ 1, 0, 0 }, // view->src[0] == rope
|
||||
{ 2, 0, 1 }, // set_rows->src[0] == view
|
||||
|
|
@ -766,7 +793,7 @@ struct vk_device_struct {
|
|||
vk_pipeline pipeline_count_experts;
|
||||
|
||||
// [2] is for whether to take n_experts from spec constant (0) or push constant (1)
|
||||
vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][TOPK_MOE_COUNT][2];
|
||||
vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][2];
|
||||
|
||||
std::vector<vk_pipeline_ref> all_pipelines;
|
||||
|
||||
|
|
@ -1181,6 +1208,11 @@ struct vk_op_topk_moe_push_constants {
|
|||
uint32_t n_expert_used;
|
||||
float clamp_min;
|
||||
float clamp_max;
|
||||
uint32_t gating_func;
|
||||
uint32_t has_bias;
|
||||
uint32_t with_norm;
|
||||
float output_scale;
|
||||
float output_bias;
|
||||
};
|
||||
|
||||
struct vk_op_add_id_push_constants {
|
||||
|
|
@ -1771,6 +1803,8 @@ struct ggml_backend_vk_context {
|
|||
// Bit 'i' means nodes[start_of_fusion + i] writes to memory.
|
||||
// If there's no fusion, bit 0 is still set.
|
||||
int fused_ops_write_mask {};
|
||||
topk_moe_mode fused_topk_moe_mode {};
|
||||
bool fused_topk_moe_scale {};
|
||||
|
||||
// for GGML_VK_PERF_LOGGER
|
||||
std::unique_ptr<vk_perf_logger> perf_logger;
|
||||
|
|
@ -4291,9 +4325,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
|
||||
for (uint32_t use_push = 0; use_push < 2; ++use_push) {
|
||||
for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) {
|
||||
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX][use_push], "topk_moe_f32_early_softmax_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 0, use_push}, 1, true, true, device->subgroup_size);
|
||||
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM][use_push], "topk_moe_f32_early_softmax_norm"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1, 0, use_push}, 1, true, true, device->subgroup_size);
|
||||
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX][use_push], "topk_moe_f32_late_softmax"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 1, use_push}, 1, true, true, device->subgroup_size);
|
||||
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][use_push], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 4, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, use_push}, 1, true, true, device->subgroup_size);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -8684,10 +8716,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|||
if (ctx->num_additional_fused_ops) {
|
||||
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
|
||||
GGML_ASSERT(idx < num_topk_moe_pipelines);
|
||||
topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
|
||||
// use n_experts from push constant if it's not equal to the power of two spec constant
|
||||
bool use_push = dst->ne[0] != (1u << idx);
|
||||
return ctx->device->pipeline_topk_moe[idx][mode][use_push];
|
||||
return ctx->device->pipeline_topk_moe[idx][use_push];
|
||||
}
|
||||
|
||||
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
|
||||
|
|
@ -10346,14 +10377,16 @@ static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& sub
|
|||
}
|
||||
|
||||
static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx) {
|
||||
topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
|
||||
topk_moe_mode mode = ctx->fused_topk_moe_mode;
|
||||
ggml_tensor * logits = cgraph->nodes[node_idx + 0]->src[0];
|
||||
ggml_tensor * weights = (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) ? cgraph->nodes[node_idx + 9] :
|
||||
(mode == TOPK_MOE_EARLY_SOFTMAX) ? cgraph->nodes[node_idx + 4] :
|
||||
cgraph->nodes[node_idx + 5];
|
||||
ggml_tensor * ids = (mode == TOPK_MOE_LATE_SOFTMAX) ? cgraph->nodes[node_idx + 1] : cgraph->nodes[node_idx + 3];
|
||||
ggml_tensor * bias = (mode == TOPK_MOE_SIGMOID_NORM_BIAS) ? cgraph->nodes[node_idx + 2]->src[1] : logits;
|
||||
ggml_tensor * weights = cgraph->nodes[node_idx + ctx->num_additional_fused_ops];
|
||||
ggml_tensor * ids = (mode == TOPK_MOE_SIGMOID_NORM_BIAS) ? cgraph->nodes[node_idx + 4] :
|
||||
(mode == TOPK_MOE_LATE_SOFTMAX) ? cgraph->nodes[node_idx + 1] :
|
||||
cgraph->nodes[node_idx + 3];
|
||||
|
||||
GGML_ASSERT(logits->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(bias->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(weights->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ids->type == GGML_TYPE_I32);
|
||||
|
||||
|
|
@ -10368,6 +10401,7 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|||
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
|
||||
|
||||
vk_subbuffer logits_buf = ggml_vk_tensor_subbuffer(ctx, logits);
|
||||
vk_subbuffer bias_buf = ggml_vk_tensor_subbuffer(ctx, bias);
|
||||
vk_subbuffer weights_buf = ggml_vk_tensor_subbuffer(ctx, weights);
|
||||
vk_subbuffer ids_buf = ggml_vk_tensor_subbuffer(ctx, ids);
|
||||
|
||||
|
|
@ -10375,18 +10409,45 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|||
pc.n_rows = n_rows;
|
||||
pc.n_experts_push = n_experts;
|
||||
pc.n_expert_used = n_expert_used;
|
||||
pc.clamp_min = -std::numeric_limits<float>::infinity();
|
||||
pc.clamp_max = std::numeric_limits<float>::infinity();
|
||||
if (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) {
|
||||
ggml_tensor * clamp = cgraph->nodes[node_idx + 7];
|
||||
GGML_ASSERT(clamp->op == GGML_OP_CLAMP);
|
||||
pc.clamp_min = ggml_get_op_params_f32(clamp, 0);
|
||||
pc.clamp_max = ggml_get_op_params_f32(clamp, 1);
|
||||
}
|
||||
if (mode == TOPK_MOE_SIGMOID_NORM_BIAS) {
|
||||
ggml_tensor * clamp = cgraph->nodes[node_idx + 8];
|
||||
GGML_ASSERT(clamp->op == GGML_OP_CLAMP);
|
||||
pc.clamp_min = ggml_get_op_params_f32(clamp, 0);
|
||||
pc.clamp_max = ggml_get_op_params_f32(clamp, 1);
|
||||
}
|
||||
|
||||
#define GATING_FUNC_SOFTMAX 0
|
||||
#define GATING_FUNC_SIGMOID 1
|
||||
#define GATING_FUNC_SOFTMAX_WEIGHT 2
|
||||
|
||||
pc.gating_func = mode == TOPK_MOE_SIGMOID_NORM_BIAS ? GATING_FUNC_SIGMOID :
|
||||
mode == TOPK_MOE_LATE_SOFTMAX ? GATING_FUNC_SOFTMAX_WEIGHT :
|
||||
GATING_FUNC_SOFTMAX;
|
||||
pc.has_bias = mode == TOPK_MOE_SIGMOID_NORM_BIAS;
|
||||
pc.with_norm = mode == TOPK_MOE_EARLY_SOFTMAX_NORM || mode == TOPK_MOE_SIGMOID_NORM_BIAS;
|
||||
if (ctx->fused_topk_moe_scale) {
|
||||
GGML_ASSERT(weights->op == GGML_OP_SCALE);
|
||||
pc.output_scale = ggml_get_op_params_f32(weights, 0);
|
||||
pc.output_bias = ggml_get_op_params_f32(weights, 1);
|
||||
} else {
|
||||
pc.output_scale = 1.0f;
|
||||
pc.output_bias = 0.0f;
|
||||
}
|
||||
|
||||
GGML_ASSERT(n_expert_used <= n_experts);
|
||||
|
||||
const uint32_t rows_per_block = 4;
|
||||
std::array<uint32_t, 3> elements = { CEIL_DIV(n_rows, rows_per_block), 1, 1 };
|
||||
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {logits_buf, weights_buf, ids_buf}, pc, elements);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {logits_buf, bias_buf, weights_buf, ids_buf}, pc, elements);
|
||||
}
|
||||
|
||||
static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_cgraph * cgraph, int node_idx, bool backprop) {
|
||||
|
|
@ -12128,6 +12189,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|||
|
||||
break;
|
||||
case GGML_OP_UNARY:
|
||||
if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) {
|
||||
ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx);
|
||||
break;
|
||||
}
|
||||
|
||||
switch (ggml_get_unary_op(node)) {
|
||||
case GGML_UNARY_OP_EXP:
|
||||
case GGML_UNARY_OP_SILU:
|
||||
|
|
@ -12175,7 +12241,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|||
|
||||
break;
|
||||
case GGML_OP_SOFT_MAX:
|
||||
if (ctx->num_additional_fused_ops) {
|
||||
if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) {
|
||||
ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx);
|
||||
} else {
|
||||
ggml_vk_soft_max(ctx, compute_ctx, src0, src1, src2, node);
|
||||
|
|
@ -12195,7 +12261,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|||
|
||||
break;
|
||||
case GGML_OP_ARGSORT:
|
||||
if (ctx->num_additional_fused_ops) {
|
||||
if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) {
|
||||
ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx);
|
||||
} else {
|
||||
ggml_vk_argsort(ctx, compute_ctx, src0, node);
|
||||
|
|
@ -13048,6 +13114,24 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc
|
|||
get_rows = cgraph->nodes[node_idx + 4];
|
||||
argsort = cgraph->nodes[node_idx + 2];
|
||||
break;
|
||||
case TOPK_MOE_SIGMOID_NORM_BIAS:
|
||||
softmax = cgraph->nodes[node_idx + 0]; // really sigmoid
|
||||
weights = cgraph->nodes[node_idx + 10];
|
||||
get_rows = cgraph->nodes[node_idx + 5];
|
||||
argsort = cgraph->nodes[node_idx + 3];
|
||||
if (ggml_get_unary_op(softmax) != GGML_UNARY_OP_SIGMOID) {
|
||||
return false;
|
||||
}
|
||||
// bias is expected to be 1D
|
||||
if (ggml_nrows(cgraph->nodes[node_idx + 2]->src[1]) != 1 ||
|
||||
!ggml_is_contiguous(cgraph->nodes[node_idx + 2]->src[1])) {
|
||||
return false;
|
||||
}
|
||||
// sigmoid fusion seems to generate infinities on moltenvk
|
||||
if (ctx->device->driver_id == vk::DriverId::eMoltenvk) {
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
case TOPK_MOE_EARLY_SOFTMAX:
|
||||
softmax = cgraph->nodes[node_idx + 0];
|
||||
weights = cgraph->nodes[node_idx + 4];
|
||||
|
|
@ -13071,26 +13155,28 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc
|
|||
probs = probs->src[0];
|
||||
ggml_tensor * selection_probs = argsort->src[0];
|
||||
|
||||
if (probs != selection_probs) {
|
||||
if (probs != selection_probs && mode != TOPK_MOE_SIGMOID_NORM_BIAS) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const float * op_params = (const float *)softmax->op_params;
|
||||
|
||||
float scale = op_params[0];
|
||||
float max_bias = op_params[1];
|
||||
|
||||
if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (scale != 1.0f || max_bias != 0.0f) {
|
||||
return false;
|
||||
}
|
||||
if (softmax->op == GGML_OP_SOFT_MAX) {
|
||||
const float * op_params = (const float *)softmax->op_params;
|
||||
|
||||
// don't fuse when masks or sinks are present
|
||||
if (softmax->src[1] || softmax->src[2]) {
|
||||
return false;
|
||||
float scale = op_params[0];
|
||||
float max_bias = op_params[1];
|
||||
|
||||
if (scale != 1.0f || max_bias != 0.0f) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// don't fuse when masks or sinks are present
|
||||
if (softmax->src[1] || softmax->src[2]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
const int n_expert = softmax->ne[0];
|
||||
|
|
@ -13363,6 +13449,8 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|||
total_mul_mat_bytes += bytes;
|
||||
}
|
||||
|
||||
ctx->fused_topk_moe_mode = TOPK_MOE_COUNT;
|
||||
ctx->fused_topk_moe_scale = false;
|
||||
const char *fusion_string {};
|
||||
if (!ctx->device->disable_fusion) {
|
||||
uint32_t num_adds = ggml_vk_fuse_multi_add(ctx, cgraph, i);
|
||||
|
|
@ -13408,13 +13496,23 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|||
ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1;
|
||||
// view of argsort writes to memory
|
||||
ctx->fused_ops_write_mask |= 1 << 3;
|
||||
ctx->fused_topk_moe_mode = TOPK_MOE_EARLY_SOFTMAX_NORM;
|
||||
fusion_string = "TOPK_MOE_EARLY_SOFTMAX_NORM";
|
||||
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_sigmoid_norm_bias, { i + 4, i + 10 }) &&
|
||||
ggml_check_edges(cgraph, i, topk_moe_sigmoid_norm_bias_edges) &&
|
||||
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_SIGMOID_NORM_BIAS)) {
|
||||
ctx->num_additional_fused_ops = topk_moe_sigmoid_norm_bias.size() - 1;
|
||||
// view of argsort writes to memory
|
||||
ctx->fused_ops_write_mask |= 1 << 4;
|
||||
ctx->fused_topk_moe_mode = TOPK_MOE_SIGMOID_NORM_BIAS;
|
||||
fusion_string = "TOPK_MOE_SIGMOID_NORM_BIAS";
|
||||
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) &&
|
||||
ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) &&
|
||||
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
|
||||
ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1;
|
||||
// view of argsort writes to memory
|
||||
ctx->fused_ops_write_mask |= 1 << 3;
|
||||
ctx->fused_topk_moe_mode = TOPK_MOE_EARLY_SOFTMAX;
|
||||
fusion_string = "TOPK_MOE_EARLY_SOFTMAX";
|
||||
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) &&
|
||||
ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) &&
|
||||
|
|
@ -13422,8 +13520,17 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|||
ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1;
|
||||
// view of argsort writes to memory
|
||||
ctx->fused_ops_write_mask |= 1 << 1;
|
||||
ctx->fused_topk_moe_mode = TOPK_MOE_LATE_SOFTMAX;
|
||||
fusion_string = "TOPK_MOE_LATE_SOFTMAX";
|
||||
}
|
||||
if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) {
|
||||
// Look for an additional scale op to fuse - occurs in deepseek2 and nemotron3 nano.
|
||||
if (ggml_can_fuse_subgraph(cgraph, i + ctx->num_additional_fused_ops - 1, { GGML_OP_DIV, GGML_OP_RESHAPE, GGML_OP_SCALE }, { i + ctx->num_additional_fused_ops + 1 }) ||
|
||||
ggml_can_fuse_subgraph(cgraph, i + ctx->num_additional_fused_ops, { GGML_OP_GET_ROWS, GGML_OP_SCALE }, { i + ctx->num_additional_fused_ops + 1 })) {
|
||||
ctx->fused_topk_moe_scale = true;
|
||||
ctx->num_additional_fused_ops++;
|
||||
}
|
||||
}
|
||||
}
|
||||
ctx->fused_ops_write_mask |= 1 << ctx->num_additional_fused_ops;
|
||||
|
||||
|
|
@ -13602,6 +13709,9 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
|
|||
if (keep_pattern(topk_moe_early_softmax_norm)) {
|
||||
continue;
|
||||
}
|
||||
if (keep_pattern(topk_moe_sigmoid_norm_bias)) {
|
||||
continue;
|
||||
}
|
||||
if (keep_pattern(topk_moe_early_softmax)) {
|
||||
continue;
|
||||
}
|
||||
|
|
@ -13628,6 +13738,7 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
|
|||
}
|
||||
// Don't pull forward nodes from fusion patterns
|
||||
if (match_pattern(topk_moe_early_softmax_norm, j) ||
|
||||
match_pattern(topk_moe_sigmoid_norm_bias, j) ||
|
||||
match_pattern(topk_moe_early_softmax, j) ||
|
||||
match_pattern(topk_moe_late_softmax, j)) {
|
||||
continue;
|
||||
|
|
|
|||
|
|
@ -7,6 +7,10 @@
|
|||
|
||||
#include "types.glsl"
|
||||
|
||||
#define GATING_FUNC_SOFTMAX 0
|
||||
#define GATING_FUNC_SIGMOID 1
|
||||
#define GATING_FUNC_SOFTMAX_WEIGHT 2
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
uint n_rows;
|
||||
|
|
@ -14,15 +18,18 @@ layout (push_constant) uniform parameter
|
|||
uint n_expert_used;
|
||||
float clamp_min;
|
||||
float clamp_max;
|
||||
uint gating_func;
|
||||
uint has_bias;
|
||||
uint with_norm;
|
||||
float output_scale;
|
||||
float output_bias;
|
||||
};
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
|
||||
|
||||
layout(constant_id = 0) const uint WARP_SIZE = 32;
|
||||
layout(constant_id = 1) const uint n_experts_spec = 512;
|
||||
layout(constant_id = 2) const bool with_norm = true;
|
||||
layout(constant_id = 3) const bool late_softmax = false;
|
||||
layout(constant_id = 4) const bool nexperts_use_push = false;
|
||||
layout(constant_id = 2) const bool nexperts_use_push = false;
|
||||
|
||||
uint n_experts = nexperts_use_push ? n_experts_push : n_experts_spec;
|
||||
|
||||
|
|
@ -31,8 +38,9 @@ uint n_experts = nexperts_use_push ? n_experts_push : n_experts_spec;
|
|||
const uint experts_per_thread = CEIL_DIV(n_experts_spec, WARP_SIZE);
|
||||
|
||||
layout (binding = 0, std430) readonly buffer Logits {float logits[];};
|
||||
layout (binding = 1, std430) writeonly buffer Weights {float weights[];};
|
||||
layout (binding = 2, std430) writeonly buffer Ids {uint ids[];};
|
||||
layout (binding = 1, std430) readonly buffer BiasProbs {float bias[];};
|
||||
layout (binding = 2, std430) writeonly buffer Weights {float weights[];};
|
||||
layout (binding = 3, std430) writeonly buffer Ids {uint ids[];};
|
||||
|
||||
const float INFINITY = 1.0 / 0.0;
|
||||
|
||||
|
|
@ -87,20 +95,40 @@ void main() {
|
|||
}
|
||||
|
||||
const uint logits_offset = n_experts * row;
|
||||
const uint bias_offset = 0; // 1D
|
||||
const uint weights_offset = n_expert_used * row;
|
||||
const uint ids_offset = n_experts * row;
|
||||
const uint lane = gl_SubgroupInvocationID;
|
||||
|
||||
float wt[experts_per_thread];
|
||||
float probs[experts_per_thread];
|
||||
|
||||
[[unroll]]
|
||||
for (uint i = 0; i < n_experts; i += WARP_SIZE) {
|
||||
const uint expert = i + lane;
|
||||
wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[logits_offset + expert] : -INFINITY;
|
||||
probs[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[logits_offset + expert] : -INFINITY;
|
||||
}
|
||||
|
||||
if (!late_softmax) {
|
||||
softmax_warp_inplace(wt, n_experts, lane, nexperts_use_push);
|
||||
if (gating_func == GATING_FUNC_SOFTMAX) {
|
||||
softmax_warp_inplace(probs, n_experts, lane, nexperts_use_push);
|
||||
} else if (gating_func == GATING_FUNC_SIGMOID) {
|
||||
[[unroll]]
|
||||
for (int i = 0; i < experts_per_thread; i++) {
|
||||
probs[i] = 1.f / (1.f + exp(-probs[i]));
|
||||
}
|
||||
}
|
||||
|
||||
float selection_probs[experts_per_thread];
|
||||
if (has_bias != 0) {
|
||||
[[unroll]]
|
||||
for (uint i = 0; i < n_experts; i += WARP_SIZE) {
|
||||
const uint expert = i + lane;
|
||||
selection_probs[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? probs[i / WARP_SIZE] + bias[bias_offset + expert] : -INFINITY;
|
||||
}
|
||||
} else {
|
||||
[[unroll]]
|
||||
for (int i = 0; i < experts_per_thread; i++) {
|
||||
selection_probs[i] = probs[i];
|
||||
}
|
||||
}
|
||||
|
||||
// at this point, each thread holds a portion of softmax,
|
||||
|
|
@ -117,14 +145,16 @@ void main() {
|
|||
}
|
||||
|
||||
for (int k = 0; k < n_expert_used; k++) {
|
||||
float max_val = wt[0];
|
||||
float max_val = probs[0];
|
||||
float max_val_s = selection_probs[0];
|
||||
uint max_expert = lane;
|
||||
|
||||
[[unroll]]
|
||||
for (int i = 1; i < experts_per_thread; i++) {
|
||||
const uint expert = lane + i * WARP_SIZE;
|
||||
if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) {
|
||||
max_val = wt[i];
|
||||
if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && selection_probs[i] > max_val_s) {
|
||||
max_val = probs[i];
|
||||
max_val_s = selection_probs[i];
|
||||
max_expert = expert;
|
||||
}
|
||||
}
|
||||
|
|
@ -132,9 +162,11 @@ void main() {
|
|||
[[unroll]]
|
||||
for (uint mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
|
||||
const float val = subgroupShuffleXor(max_val, mask);
|
||||
const float val_s = subgroupShuffleXor(max_val_s, mask);
|
||||
const uint expert = subgroupShuffleXor(max_expert, mask);
|
||||
if (val > max_val || (val == max_val && expert < max_expert)) {
|
||||
if (val_s > max_val_s || (val_s == max_val_s && expert < max_expert)) {
|
||||
max_val = val;
|
||||
max_val_s = val_s;
|
||||
max_expert = expert;
|
||||
}
|
||||
}
|
||||
|
|
@ -144,16 +176,14 @@ void main() {
|
|||
}
|
||||
|
||||
if ((max_expert & (WARP_SIZE - 1)) == lane) {
|
||||
wt[max_expert / WARP_SIZE] = -INFINITY;
|
||||
selection_probs[max_expert / WARP_SIZE] = -INFINITY;
|
||||
|
||||
ids[ids_offset + k] = max_expert;
|
||||
if (with_norm) {
|
||||
wt_sum += max_val;
|
||||
}
|
||||
wt_sum += max_val;
|
||||
}
|
||||
}
|
||||
|
||||
if (with_norm) {
|
||||
if (with_norm != 0) {
|
||||
wt_sum = subgroupAdd(wt_sum);
|
||||
wt_sum = clamp(wt_sum, clamp_min, clamp_max);
|
||||
const float inv_sum = 1.0f / wt_sum;
|
||||
|
|
@ -164,7 +194,7 @@ void main() {
|
|||
}
|
||||
}
|
||||
|
||||
if (late_softmax) {
|
||||
if (gating_func == GATING_FUNC_SOFTMAX_WEIGHT) {
|
||||
softmax_warp_inplace(output_weights, n_expert_used, lane, true);
|
||||
}
|
||||
|
||||
|
|
@ -172,7 +202,7 @@ void main() {
|
|||
for (uint i = 0; i < experts_per_thread; ++i) {
|
||||
uint idx = i * WARP_SIZE + lane;
|
||||
if (idx < n_expert_used) {
|
||||
weights[weights_offset + idx] = output_weights[i];
|
||||
weights[weights_offset + idx] = output_scale * output_weights[i] + output_bias;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1158,6 +1158,7 @@ struct test_case {
|
|||
}
|
||||
|
||||
virtual bool run_whole_graph() { return false; }
|
||||
virtual std::vector<ggml_tensor *> fusion_test_nodes() { return {}; }
|
||||
|
||||
ggml_cgraph * gf = nullptr;
|
||||
ggml_cgraph * gb = nullptr;
|
||||
|
|
@ -1391,7 +1392,13 @@ struct test_case {
|
|||
GGML_UNUSED(index);
|
||||
};
|
||||
|
||||
const bool cmp_ok = ggml_backend_compare_graph_backend(backend1, backend2, gf, callback, &ud, run_whole_graph() ? out : nullptr);
|
||||
std::vector<ggml_tensor *> fused_nodes_to_verify = fusion_test_nodes();
|
||||
if (fused_nodes_to_verify.size() == 0 && run_whole_graph()) {
|
||||
fused_nodes_to_verify.push_back(out);
|
||||
}
|
||||
const bool cmp_ok = ggml_backend_compare_graph_backend(backend1, backend2, gf, callback, &ud,
|
||||
run_whole_graph() ? fused_nodes_to_verify.data() : nullptr,
|
||||
fused_nodes_to_verify.size());
|
||||
|
||||
ggml_backend_buffer_free(buf);
|
||||
|
||||
|
|
@ -5180,6 +5187,8 @@ struct test_topk_moe : public test_case {
|
|||
const bool bias_probs;
|
||||
const MoeGatingFunc gating_func;
|
||||
const float scale_w;
|
||||
ggml_tensor * weights {};
|
||||
ggml_tensor * selected_experts {};
|
||||
|
||||
test_topk_moe(std::array<int64_t, 4> ne = { 10, 5, 1, 1 },
|
||||
int n_expert_used = 1,
|
||||
|
|
@ -5217,16 +5226,16 @@ struct test_topk_moe : public test_case {
|
|||
|
||||
ggml_tensor * selection_probs = probs;
|
||||
if (bias_probs) {
|
||||
ggml_tensor * exp_probs_b = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data());
|
||||
ggml_tensor * exp_probs_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ne[0]);
|
||||
ggml_set_name(exp_probs_b, "exp_probs_b");
|
||||
selection_probs = ggml_add(ctx, probs, exp_probs_b);
|
||||
ggml_set_name(selection_probs, "selection_probs");
|
||||
}
|
||||
|
||||
ggml_tensor * selected_experts = ggml_argsort_top_k(ctx, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
|
||||
selected_experts = ggml_argsort_top_k(ctx, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
|
||||
ggml_set_name(selected_experts, "selected_experts");
|
||||
|
||||
ggml_tensor * weights = ggml_get_rows(ctx, ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
|
||||
weights = ggml_get_rows(ctx, ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
|
||||
ggml_set_name(weights, "weights");
|
||||
|
||||
if (gating_func == GATING_FUNC_SOFTMAX_WEIGHT) {
|
||||
|
|
@ -5252,6 +5261,21 @@ struct test_topk_moe : public test_case {
|
|||
ggml_set_name(weights, "weights");
|
||||
return weights;
|
||||
}
|
||||
// Verify two outputs
|
||||
std::vector<ggml_tensor *> fusion_test_nodes() override { return { selected_experts, weights }; }
|
||||
|
||||
// allow output in arbitrary order
|
||||
double err(const float * a, const float * b, size_t n) override {
|
||||
std::vector<float> a2(n);
|
||||
std::vector<float> b2(n);
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
a2[i] = a[i];
|
||||
b2[i] = b[i];
|
||||
}
|
||||
std::sort(a2.begin(), a2.end());
|
||||
std::sort(b2.begin(), b2.end());
|
||||
return nmse(a2.data(), b2.data(), n);
|
||||
}
|
||||
};
|
||||
|
||||
struct test_mul_mat_vec_fusion : public test_case {
|
||||
|
|
|
|||
Loading…
Reference in New Issue