remove old debug log, style nit
This commit is contained in:
parent
94cb883ed9
commit
0a19a3fd6c
|
|
@ -2389,10 +2389,12 @@ static const char * llama_sampler_power_law_name(const struct llama_sampler * /*
|
|||
// Finally, the computed target is clamped to [min_target, max_target] to
|
||||
// prevent extreme values that could destabilize sampling.
|
||||
//
|
||||
static float llama_sampler_power_law_compute_target(const llama_sampler_power_law * ctx,
|
||||
static float llama_sampler_power_law_compute_target(
|
||||
const llama_sampler_power_law * ctx,
|
||||
float min_target,
|
||||
float max_target,
|
||||
float tail_decay) {
|
||||
|
||||
float computed_target = ctx->target;
|
||||
size_t sz = ctx->window.size();
|
||||
|
||||
|
|
@ -2416,6 +2418,10 @@ static float llama_sampler_power_law_compute_target(const llama_sampler_power_la
|
|||
weight *= tail_decay;
|
||||
}
|
||||
|
||||
// Shift weights to account for new value taking position 0
|
||||
// All existing values age by 1, so multiply their weights by decay
|
||||
float shifted_weighted_sum = weighted_sum * tail_decay;
|
||||
|
||||
// Compute total weight after new value is inserted
|
||||
// When full: sz elements remain (oldest evicted, new added)
|
||||
// When not full: sz + 1 elements (new added, nothing evicted)
|
||||
|
|
@ -2428,10 +2434,6 @@ static float llama_sampler_power_law_compute_target(const llama_sampler_power_la
|
|||
total_weight = (1.0f - std::pow(tail_decay, (float) final_element_count)) / (1.0f - tail_decay);
|
||||
}
|
||||
|
||||
// Shift weights to account for new value taking position 0
|
||||
// All existing values age by 1, so multiply their weights by decay
|
||||
float shifted_weighted_sum = weighted_sum * tail_decay;
|
||||
|
||||
// Solve for the new value that achieves target weighted average
|
||||
float next_value = (ctx->target * total_weight) - shifted_weighted_sum;
|
||||
|
||||
|
|
@ -2446,7 +2448,6 @@ static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_tok
|
|||
auto * ctx = (llama_sampler_power_law *) smpl->ctx;
|
||||
|
||||
if (ctx->target < 0.0f) {
|
||||
fprintf(stderr, "Target below zero, sampling from distribution\n");
|
||||
// no-op: just sample from the distribution as-is
|
||||
llama_sampler_softmax_impl(cur_p, false);
|
||||
const int idx = llama_sample_dist(cur_p, ctx->rng);
|
||||
|
|
@ -2462,9 +2463,9 @@ static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_tok
|
|||
// target computation parameters
|
||||
const float min_target = 0.0f;
|
||||
const float max_target = 1.0f;
|
||||
const float tail_decay = 0.50f; // Exponential decay factor for history weighting
|
||||
// Lower = faster response, higher = more stability
|
||||
// Effective window ≈ 1/(1-decay) ≈ 20 tokens
|
||||
const float tail_decay = 0.50f; // exponential decay factor for history weighting
|
||||
// lower = faster response, higher = more stability
|
||||
// effective window ≈ 1/(1-decay) ≈ 20 tokens
|
||||
|
||||
// compute probabilities to get the "original" values
|
||||
llama_sampler_softmax_impl(cur_p, false);
|
||||
|
|
@ -2479,7 +2480,10 @@ static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_tok
|
|||
// calculate adaptive target
|
||||
float computed_target = llama_sampler_power_law_compute_target(ctx, min_target, max_target, tail_decay);
|
||||
|
||||
//
|
||||
// power law transform
|
||||
//
|
||||
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
float p = cur_p->data[i].p;
|
||||
float normalized_distance = std::abs(p - computed_target) / distribution_width;
|
||||
|
|
|
|||
Loading…
Reference in New Issue