diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 42ccb5b76a..43620df780 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -973,7 +973,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn( // mask out the other groups selection_probs = ggml_get_rows(ctx0, selection_groups, expert_groups); // [n_exp_per_group, n_group_used, n_tokens] - selection_probs = ggml_set_rows(ctx0, ggml_scale_bias(ctx0, selection_groups, 0.0f, -INFINITY), selection_probs, expert_groups); // [n_exp_per_group, n_expert_groups, n_tokens] + selection_probs = ggml_set_rows(ctx0, ggml_fill(ctx0, selection_groups, -INFINITY), selection_probs, expert_groups); // [n_exp_per_group, n_expert_groups, n_tokens] selection_probs = ggml_reshape_2d(ctx0, selection_probs, n_expert, n_tokens); // [n_expert, n_tokens] cb(selection_probs, "ffn_moe_probs_masked", il); }