From 94cb883ed9184ac96a838566b0cbbb7918237b64 Mon Sep 17 00:00:00 2001 From: ddh0 Date: Fri, 12 Dec 2025 23:19:08 -0600 Subject: [PATCH] copy from author ref: https://gist.github.com/MrJackSpade/9be99c7efbba7b95a41377e123b7b069 --- src/llama-sampling.cpp | 156 +++++++++++++++++++++++++++++++++-------- 1 file changed, 125 insertions(+), 31 deletions(-) diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index d5f485f846..738fd05caa 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -2337,21 +2337,134 @@ static const char * llama_sampler_power_law_name(const struct llama_sampler * /* return "power-law"; } +// Computes the target probability for the current sampling step. +// +// The target determines which token probabilities the power law distribution +// will favor. This function implements a dynamic feedback mechanism to maintain +// an average selection probability close to the base target over time. +// +// When the window is empty: +// - Returns the base target value (ctx->target) +// +// When the window has entries: +// - Calculates what the next target should be to keep the weighted average +// of selected token probabilities equal to ctx->target +// - Uses exponential decay weighting: newer values have more influence +// +// Exponential Decay Weighting: +// After inserting the new value, the weights will be: +// new_value: weight = 1 (age 0, newest) +// rat(0): weight = decay (age 1) +// rat(1): weight = decay^2 (age 2) +// ... +// rat(sz-2): weight = decay^(sz-1) +// rat(sz-1): evicted (oldest) +// +// The "effective window size" is approximately 1/(1-decay): +// decay=0.9 → effective window ≈ 10 tokens +// decay=0.95 → effective window ≈ 20 tokens +// decay=1.0 → no decay, equivalent to simple average (original behavior) +// +// Formula derivation: +// We want the weighted average after insertion to equal target: +// +// (new_value * 1 + Σ rat(i) * decay^(i+1)) / total_weight = target +// +// Where total_weight = 1 + decay + decay^2 + ... + decay^(sz-1) +// = (1 - decay^sz) / (1 - decay) [geometric series] +// +// Solving for new_value: +// new_value = target * total_weight - decay * Σ rat(i) * decay^i +// +// The factor of 'decay' on the sum accounts for all existing values +// shifting one position older when the new value is inserted. +// +// The exponential decay helps prevent "fishtailing" - a phenomenon where +// forced high-probability selections (when the model is very confident) +// cause the algorithm to overcorrect with many low-probability selections, +// then swing back the other way. By decaying old values, the influence of +// forced selections fades faster, reducing oscillation amplitude and +// recovery time. +// +// Finally, the computed target is clamped to [min_target, max_target] to +// prevent extreme values that could destabilize sampling. +// +static float llama_sampler_power_law_compute_target(const llama_sampler_power_law * ctx, + float min_target, + float max_target, + float tail_decay) { + float computed_target = ctx->target; + size_t sz = ctx->window.size(); + + if (sz > 0) { + // Check if window is at capacity (oldest element will be evicted on next push) + // Use the window_size parameter from context, not a capacity() method + const bool window_full = (sz == ctx->window_size); + + // Compute weighted sum with exponential decay + // rat(0) = newest in buffer, gets weight 1 + // rat(i) gets weight decay^i + // + // When window is full: exclude oldest element (it will be evicted) + // When window is not full: include all elements (nothing evicted) + float weighted_sum = 0.0f; + float weight = 1.0f; + size_t elements_to_sum = window_full ? (sz - 1) : sz; + + for (size_t i = 0; i < elements_to_sum; ++i) { + weighted_sum += ctx->window.rat(i) * weight; + weight *= tail_decay; + } + + // Compute total weight after new value is inserted + // When full: sz elements remain (oldest evicted, new added) + // When not full: sz + 1 elements (new added, nothing evicted) + size_t final_element_count = window_full ? sz : (sz + 1); + + float total_weight; + if (std::abs(tail_decay - 1.0f) < FLT_EPSILON) { + total_weight = (float) final_element_count; + } else { + total_weight = (1.0f - std::pow(tail_decay, (float) final_element_count)) / (1.0f - tail_decay); + } + + // Shift weights to account for new value taking position 0 + // All existing values age by 1, so multiply their weights by decay + float shifted_weighted_sum = weighted_sum * tail_decay; + + // Solve for the new value that achieves target weighted average + float next_value = (ctx->target * total_weight) - shifted_weighted_sum; + + // Clamp to allowed range + computed_target = std::max(min_target, std::min(next_value, max_target)); + } + + return computed_target; +} + static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { auto * ctx = (llama_sampler_power_law *) smpl->ctx; if (ctx->target < 0.0f) { + fprintf(stderr, "Target below zero, sampling from distribution\n"); // no-op: just sample from the distribution as-is llama_sampler_softmax_impl(cur_p, false); - const int idx = llama_sample_dist(cur_p, ctx->rng); + const int idx = llama_sample_dist(cur_p, ctx->rng); cur_p->selected = idx; return; } - // fixed power law transform parameters (from original implementation) - const float distribution_width = 0.2f; - const float peak_logit_value = 3.0f; - const float tail_heaviness = 3.0f; + // fixed power law transform parameters + const float distribution_width = 0.3f; + const float peak_logit_value = 5.0f; + const float tail_heaviness = 2.0f; + + // target computation parameters + const float min_target = 0.0f; + const float max_target = 1.0f; + const float tail_decay = 0.50f; // Exponential decay factor for history weighting + // Lower = faster response, higher = more stability + // Effective window ≈ 1/(1-decay) ≈ 20 tokens // compute probabilities to get the "original" values llama_sampler_softmax_impl(cur_p, false); @@ -2363,45 +2476,26 @@ static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_tok original_probs.push_back(cur_p->data[i].p); } - // // calculate adaptive target - // + float computed_target = llama_sampler_power_law_compute_target(ctx, min_target, max_target, tail_decay); - const float min_target = 0.0f; - const float max_target = 1.0f; - - float computed_target = ctx->target; - if (ctx->window.size() > 0) { - float sum_excluding_oldest = 0.0f; - size_t sz = ctx->window.size(); - - // sum all except the oldest element - for (size_t i = 0; i < sz - 1; ++i) { - sum_excluding_oldest += ctx->window.rat(i); - } - - float next_value = (ctx->target * ctx->window_size) - sum_excluding_oldest; - computed_target = std::max(min_target, std::min(next_value, max_target)); - } - - // // power law transform - // - for (size_t i = 0; i < cur_p->size; ++i) { - float p = cur_p->data[i].p; + float p = cur_p->data[i].p; float normalized_distance = std::abs(p - computed_target) / distribution_width; - cur_p->data[i].logit = peak_logit_value / (1.0f + std::pow(normalized_distance, tail_heaviness)); + cur_p->data[i].logit = peak_logit_value / (1.0f + std::pow(normalized_distance, tail_heaviness)); } llama_sampler_softmax_impl(cur_p, false); // sample from the transformed distribution - const int idx = llama_sample_dist(cur_p, ctx->rng); + const int idx = llama_sample_dist(cur_p, ctx->rng); cur_p->selected = idx; // add the ORIGINAL probability to the rolling window - ctx->window.push_back(original_probs[idx]); + float original_p = original_probs[idx]; + + ctx->window.push_back(original_p); } static void llama_sampler_power_law_reset(struct llama_sampler * smpl) {