sampling : use logits directly for min-p filtering

This commit is contained in:
Daniel Bevenius 2025-11-28 16:12:05 +01:00
parent 333da805fe
commit 8cac9dee45
No known key found for this signature in database
1 changed files with 11 additions and 15 deletions

View File

@ -512,31 +512,27 @@ static void llama_sampler_backend_min_p_apply_ggml(
auto * sctx = (llama_sampler_backend_min_p_ctx *) smpl->ctx;
struct ggml_tensor * softmax = ggml_soft_max(ctx, ggml_data->logits);
ggml_set_name(softmax, "softmax");
// Get the sorted indices of the softmax probabilities in descending order.
struct ggml_tensor * max_idx = ggml_argmax(ctx, softmax);
struct ggml_tensor * max_idx = ggml_argmax(ctx, ggml_data->logits);
ggml_set_name(max_idx, "max_idx");
struct ggml_tensor * softmax_rows = ggml_reshape_2d(ctx, softmax, 1, softmax->ne[0]);
ggml_set_name(softmax_rows, "softmax_rows");
struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, ggml_data->logits, 1, ggml_data->logits->ne[0]);
ggml_set_name(logits_rows, "logits_rows");
struct ggml_tensor * max_prob = ggml_get_rows(ctx, softmax_rows, max_idx);
ggml_set_name(max_prob, "max_prob");
struct ggml_tensor * max_logit = ggml_get_rows(ctx, logits_rows, max_idx);
ggml_set_name(max_logit, "max_logit");
// Calculate the threshold value.
struct ggml_tensor * threshold = ggml_scale(ctx, max_prob, sctx->p);
struct ggml_tensor * threshold = ggml_scale_bias(ctx, max_logit, 1.0f, logf(sctx->p));
ggml_set_name(threshold, "min_p_threshold");
// Broadcast the threshold to match the shape of softmax.
struct ggml_tensor * threshold_b = ggml_repeat(ctx, threshold, softmax);
// Broadcast the threshold to match the shape of logits.
struct ggml_tensor * threshold_b = ggml_repeat(ctx, threshold, ggml_data->logits);
ggml_set_name(threshold_b, "min_p_threshold_b");
// Subtract the threshold from softmax probabilities.
struct ggml_tensor * sub = ggml_sub(ctx, softmax, threshold_b);
// Subtract the threshold from logits.
struct ggml_tensor * sub = ggml_sub(ctx, ggml_data->logits, threshold_b);
// Create a mask where probabilities below the threshold are 0 (discard),
// Create a mask where logits below the threshold are 0 (discard),
// and others are 1 (keep).
struct ggml_tensor * mask = ggml_step(ctx, sub);
ggml_set_name(mask, "min_p_mask");