From 217469f07f75db1594eb4b4ac3c4e723d6ea8553 Mon Sep 17 00:00:00 2001 From: Oliver Simons Date: Mon, 1 Dec 2025 15:24:32 +0100 Subject: [PATCH] Make backend's top_p sampler inclusive In addition to match the algorithm proposed in the original [paper](https://arxiv.org/abs/1904.09751), this resolves the edge-case where `max_p is > top_p` for a single logit, where the mask would otherwise be empty (and we thus sample from the whole vocabulary with equal likelihood) --- src/llama-sampling.cpp | 14 +++++++++++++- tests/test-backend-sampler.cpp | 1 + 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 8c1570761d..fd4e770e3c 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1138,7 +1138,6 @@ static void llama_sampler_top_p_backend_apply( struct ggml_tensor * cdf = ggml_cumsum(ctx, softmax); ggml_set_name(cdf, "top_p_cdf"); - // TODO: Make it inclusive of probability p // Invert CDF and add top-p value so that ggml_step yields 1 for values we want to keep struct ggml_tensor * cdf_scaled = ggml_scale_bias(ctx, cdf, -1.0f, sctx->p); ggml_set_name(cdf_scaled, "top_p_cdf_scaled"); @@ -1146,6 +1145,19 @@ static void llama_sampler_top_p_backend_apply( struct ggml_tensor * mask = ggml_step(ctx, cdf_scaled); ggml_set_name(mask, "top_p_mask"); + // Taking the sum of the mask gives us the sum of elements after the threshold + // we are interested in. + struct ggml_tensor * idxf = ggml_sum(ctx, mask); + ggml_set_name(idxf, "dist_index_f32"); + + // Make top-p inclusive (i.e. return all values such that cum_sum/cdf >= p) + struct ggml_tensor * mask_reshaped = ggml_reshape_2d(ctx, mask, 1, mask->ne[0]); + // construct ones tensor to set the value in the mask + struct ggml_tensor * ones = ggml_dup_tensor(ctx, mask_reshaped); + ones = ggml_clamp(ctx, ones, 1.0f, 1.0f); + mask_reshaped = ggml_set_rows(ctx, mask_reshaped, ones, ggml_cast(ctx, ggml_repeat(ctx, idxf, mask), GGML_TYPE_I32)); + mask = ggml_reshape_1d(ctx, mask_reshaped, mask->ne[0]); + // Use ggml_scale_bias (output = (a * s) + b) which in this case becomes: // top_p_bias = (mask * 1e9f) - 1e9f. // So entries in the mask that we want to discard will become -1e9f, and diff --git a/tests/test-backend-sampler.cpp b/tests/test-backend-sampler.cpp index 6251a5ab1c..f185cebe9d 100644 --- a/tests/test-backend-sampler.cpp +++ b/tests/test-backend-sampler.cpp @@ -512,6 +512,7 @@ static void test_backend_top_p_sampling(const char * model_path) { } } GGML_ASSERT(filtered_logits.size() < (size_t) test_ctx.n_vocab); + GGML_ASSERT(filtered_logits.size() > 0); // Sample using CPU sampler for verification to inspect they are reasonable struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();