re-write + change parameters + simplify

This commit is contained in:
ddh0 2025-12-13 22:15:03 -06:00
parent 67a733670e
commit a96ddd743a
4 changed files with 130 additions and 211 deletions

View File

@ -164,35 +164,35 @@ enum common_params_sampling_config : uint64_t {
struct common_params_sampling {
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
int32_t n_prev = 64; // number of previous tokens to remember
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
int32_t top_k = 40; // <= 0 to use vocab size
float top_p = 0.95f; // 1.0 = disabled
float min_p = 0.05f; // 0.0 = disabled
float xtc_probability = 0.00f; // 0.0 = disabled
float xtc_threshold = 0.10f; // > 0.5 disables XTC
float typ_p = 1.00f; // typical_p, 1.0 = disabled
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
float dynatemp_range = 0.00f; // 0.0 = disabled
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
float penalty_repeat = 1.00f; // 1.0 = disabled
float penalty_freq = 0.00f; // 0.0 = disabled
float penalty_present = 0.00f; // 0.0 = disabled
float dry_multiplier = 0.0f; // 0.0 = disabled; DRY repetition penalty for tokens extending repetition:
float dry_base = 1.75f; // 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length)
int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty
int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
float power_law_target = -1.0f; // target probability for Power Law sampling (valid range 0.0 to 1.0; <0 = disabled)
int32_t power_law_window_size = 10; // rolling window size for target adaptation in Power Law sampling (≤0 = fixed target)
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
float top_n_sigma = -1.00f; // -1.0 = disabled
float mirostat_tau = 5.00f; // target entropy
float mirostat_eta = 0.10f; // learning rate
bool ignore_eos = false;
bool no_perf = false; // disable performance metrics
bool timing_per_token = false;
int32_t n_prev = 64; // number of previous tokens to remember
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
int32_t top_k = 40; // <= 0 to use vocab size
float top_p = 0.95f; // 1.0 = disabled
float min_p = 0.05f; // 0.0 = disabled
float xtc_probability = 0.00f; // 0.0 = disabled
float xtc_threshold = 0.10f; // > 0.5 disables XTC
float typ_p = 1.00f; // typical_p, 1.0 = disabled
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
float dynatemp_range = 0.00f; // 0.0 = disabled
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
float penalty_repeat = 1.00f; // 1.0 = disabled
float penalty_freq = 0.00f; // 0.0 = disabled
float penalty_present = 0.00f; // 0.0 = disabled
float dry_multiplier = 0.0f; // 0.0 = disabled; DRY repetition penalty for tokens extending repetition:
float dry_base = 1.75f; // 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length)
int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty
int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
float power_law_target = -1.0f; // select tokens near this probability (valid range 0.0 to 1.0; <0 = disabled)
float power_law_decay = 0.9f; // decay rate for target adaptation over time. lower values -> faster but less stable adaptation. (valid range 0.0 to 1.0; ≤0 = no adaptation)
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
float top_n_sigma = -1.00f; // -1.0 = disabled
float mirostat_tau = 5.00f; // target entropy
float mirostat_eta = 0.10f; // learning rate
bool ignore_eos = false;
bool no_perf = false; // disable performance metrics
bool timing_per_token = false;
uint64_t user_sampling_config = 0; // bitfield to track user-specified samplers

View File

@ -1289,24 +1289,28 @@ extern "C" {
const char ** seq_breakers,
size_t num_breakers);
/// @details power-law sampler - reshapes probability distribution to target specific probability ranges
/// power-law
///
/// this sampler implements a power law probability transformation with adaptive
/// target tracking. it reshapes token probability distributions to favor tokens near a
/// configurable target probability, rather than always selecting from the highest probability
/// candidates. it is ideal for creative, unpredictable text generation.
///
/// this sampler is like `greedy`, `dist`, and `mirostat` in that it actually selects a token ID
/// rather than just transforming logits. therefore it must always be the last sampler in the
/// sampler chain.
///
/// it is recommended to only perform minimal truncation before this sampler.
/// minimal truncation before this sampler is recommended.
///
/// @param target target probability (valid range 0.0 to 1.0; <0 = disabled)
/// @param window_size rolling window size for target adaptation (≤0 = fixed target)
/// @param seed RNG seed
/// @param target select tokens near this probability (valid range 0.0 to 1.0; <0 = disabled)
/// @param decay decay rate for target adaptation over time. lower values -> faster but less stable adaptation. (valid range 0.0 to 1.0; ≤0 = no adaptation)
///
/// ref: https://github.com/MrJackSpade/llama.cpp/tree/master (original impl, documentation)
/// ref: https://github.com/MrJackSpade/llama.cpp/tree/master (original impl)
/// ref: https://github.com/ggml-org/llama.cpp/pull/17927 (llama.cpp PR)
LLAMA_API struct llama_sampler * llama_sampler_init_power_law(
float target,
int32_t window_size,
uint32_t seed);
float target,
float decay,
uint32_t seed);
LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias(
int32_t n_vocab,

View File

@ -2315,133 +2315,62 @@ struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, floa
// power-law
//
// this sampler implements a power law probability transformation with adaptive
// target tracking. it reshapes token probability distributions to favor tokens near a
// configurable target probability, rather than always selecting from the highest probability
// candidates. it is ideal for creative, unpredictable text generation.
//
// this sampler is like `greedy`, `dist`, and `mirostat` in that it actually selects a token ID
// rather than just transforming logits. therefore it must always be the last sampler in the
// sampler chain.
//
// it is recommended to only perform minimal truncation before this sampler.
// minimal truncation before this sampler is recommended.
//
// ref: https://github.com/MrJackSpade/llama.cpp/tree/master (original impl, documentation)
// ref: https://github.com/MrJackSpade/llama.cpp/tree/master (original impl)
// ref: https://github.com/ggml-org/llama.cpp/pull/17927 (llama.cpp PR)
struct llama_sampler_power_law {
const float target;
const int32_t window_size;
const uint32_t seed;
std::mt19937 rng;
ring_buffer<float> window;
// the desired average probability for selected tokens (0.0 to 1.0)
// higher values favor more probable tokens (more deterministic)
// lower values favor less probable tokens (more creative)
// negative values disable Power Law sampling (sample from distribution as-is)
const float target;
// controls how quickly history influence fades (0.0 to 0.99)
// lower values = faster adaptation, more reactive to recent tokens
// higher values = slower adaptation, more stable over time
// effective history length ≈ 1/(1-decay) tokens
// examples: decay=0.5 → ~2 tokens, decay=0.9 → ~10, decay=0.95 → ~20
// internally clamped to <= 0.99 to prevent unbounded accumulation
const float decay;
const uint32_t seed;
std::mt19937 rng;
// historical token probabilities weighted by recency
float weighted_sum;
// sum of weights, converges to 1/(1-decay)
float total_weight;
};
static const char * llama_sampler_power_law_name(const struct llama_sampler * /*smpl*/) {
return "power-law";
}
// Computes the target probability for the current sampling step.
//
// The target determines which token probabilities the power law distribution
// will favor. This function implements a dynamic feedback mechanism to maintain
// an average selection probability close to the base target over time.
//
// When the window is empty:
// - Returns the base target value (ctx->target)
//
// When the window has entries:
// - Calculates what the next target should be to keep the weighted average
// of selected token probabilities equal to ctx->target
// - Uses exponential decay weighting: newer values have more influence
//
// Exponential Decay Weighting:
// After inserting the new value, the weights will be:
// new_value: weight = 1 (age 0, newest)
// rat(0): weight = decay (age 1)
// rat(1): weight = decay^2 (age 2)
// ...
// rat(sz-2): weight = decay^(sz-1)
// rat(sz-1): evicted (oldest)
//
// The "effective window size" is approximately 1/(1-decay):
// decay=0.9 → effective window ≈ 10 tokens
// decay=0.95 → effective window ≈ 20 tokens
// decay=1.0 → no decay, equivalent to simple average (original behavior)
//
// Formula derivation:
// We want the weighted average after insertion to equal target:
//
// (new_value * 1 + Σ rat(i) * decay^(i+1)) / total_weight = target
//
// Where total_weight = 1 + decay + decay^2 + ... + decay^(sz-1)
// = (1 - decay^sz) / (1 - decay) [geometric series]
//
// Solving for new_value:
// new_value = target * total_weight - decay * Σ rat(i) * decay^i
//
// The factor of 'decay' on the sum accounts for all existing values
// shifting one position older when the new value is inserted.
//
// The exponential decay helps prevent "fishtailing" - a phenomenon where
// forced high-probability selections (when the model is very confident)
// cause the algorithm to overcorrect with many low-probability selections,
// then swing back the other way. By decaying old values, the influence of
// forced selections fades faster, reducing oscillation amplitude and
// recovery time.
//
// Finally, the computed target is clamped to [min_target, max_target] to
// prevent extreme values that could destabilize sampling.
//
static float llama_sampler_power_law_compute_target(
const llama_sampler_power_law * ctx,
float min_target,
float max_target,
float tail_decay) {
float computed_target = ctx->target;
size_t sz = ctx->window.size();
if (sz > 0) {
// Check if window is at capacity (oldest element will be evicted on next push)
// Use the window_size parameter from context, not a capacity() method
const bool window_full = (sz == (size_t)ctx->window_size);
// Compute weighted sum with exponential decay
// rat(0) = newest in buffer, gets weight 1
// rat(i) gets weight decay^i
//
// When window is full: exclude oldest element (it will be evicted)
// When window is not full: include all elements (nothing evicted)
float weighted_sum = 0.0f;
float weight = 1.0f;
size_t elements_to_sum = window_full ? (sz - 1) : sz;
for (size_t i = 0; i < elements_to_sum; ++i) {
weighted_sum += ctx->window.rat(i) * weight;
weight *= tail_decay;
}
// Shift weights to account for new value taking position 0
// All existing values age by 1, so multiply their weights by decay
float shifted_weighted_sum = weighted_sum * tail_decay;
// Compute total weight after new value is inserted
// When full: sz elements remain (oldest evicted, new added)
// When not full: sz + 1 elements (new added, nothing evicted)
size_t final_element_count = window_full ? sz : (sz + 1);
float total_weight;
if (std::abs(tail_decay - 1.0f) < FLT_EPSILON) {
total_weight = (float) final_element_count;
} else {
total_weight = (1.0f - std::pow(tail_decay, (float) final_element_count)) / (1.0f - tail_decay);
}
// Solve for the new value that achieves target weighted average
float next_value = (ctx->target * total_weight) - shifted_weighted_sum;
// Clamp to allowed range
computed_target = std::max(min_target, std::min(next_value, max_target));
// compute the adaptive target probability for the current sampling step
static float llama_sampler_power_law_compute_target(const llama_sampler_power_law * ctx, float decay) {
if (ctx->total_weight == 0.0f) {
// if there is no history, just use base target
return ctx->target;
}
return computed_target;
// maintain a running weighted sum with exponential decay
float new_total_weight = 1.0f + decay * ctx->total_weight;
float next_value = ctx->target * new_total_weight - decay * ctx->weighted_sum;
// clamp to [0.0, 1.0]
return std::max(0.0f, std::min(next_value, 1.0f));
}
static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
@ -2455,30 +2384,25 @@ static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_tok
return;
}
// clamp decay to avoid degenerate case at 1.0 (unbounded accumulation)
const float decay = std::min(ctx->decay, 0.99f);
// fixed power law transform parameters
const float distribution_width = 0.3f;
const float peak_logit_value = 5.0f;
const float tail_heaviness = 2.0f;
// target computation parameters
const float min_target = 0.0f;
const float max_target = 1.0f;
const float tail_decay = 0.50f; // exponential decay factor for history weighting
// lower = faster response, higher = more stability
// effective window ≈ 1/(1-decay) ≈ 20 tokens
// compute probabilities to get the "original" values
// get the original probabilities
llama_sampler_softmax_impl(cur_p, false);
// store original probabilities (used for future target adaptation)
// store the original probabilities (needed for history update after selection)
std::vector<float> original_probs;
original_probs.reserve(cur_p->size);
for (size_t i = 0; i < cur_p->size; ++i) {
original_probs.push_back(cur_p->data[i].p);
}
// calculate adaptive target
float computed_target = llama_sampler_power_law_compute_target(ctx, min_target, max_target, tail_decay);
float computed_target = llama_sampler_power_law_compute_target(ctx, decay);
//
// power law transform
@ -2492,40 +2416,30 @@ static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_tok
llama_sampler_softmax_impl(cur_p, false);
// sample from the transformed distribution
// sample from transformed distribution
const int idx = llama_sample_dist(cur_p, ctx->rng);
cur_p->selected = idx;
// uncomment this to log the target values and history window contents for every token
//
// fprintf(stderr, "power_law: window_size=%zu/%d values=[",
// ctx->window.size(), ctx->window_size);
// for (size_t i = 0; i < ctx->window.size(); ++i) {
// fprintf(stderr, "%.1f", ctx->window.rat(i));
// if (i < ctx->window.size() - 1) fprintf(stderr, ",");
// }
// fprintf(stderr, "] computed_target=%.4f selected_token=%d orig_prob=%.4f\n",
// computed_target, cur_p->data[idx].id, original_probs[idx]);
// fflush(stderr);
// add the ORIGINAL probability to the rolling window
float original_p = original_probs[idx];
ctx->window.push_back(original_p);
// update running history with the original probability of the selected token
float original_p = original_probs[idx];
ctx->weighted_sum = original_p + decay * ctx->weighted_sum;
ctx->total_weight = 1.0f + decay * ctx->total_weight;
}
static void llama_sampler_power_law_reset(struct llama_sampler * smpl) {
auto * ctx = (llama_sampler_power_law *) smpl->ctx;
ctx->window = ring_buffer<float>(ctx->window_size);
auto * ctx = (llama_sampler_power_law *) smpl->ctx;
ctx->weighted_sum = 0.0f;
ctx->total_weight = 0.0f;
}
static struct llama_sampler * llama_sampler_power_law_clone(const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_power_law *) smpl->ctx;
auto * result = llama_sampler_init_power_law(ctx->target, ctx->window_size, ctx->seed);
auto * result = llama_sampler_init_power_law(ctx->target, ctx->decay, ctx->seed);
auto * result_ctx = (llama_sampler_power_law *) result->ctx;
result_ctx->rng = ctx->rng;
result_ctx->window = ctx->window;
result_ctx->rng = ctx->rng;
result_ctx->weighted_sum = ctx->weighted_sum;
result_ctx->total_weight = ctx->total_weight;
return result;
}
@ -2545,7 +2459,7 @@ static struct llama_sampler_i llama_sampler_power_law_i = {
struct llama_sampler * llama_sampler_init_power_law(
float target,
int32_t window_size,
float decay,
uint32_t seed
) {
auto seed_cur = get_rng_seed(seed);
@ -2553,10 +2467,11 @@ struct llama_sampler * llama_sampler_init_power_law(
/* .iface = */ &llama_sampler_power_law_i,
/* .ctx = */ new llama_sampler_power_law {
/* .target = */ target,
/* .window_size = */ window_size,
/* .decay = */ decay,
/* .seed = */ seed_cur,
/* .rng = */ std::mt19937(seed_cur),
/* .window = */ ring_buffer<float>(window_size),
/* .weighted_sum = */ 0.0f,
/* .total_weight = */ 0.0f,
}
);
}

View File

@ -182,33 +182,33 @@ task_params server_task::params_from_json_cmpl(
params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms);
params.response_fields = json_value(data, "response_fields", std::vector<std::string>());
params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k);
params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p);
params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p);
params.sampling.top_n_sigma = json_value(data, "top_n_sigma", defaults.sampling.top_n_sigma);
params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability);
params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold);
params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p);
params.sampling.temp = json_value(data, "temperature", defaults.sampling.temp);
params.sampling.dynatemp_range = json_value(data, "dynatemp_range", defaults.sampling.dynatemp_range);
params.sampling.dynatemp_exponent = json_value(data, "dynatemp_exponent", defaults.sampling.dynatemp_exponent);
params.sampling.penalty_last_n = json_value(data, "repeat_last_n", defaults.sampling.penalty_last_n);
params.sampling.penalty_repeat = json_value(data, "repeat_penalty", defaults.sampling.penalty_repeat);
params.sampling.penalty_freq = json_value(data, "frequency_penalty", defaults.sampling.penalty_freq);
params.sampling.penalty_present = json_value(data, "presence_penalty", defaults.sampling.penalty_present);
params.sampling.dry_multiplier = json_value(data, "dry_multiplier", defaults.sampling.dry_multiplier);
params.sampling.dry_base = json_value(data, "dry_base", defaults.sampling.dry_base);
params.sampling.dry_allowed_length = json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length);
params.sampling.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n);
params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat);
params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau);
params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta);
params.sampling.power_law_target = json_value(data, "power_law_target", defaults.sampling.power_law_target);
params.sampling.power_law_window_size = json_value(data, "power_law_window_size", defaults.sampling.power_law_window_size);
params.sampling.seed = json_value(data, "seed", defaults.sampling.seed);
params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs);
params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep);
params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs);
params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k);
params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p);
params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p);
params.sampling.top_n_sigma = json_value(data, "top_n_sigma", defaults.sampling.top_n_sigma);
params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability);
params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold);
params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p);
params.sampling.temp = json_value(data, "temperature", defaults.sampling.temp);
params.sampling.dynatemp_range = json_value(data, "dynatemp_range", defaults.sampling.dynatemp_range);
params.sampling.dynatemp_exponent = json_value(data, "dynatemp_exponent", defaults.sampling.dynatemp_exponent);
params.sampling.penalty_last_n = json_value(data, "repeat_last_n", defaults.sampling.penalty_last_n);
params.sampling.penalty_repeat = json_value(data, "repeat_penalty", defaults.sampling.penalty_repeat);
params.sampling.penalty_freq = json_value(data, "frequency_penalty", defaults.sampling.penalty_freq);
params.sampling.penalty_present = json_value(data, "presence_penalty", defaults.sampling.penalty_present);
params.sampling.dry_multiplier = json_value(data, "dry_multiplier", defaults.sampling.dry_multiplier);
params.sampling.dry_base = json_value(data, "dry_base", defaults.sampling.dry_base);
params.sampling.dry_allowed_length = json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length);
params.sampling.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n);
params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat);
params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau);
params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta);
params.sampling.power_law_target = json_value(data, "power_law_target", defaults.sampling.power_law_target);
params.sampling.power_law_decay = json_value(data, "power_law_decay", defaults.sampling.power_law_decay);
params.sampling.seed = json_value(data, "seed", defaults.sampling.seed);
params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs);
params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep);
params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs);
params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min);
params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max);