sampling : add min-p backend sampler

This commit is contained in:
Daniel Bevenius 2025-11-26 10:50:58 +01:00
parent f23b306cc5
commit b45d504e70
No known key found for this signature in database
4 changed files with 188 additions and 0 deletions

View File

@ -173,6 +173,7 @@ static bool sampler_backend_supported(enum common_sampler_type type) {
switch (type) {
case COMMON_SAMPLER_TYPE_TOP_K:
case COMMON_SAMPLER_TYPE_TEMPERATURE:
case COMMON_SAMPLER_TYPE_MIN_P:
return true;
default:
return false;
@ -325,6 +326,12 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
}
backend_idx++;
break;
case COMMON_SAMPLER_TYPE_MIN_P:
if (params.min_p > 0.0f) {
llama_sampler_chain_add(result->backend_chain, llama_sampler_backend_init_min_p(params.min_p));
}
backend_idx++;
break;
default:
GGML_ASSERT(false && "unsupported backend sampler");
}
@ -468,6 +475,12 @@ struct llama_sampler * common_sampler_backend_init(const struct llama_model * mo
}
backend_idx++;
break;
case COMMON_SAMPLER_TYPE_MIN_P:
if (params.min_p > 0.0f) {
llama_sampler_chain_add(chain, llama_sampler_backend_init_min_p(params.min_p));
}
backend_idx++;
break;
default:
GGML_ASSERT(false && "unsupported backend sampler");
}

View File

@ -1405,6 +1405,9 @@ extern "C" {
/// @details Distribution sampling on backend - final sampling step that selects a token
LLAMA_API struct llama_sampler * llama_sampler_backend_init_dist(uint32_t seed);
/// @details Min-P filtering on backend - filter tokens with a probability less than p times the maximum probability.
LLAMA_API struct llama_sampler * llama_sampler_backend_init_min_p(float p);
// Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise
LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl);

View File

@ -488,3 +488,118 @@ struct llama_sampler * llama_sampler_backend_init_logit_bias(int32_t n_vocab,
return sampler;
}
struct llama_sampler_backend_min_p_ctx {
float p;
// Only required for checking operation support and can be removed later.
ggml_backend_dev_t device;
};
static void llama_sampler_backend_min_p_init_ggml(
struct llama_sampler * smpl,
ggml_backend_buffer_type_t buft) {
auto * sctx = (llama_sampler_backend_min_p_ctx *) smpl->ctx;
sctx->device = ggml_backend_buft_get_device(buft);
}
static void llama_sampler_backend_min_p_apply_ggml(
struct llama_sampler * smpl,
struct ggml_context * ctx,
struct ggml_cgraph * gf,
struct llama_sampler_ggml_data * ggml_data) {
GGML_UNUSED(gf);
auto * sctx = (llama_sampler_backend_min_p_ctx *) smpl->ctx;
struct ggml_tensor * softmax = ggml_soft_max(ctx, ggml_data->logits);
ggml_set_name(softmax, "softmax");
// Get the sorted indices of the softmax probabilities in descending order.
struct ggml_tensor * sorted_idx = ggml_argsort(ctx, softmax, GGML_SORT_ORDER_DESC);
ggml_set_name(sorted_idx, "sorted_idx");
// Reshape into a row vector.
struct ggml_tensor * softmax_rows = ggml_reshape_2d(ctx, softmax, 1, softmax->ne[0]);
ggml_set_name(softmax_rows, "softmax_rows");
// Get the sorted probabilities using the sorted indices so that we can get
// the max probability value, which will be the first entry in sorted_probs.
struct ggml_tensor * sorted_probs = ggml_get_rows(ctx, softmax_rows, sorted_idx);
ggml_set_name(sorted_probs, "sorted_probs");
// Get the max probability value from sorted_probs.
struct ggml_tensor * p_max = ggml_view_1d(ctx, sorted_probs, 1, 0);
ggml_set_name(p_max, "p_max");
// Calculate the threshold value.
struct ggml_tensor * threshold = ggml_scale(ctx, p_max, sctx->p);
ggml_set_name(threshold, "min_p_threshold");
// Broadcast the threshold to match the shape of softmax.
struct ggml_tensor * threshold_b = ggml_repeat(ctx, threshold, softmax);
ggml_set_name(threshold_b, "min_p_threshold_b");
// Subtract the threshold from softmax probabilities.
struct ggml_tensor * sub = ggml_sub(ctx, softmax, threshold_b);
// Create a mask where probabilities below the threshold are 0 (discard),
// and others are 1 (keep).
struct ggml_tensor * mask = ggml_step(ctx, sub);
ggml_set_name(mask, "min_p_mask");
// Use ggml_scale_bias (output = (a * s) + b) which in this case becomes:
// min_p_bias = (mask * 1e9f) - 1e9f.
// So entries in the mask that we want to discard will become -1e9f, and
// others will be 0 (meaning that will not effect the logits).
const float large_val = 1e9f;
struct ggml_tensor * min_p_bias = ggml_scale_bias(ctx, mask, large_val, -large_val);
ggml_set_name(min_p_bias, "min_p_bias");
// Add the min_p bias to the logits.
ggml_data->logits = ggml_add(ctx, ggml_data->logits, min_p_bias);
ggml_set_name(ggml_data->logits, "min_p_logits");
ggml_build_forward_expand(gf, ggml_data->logits);
}
static const char * llama_sampler_backend_min_p_name(const struct llama_sampler *) {
return "backend-min-p";
}
static void llama_sampler_backend_min_p_free(struct llama_sampler * smpl) {
auto * sctx = (llama_sampler_backend_min_p_ctx *) smpl->ctx;
delete sctx;
}
static struct llama_sampler * llama_sampler_backend_min_p_clone(const struct llama_sampler * smpl) {
auto * sctx = (llama_sampler_backend_min_p_ctx *) smpl->ctx;
return llama_sampler_backend_init_min_p(sctx->p);
}
struct llama_sampler * llama_sampler_backend_init_min_p(float p) {
static const llama_sampler_i iface = {
/*.name =*/ llama_sampler_backend_min_p_name,
/*.accept =*/ nullptr,
/*.apply =*/ nullptr,
/*.reset =*/ nullptr,
/*.clone =*/ llama_sampler_backend_min_p_clone,
/*.free =*/ llama_sampler_backend_min_p_free,
/*.apply_ggml =*/ llama_sampler_backend_min_p_apply_ggml,
/*.accept_ggml =*/ nullptr,
/*.set_input_ggml =*/ nullptr,
/*.init_ggml =*/ llama_sampler_backend_min_p_init_ggml,
};
auto * sctx = new llama_sampler_backend_min_p_ctx {
/*.p =*/ p,
/*.device =*/ nullptr,
};
auto * sampler = new llama_sampler {
/*.iface =*/ &iface,
/*.ctx =*/ sctx,
};
return sampler;
}

View File

@ -416,6 +416,62 @@ static void test_backend_temp_sampling(const char * model_path) {
}
static void test_backend_min_p_sampling(const char * model_path) {
test_model_context test_ctx;
const int seq_id = 0;
const float p = 0.1;
struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params);
llama_sampler_chain_add(backend_sampler_chain, llama_sampler_backend_init_min_p(p));
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain }};
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
return;
}
if (!test_ctx.decode({{seq_id, "Hello"}})) {
return;
}
int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
float * logits = llama_get_backend_sampled_logits_ith(test_ctx.ctx, batch_idx);
uint32_t n_logits = llama_get_backend_sampled_logits_count_ith(test_ctx.ctx, batch_idx);
// Print the logits that are above the min-p threshold
std::vector<float> filtered_logits;
for (size_t i = 0; i < n_logits; ++i) {
if (logits[i] > -1e9f) {
filtered_logits.push_back(logits[i]);
//printf("min_p logit[%zu] = %.6f\n", i, logits[i]);
}
}
GGML_ASSERT(filtered_logits.size() < (size_t) test_ctx.n_vocab);
// Sample using CPU sampler for verification to inspect they are reasonable
struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
struct llama_sampler * chain = llama_sampler_chain_init(chain_params);
llama_sampler_chain_add(chain, llama_sampler_init_dist(88));
llama_token token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx);
const std::string token_str = test_ctx.token_to_piece(token, false);
printf("min-p cpu sampled token id:%d, string: '%s'\n", token, token_str.c_str());
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
// Decode and sampler 10 more tokens
for (int i = 0; i < 10; i++) {
int32_t loop_idx = test_ctx.idx_for_seq(seq_id);
llama_token token = llama_sampler_sample(chain, test_ctx.ctx, loop_idx);
printf("min-p gen step %d: token id :%5.d, string: %s\n", i, token, test_ctx.token_to_piece(token, false).c_str());
test_ctx.decode_token(token, 0);
}
printf("min-p sampling test PASSED\n");
llama_sampler_free(chain);
}
static void test_backend_multi_sequence_sampling(const char * model_path) {
test_model_context test_ctx;
@ -772,6 +828,7 @@ static const backend_test_case BACKEND_TESTS[] = {
{ "set_sampler", test_backend_set_sampler, true },
{ "max_outputs", test_backend_max_outputs, true },
{ "mixed", test_backend_mixed_sampling, true },
{ "min_p", test_backend_min_p_sampling, true },
};
struct backend_cli_args {