sampling : use argmax for min-p sampling

This commit is contained in:
Daniel Bevenius 2025-11-27 07:38:44 +01:00
parent 7c2bfb352e
commit d9d736102b
No known key found for this signature in database
1 changed files with 5 additions and 12 deletions

View File

@ -516,24 +516,17 @@ static void llama_sampler_backend_min_p_apply_ggml(
ggml_set_name(softmax, "softmax");
// Get the sorted indices of the softmax probabilities in descending order.
struct ggml_tensor * sorted_idx = ggml_argsort(ctx, softmax, GGML_SORT_ORDER_DESC);
ggml_set_name(sorted_idx, "sorted_idx");
struct ggml_tensor * max_idx = ggml_argmax(ctx, softmax);
ggml_set_name(max_idx, "max_idx");
// Reshape into a row vector.
struct ggml_tensor * softmax_rows = ggml_reshape_2d(ctx, softmax, 1, softmax->ne[0]);
ggml_set_name(softmax_rows, "softmax_rows");
// Get the sorted probabilities using the sorted indices so that we can get
// the max probability value, which will be the first entry in sorted_probs.
struct ggml_tensor * sorted_probs = ggml_get_rows(ctx, softmax_rows, sorted_idx);
ggml_set_name(sorted_probs, "sorted_probs");
// Get the max probability value from sorted_probs.
struct ggml_tensor * p_max = ggml_view_1d(ctx, sorted_probs, 1, 0);
ggml_set_name(p_max, "p_max");
struct ggml_tensor * max_prob = ggml_get_rows(ctx, softmax_rows, max_idx);
ggml_set_name(max_prob, "max_prob");
// Calculate the threshold value.
struct ggml_tensor * threshold = ggml_scale(ctx, p_max, sctx->p);
struct ggml_tensor * threshold = ggml_scale(ctx, max_prob, sctx->p);
ggml_set_name(threshold, "min_p_threshold");
// Broadcast the threshold to match the shape of softmax.