llama : add adaptive-p sampler (#17927)

* initial commit for branch

* simplify constants

* add params to `struct common_params_sampling`, add reference to PR

* explicitly clamp `min_target` and `max_target` to `[0.0, 1.0]`

* add args, rename `queue_size` -> `window_size`

* improved comments

* minor

* remove old unused code from algorithm

* minor

* add power law case to `common_sampler_init`, add sampler name mappings

* clarify behaviour when `window_size = 0`

* add missing enums

* remove `target_range` param, make `target == 1` no-op, cleanup code

* oops, straggler

* add missing parameters in `server-task.cpp`

* copy from author

ref:
https://gist.github.com/MrJackSpade/9be99c7efbba7b95a41377e123b7b069

* remove old debug log, style nit

* fix compiler warning, add commented-out logging per token

* re-write + change parameters + simplify

* oops forgot args.cpp

* fix leftover `window_size`

* add missing values to `common_params_sampling::print()`

* with logging

* does this fix it?

* no, but does this?

* update default decay

* optimize

* fix bad merge

my git skills are lacking

* silence `missing initializer for member`

* update default decay to 0.9

* fix logging

* format (double)

* add power law to the new `samplers` vector

* log sampler init values

* improve logging messages in llama_sampler_power_law

* remove extraneous logging

* simplify target computation

last commit with debug logging!

* remove debug logging, explicitly clamp params at init

* add `use_power_law` flag + logic, minor cleanup

* update `power-law` -> `adaptive-p`

* fix cold start EMA

- `ctx->weighted_sum` is now initialized and reset to `target / (1.0f -
clamped_decay)`
- `ctx->total_weight` is now initialized and reset to `1.0f / (1.0f -
clamped_decay)`

this fixes a "cold start" problem with the moving average

* update `SHARPNESS` constant to `10.0f`

* minor style fixes

no functional changes

* minor style fixes cont.

* update `llama_sampler_adaptive_p_i` for backend sampling (ref: #17004)

* separate into `apply` + `accept` functions

* `pending_token_idx`: switch from `llama_token` to `int32`

functionally identical (`llama.h` has `typedef int32_t llama_token;`),
but its more correct now

* don't transform logits <= -1e9f

* fix masking in backend top-p, min-p

* address review comments

* typo in comments `RND` -> `RNG`

* add docs

* add recommended values in completion docs

* address PR feedback

* remove trailing whitespace (for CI `editorconfig`)

* add to adaptive-p to `common_sampler_types_from_chars`
This commit is contained in:
ddh0 2026-01-15 11:16:29 -06:00 committed by GitHub
parent a04c2b06a3
commit 13f1e4a9ca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 297 additions and 52 deletions

View File

@ -1729,6 +1729,26 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
}
).set_sparam());
add_opt(common_arg(
{"--adaptive-target"}, "N",
string_format("adaptive-p: select tokens near this probability (valid range 0.0 "
"to 1.0; negative = disabled) (default: %.2f)\n"
"[(more info)](https://github.com/ggml-org/llama.cpp/pull/17927)",
(double)params.sampling.adaptive_target),
[](common_params & params, const std::string & value) {
params.sampling.adaptive_target = std::stof(value);
}
).set_sparam());
add_opt(common_arg(
{"--adaptive-decay"}, "N",
string_format("adaptive-p: decay rate for target adaptation over time. lower values "
"are more reactive, higher values are more stable.\n"
"(valid range 0.0 to 0.99) (default: %.2f)",
(double)params.sampling.adaptive_decay),
[](common_params & params, const std::string & value) {
params.sampling.adaptive_decay = std::stof(value);
}
).set_sparam());
add_opt(common_arg(
{"--dynatemp-range"}, "N",
string_format("dynamic temperature range (default: %.1f, 0.0 = disabled)", (double)params.sampling.dynatemp_range),

View File

@ -119,6 +119,7 @@ enum common_sampler_type {
COMMON_SAMPLER_TYPE_INFILL = 9,
COMMON_SAMPLER_TYPE_PENALTIES = 10,
COMMON_SAMPLER_TYPE_TOP_N_SIGMA = 11,
COMMON_SAMPLER_TYPE_ADAPTIVE_P = 12,
};
// dimensionality reduction methods, used by cvector-generator
@ -166,32 +167,34 @@ enum common_params_sampling_config : uint64_t {
struct common_params_sampling {
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
int32_t n_prev = 64; // number of previous tokens to remember
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
int32_t top_k = 40; // <= 0 to use vocab size
float top_p = 0.95f; // 1.0 = disabled
float min_p = 0.05f; // 0.0 = disabled
float xtc_probability = 0.00f; // 0.0 = disabled
float xtc_threshold = 0.10f; // > 0.5 disables XTC
float typ_p = 1.00f; // typical_p, 1.0 = disabled
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
float dynatemp_range = 0.00f; // 0.0 = disabled
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
float penalty_repeat = 1.00f; // 1.0 = disabled
float penalty_freq = 0.00f; // 0.0 = disabled
float penalty_present = 0.00f; // 0.0 = disabled
float dry_multiplier = 0.0f; // 0.0 = disabled; DRY repetition penalty for tokens extending repetition:
float dry_base = 1.75f; // 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length)
int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty
int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
float top_n_sigma = -1.00f;// -1.0 = disabled
float mirostat_tau = 5.00f; // target entropy
float mirostat_eta = 0.10f; // learning rate
int32_t n_prev = 64; // number of previous tokens to remember
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
int32_t top_k = 40; // <= 0 to use vocab size
float top_p = 0.95f; // 1.0 = disabled
float min_p = 0.05f; // 0.0 = disabled
float xtc_probability = 0.00f; // 0.0 = disabled
float xtc_threshold = 0.10f; // > 0.5 disables XTC
float typ_p = 1.00f; // typical_p, 1.0 = disabled
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
float dynatemp_range = 0.00f; // 0.0 = disabled
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
float penalty_repeat = 1.00f; // 1.0 = disabled
float penalty_freq = 0.00f; // 0.0 = disabled
float penalty_present = 0.00f; // 0.0 = disabled
float dry_multiplier = 0.0f; // 0.0 = disabled; DRY repetition penalty for tokens extending repetition:
float dry_base = 1.75f; // 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length)
int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty
int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
float adaptive_target = -1.0f; // select tokens near this probability (valid range 0.0 to 1.0; negative = disabled)
float adaptive_decay = 0.90f; // EMA decay for adaptation; history ≈ 1/(1-decay) tokens (0.0 - 0.99)
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
float top_n_sigma = -1.00f; // -1.0 = disabled
float mirostat_tau = 5.00f; // target entropy
float mirostat_eta = 0.10f; // learning rate
bool ignore_eos = false;
bool no_perf = false; // disable performance metrics
bool no_perf = false; // disable performance metrics
bool timing_per_token = false;
uint64_t user_sampling_config = 0; // bitfield to track user-specified samplers

View File

@ -167,11 +167,11 @@ std::string common_params_sampling::print() const {
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
"\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n"
"\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, top_n_sigma = %.3f, temp = %.3f\n"
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f, adaptive_target = %.3f, adaptive_decay = %.3f",
penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n,
top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, top_n_sigma, temp,
mirostat, mirostat_eta, mirostat_tau);
mirostat, mirostat_eta, mirostat_tau, adaptive_target, adaptive_decay);
return std::string(result);
}
@ -255,6 +255,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
}
if (params.mirostat == 0) {
bool use_adaptive_p = false; // see below
for (const auto & cnstr : params.samplers) {
switch (cnstr) {
case COMMON_SAMPLER_TYPE_DRY:
@ -264,43 +267,54 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
for (const auto & str : params.dry_sequence_breakers) {
c_breakers.push_back(str.c_str());
}
samplers.push_back(llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
samplers.push_back(llama_sampler_init_dry(vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
}
break;
case COMMON_SAMPLER_TYPE_TOP_K:
samplers.push_back(llama_sampler_init_top_k (params.top_k));
samplers.push_back(llama_sampler_init_top_k(params.top_k));
break;
case COMMON_SAMPLER_TYPE_TOP_P:
samplers.push_back(llama_sampler_init_top_p (params.top_p, params.min_keep));
samplers.push_back(llama_sampler_init_top_p(params.top_p, params.min_keep));
break;
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
samplers.push_back(llama_sampler_init_top_n_sigma(params.top_n_sigma));
break;
case COMMON_SAMPLER_TYPE_MIN_P:
samplers.push_back(llama_sampler_init_min_p (params.min_p, params.min_keep));
samplers.push_back(llama_sampler_init_min_p(params.min_p, params.min_keep));
break;
case COMMON_SAMPLER_TYPE_XTC:
samplers.push_back(llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
samplers.push_back(llama_sampler_init_xtc(params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
break;
case COMMON_SAMPLER_TYPE_TYPICAL_P:
samplers.push_back(llama_sampler_init_typical (params.typ_p, params.min_keep));
samplers.push_back(llama_sampler_init_typical(params.typ_p, params.min_keep));
break;
case COMMON_SAMPLER_TYPE_TEMPERATURE:
samplers.push_back(llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
samplers.push_back(llama_sampler_init_temp_ext(params.temp, params.dynatemp_range, params.dynatemp_exponent));
break;
case COMMON_SAMPLER_TYPE_INFILL:
samplers.push_back(llama_sampler_init_infill (vocab));
samplers.push_back(llama_sampler_init_infill(vocab));
break;
case COMMON_SAMPLER_TYPE_PENALTIES:
samplers.push_back(llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
samplers.push_back(llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
break;
case COMMON_SAMPLER_TYPE_ADAPTIVE_P:
// the `adaptive-p` sampler is like `dist` and `mirostat` in that it selects
// a single token, so we will add `dist` at the end of the chain by default,
// unless the user specifically included `adaptive-p`. we set this flag here
// so we know to add the sampler at the very end.
use_adaptive_p = true;
break;
default:
GGML_ASSERT(false && "unknown sampler type");
}
}
samplers.push_back(llama_sampler_init_dist(params.seed));
if (use_adaptive_p) {
// only if user explicitly included adaptive-p sampler
samplers.push_back(llama_sampler_init_adaptive_p(params.adaptive_target, params.adaptive_decay, params.seed));
} else {
// default: sample from distribution
samplers.push_back(llama_sampler_init_dist(params.seed));
}
} else if (params.mirostat == 1) {
samplers.push_back(llama_sampler_init_temp(params.temp));
samplers.push_back(llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
@ -625,6 +639,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
case COMMON_SAMPLER_TYPE_XTC: return 'x';
case COMMON_SAMPLER_TYPE_INFILL: return 'i';
case COMMON_SAMPLER_TYPE_PENALTIES: return 'e';
case COMMON_SAMPLER_TYPE_ADAPTIVE_P: return 'a';
default : return '?';
}
}
@ -641,6 +656,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
case COMMON_SAMPLER_TYPE_XTC: return "xtc";
case COMMON_SAMPLER_TYPE_INFILL: return "infill";
case COMMON_SAMPLER_TYPE_PENALTIES: return "penalties";
case COMMON_SAMPLER_TYPE_ADAPTIVE_P: return "adaptive_p";
default : return "";
}
}
@ -657,6 +673,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
{ "xtc", COMMON_SAMPLER_TYPE_XTC },
{ "infill", COMMON_SAMPLER_TYPE_INFILL },
{ "penalties", COMMON_SAMPLER_TYPE_PENALTIES },
{ "adaptive_p", COMMON_SAMPLER_TYPE_ADAPTIVE_P },
};
// since samplers names are written multiple ways
@ -672,6 +689,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
{ "typ", COMMON_SAMPLER_TYPE_TYPICAL_P },
{ "min-p", COMMON_SAMPLER_TYPE_MIN_P },
{ "temp", COMMON_SAMPLER_TYPE_TEMPERATURE },
{ "adaptive-p", COMMON_SAMPLER_TYPE_ADAPTIVE_P },
};
std::vector<common_sampler_type> samplers;
@ -708,6 +726,7 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_PENALTIES), COMMON_SAMPLER_TYPE_PENALTIES },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_ADAPTIVE_P), COMMON_SAMPLER_TYPE_ADAPTIVE_P },
};
std::vector<common_sampler_type> samplers;

View File

@ -1395,6 +1395,33 @@ extern "C" {
const char ** seq_breakers,
size_t num_breakers);
/// adaptive-p: select tokens near a configurable target probability over time.
///
/// the adaptive-p sampler transforms the token probability distribution to favor tokens
/// that fall near a user-configurable probability target.
///
/// internally, the sampler maintains an exponential moving average of the *ORIGINAL*
/// probabilities of selected tokens at each sampling step. it uses this EMA to compute an
/// adapted target probability at each sampling step, thus maintaining the desired target
/// probability over time.
///
/// adaptive-p selects a token ID rather than just mutating candidates, so it must be last
/// in the sampler chain (like mirostat, dist, greedy).
///
/// only mild truncation before this sampler is recommended. we suggest applying min-p
/// before adaptive-p as the only other active sampler in the chain.
///
/// @param target select tokens near this probability (valid range 0.0 to 1.0; negative = disabled)
/// @param decay EMA decay for adaptation; history ≈ 1/(1-decay) tokens (valid range 0.0 - 0.99)
/// @param seed RNG seed
///
/// ref: https://github.com/ggml-org/llama.cpp/pull/17927
///
LLAMA_API struct llama_sampler * llama_sampler_init_adaptive_p(
float target,
float decay,
uint32_t seed);
LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias(
int32_t n_vocab,
int32_t n_logit_bias,

View File

@ -1513,12 +1513,9 @@ static void llama_sampler_top_p_backend_apply(
mask_reshaped = ggml_set_rows(ctx, mask_reshaped, ones, ggml_cast(ctx, idxf, GGML_TYPE_I32));
mask = ggml_reshape_1d(ctx, mask_reshaped, mask->ne[0]);
// Use ggml_scale_bias (output = (a * s) + b) which in this case becomes:
// top_p_bias = (mask * 1e9f) - 1e9f.
// So entries in the mask that we want to discard will become -1e9f, and
// others will be 0 (meaning that will not effect the logits).
const float large_val = 1e9f;
struct ggml_tensor * top_p_bias = ggml_scale_bias(ctx, mask, large_val, -large_val);
// Apply -INFINITY bias for masked-out tokens
// log(1) = 0 (keep), log(0) = -INF (discard)
struct ggml_tensor * top_p_bias = ggml_log(ctx, mask);
ggml_set_name(top_p_bias, "top_p_bias");
data->logits = ggml_add(ctx, sorted_logits, top_p_bias);
@ -1673,15 +1670,11 @@ static void llama_sampler_min_p_backend_apply(
struct ggml_tensor * mask = ggml_step(ctx, sub);
ggml_set_name(mask, "min_p_mask");
// Use ggml_scale_bias (output = (a * s) + b) which in this case becomes:
// min_p_bias = (mask * 1e9f) - 1e9f.
// So entries in the mask that we want to discard will become -1e9f, and
// others will be 0 (meaning that will not effect the logits).
const float large_val = 1e9f;
struct ggml_tensor * min_p_bias = ggml_scale_bias(ctx, mask, large_val, -large_val);
// Apply -INFINITY bias for masked-out tokens
// log(1) = 0 (keep), log(0) = -INF (discard)
struct ggml_tensor * min_p_bias = ggml_log(ctx, mask);
ggml_set_name(min_p_bias, "min_p_bias");
// Add the min_p bias to the logits.
data->logits = ggml_add(ctx, data->logits, min_p_bias);
ggml_set_name(data->logits, "min_p_logits");
@ -3293,6 +3286,170 @@ struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, floa
return result;
}
// adaptive-p sampler state
//
// maintains an exponential moving average of the *ORIGINAL* probabilities
// of selected tokens, used to compute an adapted target at each sampling step.
//
// see llama.h for a full description of the sampler
//
// ref: https://github.com/ggml-org/llama.cpp/pull/17927
//
struct llama_sampler_adaptive_p {
const float target; // target probability (0.0 - 1.0; negative = disabled)
const float decay; // EMA decay; history ~= 1/(1-decay) tokens (0.0 - 0.99)
const uint32_t seed; // original RNG seed
uint32_t seed_cur; // actual RNG seed
std::mt19937 rng; // RNG state
float weighted_sum; // sum(p_i * decay^i)
float total_weight; // sum(decay^i), converges to 1/(1-decay)
std::vector<float> original_probs; // pre-transform probs, cached for EMA update
llama_token pending_token_id; // token ID of selected token
int32_t pending_token_idx; // index of orig. prob. of selected token in original_probs
};
// adaptive probability transformation constants
static constexpr float DISTRIBUTION_WIDTH = 0.3f;
static constexpr float PEAK_LOGIT_VALUE = 5.0f;
static constexpr float SHARPNESS = 10.0f;
static constexpr float INV_WIDTH = 1.0f / DISTRIBUTION_WIDTH;
static const char * llama_sampler_adaptive_p_name(const struct llama_sampler * /*smpl*/) {
return "adaptive-p";
}
static void llama_sampler_adaptive_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
auto * ctx = (llama_sampler_adaptive_p *) smpl->ctx;
llama_sampler_softmax_impl(cur_p, false);
if (ctx->target < 0.0f) {
// at negative target values, adaptive-p is no-op
// we simply sample from the existing distribution
cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
return;
}
// store the original probabilities
ctx->original_probs.resize(cur_p->size);
for (size_t i = 0; i < cur_p->size; ++i) {
ctx->original_probs[i] = cur_p->data[i].p;
}
// using the EMA, compute the adapted target probability for the current sampling step
auto target = std::clamp(ctx->target, 0.0f, 1.0f);
float adapted_target = std::clamp(
ctx->total_weight == 0.0f ? target : 2.0f * target - (ctx->weighted_sum / ctx->total_weight),
0.0f, 1.0f
);
// adaptive probability transform
//
// quadratic near target for fine differentiation, transitioning to linear decay in the
// tails. unbounded negative logits ensure proper suppression of far-from-target tokens
// after the softmax.
//
for (size_t i = 0; i < cur_p->size; ++i) {
if (cur_p->data[i].logit == -INFINITY) {
// don't transform logits that are -INFINITY
// (as masked out by e.g. min-p and top-p when using backend sampling)
continue;
}
float dist = std::abs((cur_p->data[i].p - adapted_target) * INV_WIDTH);
cur_p->data[i].logit = PEAK_LOGIT_VALUE - SHARPNESS * dist * dist / (1.0f + dist);
}
// softmax and sample from the transformed distribution
llama_sampler_softmax_impl(cur_p, false);
const int idx = llama_sample_dist(cur_p, ctx->rng);
cur_p->selected = idx;
// store the selected token ID for acceptance later
ctx->pending_token_id = cur_p->data[idx].id;
ctx->pending_token_idx = idx;
}
static void llama_sampler_adaptive_p_accept(struct llama_sampler * smpl, llama_token token) {
auto * ctx = (llama_sampler_adaptive_p *) smpl->ctx;
if (ctx->pending_token_id == token) {
GGML_ASSERT(ctx->pending_token_id != LLAMA_TOKEN_NULL);
GGML_ASSERT(ctx->pending_token_idx != -1);
// update EMA with the original probability of the selected token
ctx->weighted_sum = ctx->original_probs[ctx->pending_token_idx] + ctx->decay * ctx->weighted_sum;
ctx->total_weight = 1.0f + ctx->decay * ctx->total_weight;
}
ctx->pending_token_id = LLAMA_TOKEN_NULL;
ctx->pending_token_idx = -1;
}
static void llama_sampler_adaptive_p_reset(struct llama_sampler * smpl) {
auto * ctx = (llama_sampler_adaptive_p *) smpl->ctx;
// ctx->target and ctx->decay never change after init, so it's safe to keep them as is.
// original_probs is completely overwritten on every call to _apply.
// so we only need to reset the EMA state and pending token.
ctx->weighted_sum = ctx->target / (1.0f - ctx->decay);
ctx->total_weight = 1.0f / (1.0f - ctx->decay);
ctx->pending_token_id = LLAMA_TOKEN_NULL;
ctx->pending_token_idx = -1;
ctx->seed_cur = get_rng_seed(ctx->seed);
ctx->rng.seed(ctx->seed_cur);
}
static struct llama_sampler * llama_sampler_adaptive_p_clone(const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_adaptive_p *) smpl->ctx;
auto * result = llama_sampler_init_adaptive_p(ctx->target, ctx->decay, ctx->seed);
auto * result_ctx = (llama_sampler_adaptive_p *) result->ctx;
// copy everything (target, decay, seed, and RNG are already set)
result_ctx->weighted_sum = ctx->weighted_sum;
result_ctx->total_weight = ctx->total_weight;
result_ctx->pending_token_id = ctx->pending_token_id;
result_ctx->pending_token_idx = ctx->pending_token_idx;
return result;
}
static void llama_sampler_adaptive_p_free(struct llama_sampler * smpl) {
delete (llama_sampler_adaptive_p *) smpl->ctx;
}
static struct llama_sampler_i llama_sampler_adaptive_p_i = {
/* .name = */ llama_sampler_adaptive_p_name,
/* .accept = */ llama_sampler_adaptive_p_accept,
/* .apply = */ llama_sampler_adaptive_p_apply,
/* .reset = */ llama_sampler_adaptive_p_reset,
/* .clone = */ llama_sampler_adaptive_p_clone,
/* .free = */ llama_sampler_adaptive_p_free,
/* .backend_init = */ nullptr,
/* .backend_accept = */ nullptr,
/* .backend_apply = */ nullptr,
/* .backend_set_input = */ nullptr,
};
struct llama_sampler * llama_sampler_init_adaptive_p(
float target,
float decay,
uint32_t seed
) {
auto seed_cur = get_rng_seed(seed);
float clamped_decay = std::clamp(decay, 0.0f, 0.99f);
return llama_sampler_init(
/* .iface = */ &llama_sampler_adaptive_p_i,
/* .ctx = */ new llama_sampler_adaptive_p {
/* .target = */ target,
/* .decay = */ clamped_decay,
/* .seed = */ seed,
/* .seed_cur = */ seed_cur,
/* .rng = */ std::mt19937(seed_cur),
/* .weighted_sum = */ target / (1.0f - clamped_decay),
/* .total_weight = */ 1.0f / (1.0f - clamped_decay),
/* .original_probs = */ {},
/* .pending_token_id = */ LLAMA_TOKEN_NULL,
/* .pending_token_idx = */ -1
}
);
}
// logit-bias
struct llama_sampler_logit_bias : public llama_sampler_backend {

View File

@ -113,6 +113,8 @@
| `--top-k N` | top-k sampling (default: 40, 0 = disabled)<br/>(env: LLAMA_ARG_TOP_K) |
| `--top-p N` | top-p sampling (default: 0.9, 1.0 = disabled) |
| `--min-p N` | min-p sampling (default: 0.1, 0.0 = disabled) |
| `--adaptive-target N` | adaptive-p: select tokens near this probability (valid range 0.0 to 1.0; negative = disabled) |
| `--adaptive-decay N` | adaptive-p: EMA decay for adaptation; effective history length ≈ 1/(1-decay) tokens (valid range 0.0 - 0.99) |
| `--top-nsigma N` | top-n-sigma sampling (default: -1.0, -1.0 = disabled) |
| `--xtc-probability N` | xtc probability (default: 0.0, 0.0 = disabled) |
| `--xtc-threshold N` | xtc threshold (default: 0.1, 1.0 = disabled) |

View File

@ -436,6 +436,19 @@ The Min-P sampling method was designed as an alternative to Top-P, and aims to e
Example usage: `--min-p 0.05`
### Adaptive-P Sampling
- `--adaptive-target N`: select tokens near this probability (valid range 0.0 to 1.0; negative = disabled)
- `--adaptive-decay N`: EMA decay for adaptation; history ≈ 1/(1-decay) tokens (valid range 0.0 - 0.99)
Adaptive-P: Select tokens near a configurable target probability over time.
The adaptive-p sampler transforms the token probability distribution to favor tokens that fall near a user-configurable probability target. Internally, the sampler maintains an exponential moving average of the *ORIGINAL* probabilities of selected tokens at each sampling step. It uses this EMA to compute an adapted target probability at each sampling step, thus maintaining the desired target probability over time. Only mild truncation before this sampler is recommended. It is suggested to apply min-p before adaptive-p as the only other active sampler.
Recommended starting values: `--adaptive-target 0.55 --adaptive-decay 0.9`
For more info, refer to: [llama.cpp#17927](https://github.com/ggml-org/llama.cpp/pull/17927)
### Locally Typical Sampling
- `--typical N`: Enable locally typical sampling with parameter p (default: 1.0, 1.0 = disabled).

View File

@ -130,6 +130,8 @@ For the ful list of features, please refer to [server's changelog](https://githu
| `--top-k N` | top-k sampling (default: 40, 0 = disabled)<br/>(env: LLAMA_ARG_TOP_K) |
| `--top-p N` | top-p sampling (default: 0.9, 1.0 = disabled) |
| `--min-p N` | min-p sampling (default: 0.1, 0.0 = disabled) |
| `--adaptive-target N` | adaptive-p: select tokens near this probability (valid range 0.0 to 1.0; negative = disabled) |
| `--adaptive-decay N` | adaptive-p: EMA decay for adaptation; effective history length ≈ 1/(1-decay) tokens (valid range 0.0 - 0.99) |
| `--top-nsigma N` | top-n-sigma sampling (default: -1.0, -1.0 = disabled) |
| `--xtc-probability N` | xtc probability (default: 0.0, 0.0 = disabled) |
| `--xtc-threshold N` | xtc threshold (default: 0.1, 1.0 = disabled) |

View File

@ -204,6 +204,8 @@ task_params server_task::params_from_json_cmpl(
params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat);
params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau);
params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta);
params.sampling.adaptive_target = json_value(data, "adaptive_target", defaults.sampling.adaptive_target);
params.sampling.adaptive_decay = json_value(data, "adaptive_decay", defaults.sampling.adaptive_decay);
params.sampling.seed = json_value(data, "seed", defaults.sampling.seed);
params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs);
params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep);