diff --git a/common/common.h b/common/common.h index ba3d776bdc..66a6ca96b3 100644 --- a/common/common.h +++ b/common/common.h @@ -164,35 +164,35 @@ enum common_params_sampling_config : uint64_t { struct common_params_sampling { uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler - int32_t n_prev = 64; // number of previous tokens to remember - int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. - int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens - int32_t top_k = 40; // <= 0 to use vocab size - float top_p = 0.95f; // 1.0 = disabled - float min_p = 0.05f; // 0.0 = disabled - float xtc_probability = 0.00f; // 0.0 = disabled - float xtc_threshold = 0.10f; // > 0.5 disables XTC - float typ_p = 1.00f; // typical_p, 1.0 = disabled - float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities - float dynatemp_range = 0.00f; // 0.0 = disabled - float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler - int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) - float penalty_repeat = 1.00f; // 1.0 = disabled - float penalty_freq = 0.00f; // 0.0 = disabled - float penalty_present = 0.00f; // 0.0 = disabled - float dry_multiplier = 0.0f; // 0.0 = disabled; DRY repetition penalty for tokens extending repetition: - float dry_base = 1.75f; // 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length) - int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty - int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size) - float power_law_target = -1.0f; // target probability for Power Law sampling (valid range 0.0 to 1.0; <0 = disabled) - int32_t power_law_window_size = 10; // rolling window size for target adaptation in Power Law sampling (≤0 = fixed target) - int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 - float top_n_sigma = -1.00f; // -1.0 = disabled - float mirostat_tau = 5.00f; // target entropy - float mirostat_eta = 0.10f; // learning rate - bool ignore_eos = false; - bool no_perf = false; // disable performance metrics - bool timing_per_token = false; + int32_t n_prev = 64; // number of previous tokens to remember + int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. + int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens + int32_t top_k = 40; // <= 0 to use vocab size + float top_p = 0.95f; // 1.0 = disabled + float min_p = 0.05f; // 0.0 = disabled + float xtc_probability = 0.00f; // 0.0 = disabled + float xtc_threshold = 0.10f; // > 0.5 disables XTC + float typ_p = 1.00f; // typical_p, 1.0 = disabled + float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities + float dynatemp_range = 0.00f; // 0.0 = disabled + float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler + int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) + float penalty_repeat = 1.00f; // 1.0 = disabled + float penalty_freq = 0.00f; // 0.0 = disabled + float penalty_present = 0.00f; // 0.0 = disabled + float dry_multiplier = 0.0f; // 0.0 = disabled; DRY repetition penalty for tokens extending repetition: + float dry_base = 1.75f; // 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length) + int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty + int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size) + float power_law_target = -1.0f; // select tokens near this probability (valid range 0.0 to 1.0; <0 = disabled) + float power_law_decay = 0.9f; // decay rate for target adaptation over time. lower values -> faster but less stable adaptation. (valid range 0.0 to 1.0; ≤0 = no adaptation) + int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 + float top_n_sigma = -1.00f; // -1.0 = disabled + float mirostat_tau = 5.00f; // target entropy + float mirostat_eta = 0.10f; // learning rate + bool ignore_eos = false; + bool no_perf = false; // disable performance metrics + bool timing_per_token = false; uint64_t user_sampling_config = 0; // bitfield to track user-specified samplers diff --git a/include/llama.h b/include/llama.h index ce1308d2bd..f3867c6988 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1289,24 +1289,28 @@ extern "C" { const char ** seq_breakers, size_t num_breakers); - /// @details power-law sampler - reshapes probability distribution to target specific probability ranges + /// power-law + /// + /// this sampler implements a power law probability transformation with adaptive + /// target tracking. it reshapes token probability distributions to favor tokens near a + /// configurable target probability, rather than always selecting from the highest probability + /// candidates. it is ideal for creative, unpredictable text generation. /// /// this sampler is like `greedy`, `dist`, and `mirostat` in that it actually selects a token ID /// rather than just transforming logits. therefore it must always be the last sampler in the /// sampler chain. /// - /// it is recommended to only perform minimal truncation before this sampler. + /// minimal truncation before this sampler is recommended. /// - /// @param target target probability (valid range 0.0 to 1.0; <0 = disabled) - /// @param window_size rolling window size for target adaptation (≤0 = fixed target) - /// @param seed RNG seed + /// @param target select tokens near this probability (valid range 0.0 to 1.0; <0 = disabled) + /// @param decay decay rate for target adaptation over time. lower values -> faster but less stable adaptation. (valid range 0.0 to 1.0; ≤0 = no adaptation) /// - /// ref: https://github.com/MrJackSpade/llama.cpp/tree/master (original impl, documentation) + /// ref: https://github.com/MrJackSpade/llama.cpp/tree/master (original impl) /// ref: https://github.com/ggml-org/llama.cpp/pull/17927 (llama.cpp PR) LLAMA_API struct llama_sampler * llama_sampler_init_power_law( - float target, - int32_t window_size, - uint32_t seed); + float target, + float decay, + uint32_t seed); LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias( int32_t n_vocab, diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 7686f59148..db126a18d5 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -2315,133 +2315,62 @@ struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, floa // power-law // +// this sampler implements a power law probability transformation with adaptive +// target tracking. it reshapes token probability distributions to favor tokens near a +// configurable target probability, rather than always selecting from the highest probability +// candidates. it is ideal for creative, unpredictable text generation. +// // this sampler is like `greedy`, `dist`, and `mirostat` in that it actually selects a token ID // rather than just transforming logits. therefore it must always be the last sampler in the // sampler chain. // -// it is recommended to only perform minimal truncation before this sampler. +// minimal truncation before this sampler is recommended. // -// ref: https://github.com/MrJackSpade/llama.cpp/tree/master (original impl, documentation) +// ref: https://github.com/MrJackSpade/llama.cpp/tree/master (original impl) // ref: https://github.com/ggml-org/llama.cpp/pull/17927 (llama.cpp PR) struct llama_sampler_power_law { - const float target; - const int32_t window_size; - const uint32_t seed; - std::mt19937 rng; - ring_buffer window; + // the desired average probability for selected tokens (0.0 to 1.0) + // higher values favor more probable tokens (more deterministic) + // lower values favor less probable tokens (more creative) + // negative values disable Power Law sampling (sample from distribution as-is) + const float target; + + // controls how quickly history influence fades (0.0 to 0.99) + // lower values = faster adaptation, more reactive to recent tokens + // higher values = slower adaptation, more stable over time + // effective history length ≈ 1/(1-decay) tokens + // examples: decay=0.5 → ~2 tokens, decay=0.9 → ~10, decay=0.95 → ~20 + // internally clamped to <= 0.99 to prevent unbounded accumulation + const float decay; + + const uint32_t seed; + std::mt19937 rng; + + // historical token probabilities weighted by recency + float weighted_sum; + // sum of weights, converges to 1/(1-decay) + float total_weight; }; static const char * llama_sampler_power_law_name(const struct llama_sampler * /*smpl*/) { return "power-law"; } -// Computes the target probability for the current sampling step. -// -// The target determines which token probabilities the power law distribution -// will favor. This function implements a dynamic feedback mechanism to maintain -// an average selection probability close to the base target over time. -// -// When the window is empty: -// - Returns the base target value (ctx->target) -// -// When the window has entries: -// - Calculates what the next target should be to keep the weighted average -// of selected token probabilities equal to ctx->target -// - Uses exponential decay weighting: newer values have more influence -// -// Exponential Decay Weighting: -// After inserting the new value, the weights will be: -// new_value: weight = 1 (age 0, newest) -// rat(0): weight = decay (age 1) -// rat(1): weight = decay^2 (age 2) -// ... -// rat(sz-2): weight = decay^(sz-1) -// rat(sz-1): evicted (oldest) -// -// The "effective window size" is approximately 1/(1-decay): -// decay=0.9 → effective window ≈ 10 tokens -// decay=0.95 → effective window ≈ 20 tokens -// decay=1.0 → no decay, equivalent to simple average (original behavior) -// -// Formula derivation: -// We want the weighted average after insertion to equal target: -// -// (new_value * 1 + Σ rat(i) * decay^(i+1)) / total_weight = target -// -// Where total_weight = 1 + decay + decay^2 + ... + decay^(sz-1) -// = (1 - decay^sz) / (1 - decay) [geometric series] -// -// Solving for new_value: -// new_value = target * total_weight - decay * Σ rat(i) * decay^i -// -// The factor of 'decay' on the sum accounts for all existing values -// shifting one position older when the new value is inserted. -// -// The exponential decay helps prevent "fishtailing" - a phenomenon where -// forced high-probability selections (when the model is very confident) -// cause the algorithm to overcorrect with many low-probability selections, -// then swing back the other way. By decaying old values, the influence of -// forced selections fades faster, reducing oscillation amplitude and -// recovery time. -// -// Finally, the computed target is clamped to [min_target, max_target] to -// prevent extreme values that could destabilize sampling. -// -static float llama_sampler_power_law_compute_target( - const llama_sampler_power_law * ctx, - float min_target, - float max_target, - float tail_decay) { - - float computed_target = ctx->target; - size_t sz = ctx->window.size(); - - if (sz > 0) { - // Check if window is at capacity (oldest element will be evicted on next push) - // Use the window_size parameter from context, not a capacity() method - const bool window_full = (sz == (size_t)ctx->window_size); - - // Compute weighted sum with exponential decay - // rat(0) = newest in buffer, gets weight 1 - // rat(i) gets weight decay^i - // - // When window is full: exclude oldest element (it will be evicted) - // When window is not full: include all elements (nothing evicted) - float weighted_sum = 0.0f; - float weight = 1.0f; - size_t elements_to_sum = window_full ? (sz - 1) : sz; - - for (size_t i = 0; i < elements_to_sum; ++i) { - weighted_sum += ctx->window.rat(i) * weight; - weight *= tail_decay; - } - - // Shift weights to account for new value taking position 0 - // All existing values age by 1, so multiply their weights by decay - float shifted_weighted_sum = weighted_sum * tail_decay; - - // Compute total weight after new value is inserted - // When full: sz elements remain (oldest evicted, new added) - // When not full: sz + 1 elements (new added, nothing evicted) - size_t final_element_count = window_full ? sz : (sz + 1); - - float total_weight; - if (std::abs(tail_decay - 1.0f) < FLT_EPSILON) { - total_weight = (float) final_element_count; - } else { - total_weight = (1.0f - std::pow(tail_decay, (float) final_element_count)) / (1.0f - tail_decay); - } - - // Solve for the new value that achieves target weighted average - float next_value = (ctx->target * total_weight) - shifted_weighted_sum; - - // Clamp to allowed range - computed_target = std::max(min_target, std::min(next_value, max_target)); +// compute the adaptive target probability for the current sampling step +static float llama_sampler_power_law_compute_target(const llama_sampler_power_law * ctx, float decay) { + if (ctx->total_weight == 0.0f) { + // if there is no history, just use base target + return ctx->target; } - return computed_target; + // maintain a running weighted sum with exponential decay + float new_total_weight = 1.0f + decay * ctx->total_weight; + float next_value = ctx->target * new_total_weight - decay * ctx->weighted_sum; + + // clamp to [0.0, 1.0] + return std::max(0.0f, std::min(next_value, 1.0f)); } static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { @@ -2455,30 +2384,25 @@ static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_tok return; } + // clamp decay to avoid degenerate case at 1.0 (unbounded accumulation) + const float decay = std::min(ctx->decay, 0.99f); + // fixed power law transform parameters const float distribution_width = 0.3f; const float peak_logit_value = 5.0f; const float tail_heaviness = 2.0f; - // target computation parameters - const float min_target = 0.0f; - const float max_target = 1.0f; - const float tail_decay = 0.50f; // exponential decay factor for history weighting - // lower = faster response, higher = more stability - // effective window ≈ 1/(1-decay) ≈ 20 tokens - - // compute probabilities to get the "original" values + // get the original probabilities llama_sampler_softmax_impl(cur_p, false); - // store original probabilities (used for future target adaptation) + // store the original probabilities (needed for history update after selection) std::vector original_probs; original_probs.reserve(cur_p->size); for (size_t i = 0; i < cur_p->size; ++i) { original_probs.push_back(cur_p->data[i].p); } - // calculate adaptive target - float computed_target = llama_sampler_power_law_compute_target(ctx, min_target, max_target, tail_decay); + float computed_target = llama_sampler_power_law_compute_target(ctx, decay); // // power law transform @@ -2492,40 +2416,30 @@ static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_tok llama_sampler_softmax_impl(cur_p, false); - // sample from the transformed distribution + // sample from transformed distribution const int idx = llama_sample_dist(cur_p, ctx->rng); cur_p->selected = idx; - // uncomment this to log the target values and history window contents for every token - // - // fprintf(stderr, "power_law: window_size=%zu/%d values=[", - // ctx->window.size(), ctx->window_size); - // for (size_t i = 0; i < ctx->window.size(); ++i) { - // fprintf(stderr, "%.1f", ctx->window.rat(i)); - // if (i < ctx->window.size() - 1) fprintf(stderr, ","); - // } - // fprintf(stderr, "] computed_target=%.4f selected_token=%d orig_prob=%.4f\n", - // computed_target, cur_p->data[idx].id, original_probs[idx]); - // fflush(stderr); - - // add the ORIGINAL probability to the rolling window - float original_p = original_probs[idx]; - - ctx->window.push_back(original_p); + // update running history with the original probability of the selected token + float original_p = original_probs[idx]; + ctx->weighted_sum = original_p + decay * ctx->weighted_sum; + ctx->total_weight = 1.0f + decay * ctx->total_weight; } static void llama_sampler_power_law_reset(struct llama_sampler * smpl) { - auto * ctx = (llama_sampler_power_law *) smpl->ctx; - ctx->window = ring_buffer(ctx->window_size); + auto * ctx = (llama_sampler_power_law *) smpl->ctx; + ctx->weighted_sum = 0.0f; + ctx->total_weight = 0.0f; } static struct llama_sampler * llama_sampler_power_law_clone(const struct llama_sampler * smpl) { const auto * ctx = (const llama_sampler_power_law *) smpl->ctx; - auto * result = llama_sampler_init_power_law(ctx->target, ctx->window_size, ctx->seed); + auto * result = llama_sampler_init_power_law(ctx->target, ctx->decay, ctx->seed); auto * result_ctx = (llama_sampler_power_law *) result->ctx; - result_ctx->rng = ctx->rng; - result_ctx->window = ctx->window; + result_ctx->rng = ctx->rng; + result_ctx->weighted_sum = ctx->weighted_sum; + result_ctx->total_weight = ctx->total_weight; return result; } @@ -2545,7 +2459,7 @@ static struct llama_sampler_i llama_sampler_power_law_i = { struct llama_sampler * llama_sampler_init_power_law( float target, - int32_t window_size, + float decay, uint32_t seed ) { auto seed_cur = get_rng_seed(seed); @@ -2553,10 +2467,11 @@ struct llama_sampler * llama_sampler_init_power_law( /* .iface = */ &llama_sampler_power_law_i, /* .ctx = */ new llama_sampler_power_law { /* .target = */ target, - /* .window_size = */ window_size, + /* .decay = */ decay, /* .seed = */ seed_cur, /* .rng = */ std::mt19937(seed_cur), - /* .window = */ ring_buffer(window_size), + /* .weighted_sum = */ 0.0f, + /* .total_weight = */ 0.0f, } ); } diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index c3ac98f13f..6c083e6624 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -182,33 +182,33 @@ task_params server_task::params_from_json_cmpl( params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms); params.response_fields = json_value(data, "response_fields", std::vector()); - params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k); - params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p); - params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p); - params.sampling.top_n_sigma = json_value(data, "top_n_sigma", defaults.sampling.top_n_sigma); - params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability); - params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold); - params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p); - params.sampling.temp = json_value(data, "temperature", defaults.sampling.temp); - params.sampling.dynatemp_range = json_value(data, "dynatemp_range", defaults.sampling.dynatemp_range); - params.sampling.dynatemp_exponent = json_value(data, "dynatemp_exponent", defaults.sampling.dynatemp_exponent); - params.sampling.penalty_last_n = json_value(data, "repeat_last_n", defaults.sampling.penalty_last_n); - params.sampling.penalty_repeat = json_value(data, "repeat_penalty", defaults.sampling.penalty_repeat); - params.sampling.penalty_freq = json_value(data, "frequency_penalty", defaults.sampling.penalty_freq); - params.sampling.penalty_present = json_value(data, "presence_penalty", defaults.sampling.penalty_present); - params.sampling.dry_multiplier = json_value(data, "dry_multiplier", defaults.sampling.dry_multiplier); - params.sampling.dry_base = json_value(data, "dry_base", defaults.sampling.dry_base); - params.sampling.dry_allowed_length = json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length); - params.sampling.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n); - params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat); - params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau); - params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta); - params.sampling.power_law_target = json_value(data, "power_law_target", defaults.sampling.power_law_target); - params.sampling.power_law_window_size = json_value(data, "power_law_window_size", defaults.sampling.power_law_window_size); - params.sampling.seed = json_value(data, "seed", defaults.sampling.seed); - params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs); - params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep); - params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs); + params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k); + params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p); + params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p); + params.sampling.top_n_sigma = json_value(data, "top_n_sigma", defaults.sampling.top_n_sigma); + params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability); + params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold); + params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p); + params.sampling.temp = json_value(data, "temperature", defaults.sampling.temp); + params.sampling.dynatemp_range = json_value(data, "dynatemp_range", defaults.sampling.dynatemp_range); + params.sampling.dynatemp_exponent = json_value(data, "dynatemp_exponent", defaults.sampling.dynatemp_exponent); + params.sampling.penalty_last_n = json_value(data, "repeat_last_n", defaults.sampling.penalty_last_n); + params.sampling.penalty_repeat = json_value(data, "repeat_penalty", defaults.sampling.penalty_repeat); + params.sampling.penalty_freq = json_value(data, "frequency_penalty", defaults.sampling.penalty_freq); + params.sampling.penalty_present = json_value(data, "presence_penalty", defaults.sampling.penalty_present); + params.sampling.dry_multiplier = json_value(data, "dry_multiplier", defaults.sampling.dry_multiplier); + params.sampling.dry_base = json_value(data, "dry_base", defaults.sampling.dry_base); + params.sampling.dry_allowed_length = json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length); + params.sampling.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n); + params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat); + params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau); + params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta); + params.sampling.power_law_target = json_value(data, "power_law_target", defaults.sampling.power_law_target); + params.sampling.power_law_decay = json_value(data, "power_law_decay", defaults.sampling.power_law_decay); + params.sampling.seed = json_value(data, "seed", defaults.sampling.seed); + params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs); + params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep); + params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs); params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min); params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max);