remove old debug log, style nit

This commit is contained in:
ddh0 2025-12-12 23:32:57 -06:00
parent 94cb883ed9
commit 0a19a3fd6c
1 changed files with 16 additions and 12 deletions

View File

@ -2389,10 +2389,12 @@ static const char * llama_sampler_power_law_name(const struct llama_sampler * /*
// Finally, the computed target is clamped to [min_target, max_target] to // Finally, the computed target is clamped to [min_target, max_target] to
// prevent extreme values that could destabilize sampling. // prevent extreme values that could destabilize sampling.
// //
static float llama_sampler_power_law_compute_target(const llama_sampler_power_law * ctx, static float llama_sampler_power_law_compute_target(
const llama_sampler_power_law * ctx,
float min_target, float min_target,
float max_target, float max_target,
float tail_decay) { float tail_decay) {
float computed_target = ctx->target; float computed_target = ctx->target;
size_t sz = ctx->window.size(); size_t sz = ctx->window.size();
@ -2416,6 +2418,10 @@ static float llama_sampler_power_law_compute_target(const llama_sampler_power_la
weight *= tail_decay; weight *= tail_decay;
} }
// Shift weights to account for new value taking position 0
// All existing values age by 1, so multiply their weights by decay
float shifted_weighted_sum = weighted_sum * tail_decay;
// Compute total weight after new value is inserted // Compute total weight after new value is inserted
// When full: sz elements remain (oldest evicted, new added) // When full: sz elements remain (oldest evicted, new added)
// When not full: sz + 1 elements (new added, nothing evicted) // When not full: sz + 1 elements (new added, nothing evicted)
@ -2428,10 +2434,6 @@ static float llama_sampler_power_law_compute_target(const llama_sampler_power_la
total_weight = (1.0f - std::pow(tail_decay, (float) final_element_count)) / (1.0f - tail_decay); total_weight = (1.0f - std::pow(tail_decay, (float) final_element_count)) / (1.0f - tail_decay);
} }
// Shift weights to account for new value taking position 0
// All existing values age by 1, so multiply their weights by decay
float shifted_weighted_sum = weighted_sum * tail_decay;
// Solve for the new value that achieves target weighted average // Solve for the new value that achieves target weighted average
float next_value = (ctx->target * total_weight) - shifted_weighted_sum; float next_value = (ctx->target * total_weight) - shifted_weighted_sum;
@ -2446,7 +2448,6 @@ static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_tok
auto * ctx = (llama_sampler_power_law *) smpl->ctx; auto * ctx = (llama_sampler_power_law *) smpl->ctx;
if (ctx->target < 0.0f) { if (ctx->target < 0.0f) {
fprintf(stderr, "Target below zero, sampling from distribution\n");
// no-op: just sample from the distribution as-is // no-op: just sample from the distribution as-is
llama_sampler_softmax_impl(cur_p, false); llama_sampler_softmax_impl(cur_p, false);
const int idx = llama_sample_dist(cur_p, ctx->rng); const int idx = llama_sample_dist(cur_p, ctx->rng);
@ -2462,9 +2463,9 @@ static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_tok
// target computation parameters // target computation parameters
const float min_target = 0.0f; const float min_target = 0.0f;
const float max_target = 1.0f; const float max_target = 1.0f;
const float tail_decay = 0.50f; // Exponential decay factor for history weighting const float tail_decay = 0.50f; // exponential decay factor for history weighting
// Lower = faster response, higher = more stability // lower = faster response, higher = more stability
// Effective window ≈ 1/(1-decay) ≈ 20 tokens // effective window ≈ 1/(1-decay) ≈ 20 tokens
// compute probabilities to get the "original" values // compute probabilities to get the "original" values
llama_sampler_softmax_impl(cur_p, false); llama_sampler_softmax_impl(cur_p, false);
@ -2479,7 +2480,10 @@ static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_tok
// calculate adaptive target // calculate adaptive target
float computed_target = llama_sampler_power_law_compute_target(ctx, min_target, max_target, tail_decay); float computed_target = llama_sampler_power_law_compute_target(ctx, min_target, max_target, tail_decay);
//
// power law transform // power law transform
//
for (size_t i = 0; i < cur_p->size; ++i) { for (size_t i = 0; i < cur_p->size; ++i) {
float p = cur_p->data[i].p; float p = cur_p->data[i].p;
float normalized_distance = std::abs(p - computed_target) / distribution_width; float normalized_distance = std::abs(p - computed_target) / distribution_width;