copy from author
ref: https://gist.github.com/MrJackSpade/9be99c7efbba7b95a41377e123b7b069
This commit is contained in:
parent
53380c183f
commit
94cb883ed9
|
|
@ -2337,21 +2337,134 @@ static const char * llama_sampler_power_law_name(const struct llama_sampler * /*
|
|||
return "power-law";
|
||||
}
|
||||
|
||||
// Computes the target probability for the current sampling step.
|
||||
//
|
||||
// The target determines which token probabilities the power law distribution
|
||||
// will favor. This function implements a dynamic feedback mechanism to maintain
|
||||
// an average selection probability close to the base target over time.
|
||||
//
|
||||
// When the window is empty:
|
||||
// - Returns the base target value (ctx->target)
|
||||
//
|
||||
// When the window has entries:
|
||||
// - Calculates what the next target should be to keep the weighted average
|
||||
// of selected token probabilities equal to ctx->target
|
||||
// - Uses exponential decay weighting: newer values have more influence
|
||||
//
|
||||
// Exponential Decay Weighting:
|
||||
// After inserting the new value, the weights will be:
|
||||
// new_value: weight = 1 (age 0, newest)
|
||||
// rat(0): weight = decay (age 1)
|
||||
// rat(1): weight = decay^2 (age 2)
|
||||
// ...
|
||||
// rat(sz-2): weight = decay^(sz-1)
|
||||
// rat(sz-1): evicted (oldest)
|
||||
//
|
||||
// The "effective window size" is approximately 1/(1-decay):
|
||||
// decay=0.9 → effective window ≈ 10 tokens
|
||||
// decay=0.95 → effective window ≈ 20 tokens
|
||||
// decay=1.0 → no decay, equivalent to simple average (original behavior)
|
||||
//
|
||||
// Formula derivation:
|
||||
// We want the weighted average after insertion to equal target:
|
||||
//
|
||||
// (new_value * 1 + Σ rat(i) * decay^(i+1)) / total_weight = target
|
||||
//
|
||||
// Where total_weight = 1 + decay + decay^2 + ... + decay^(sz-1)
|
||||
// = (1 - decay^sz) / (1 - decay) [geometric series]
|
||||
//
|
||||
// Solving for new_value:
|
||||
// new_value = target * total_weight - decay * Σ rat(i) * decay^i
|
||||
//
|
||||
// The factor of 'decay' on the sum accounts for all existing values
|
||||
// shifting one position older when the new value is inserted.
|
||||
//
|
||||
// The exponential decay helps prevent "fishtailing" - a phenomenon where
|
||||
// forced high-probability selections (when the model is very confident)
|
||||
// cause the algorithm to overcorrect with many low-probability selections,
|
||||
// then swing back the other way. By decaying old values, the influence of
|
||||
// forced selections fades faster, reducing oscillation amplitude and
|
||||
// recovery time.
|
||||
//
|
||||
// Finally, the computed target is clamped to [min_target, max_target] to
|
||||
// prevent extreme values that could destabilize sampling.
|
||||
//
|
||||
static float llama_sampler_power_law_compute_target(const llama_sampler_power_law * ctx,
|
||||
float min_target,
|
||||
float max_target,
|
||||
float tail_decay) {
|
||||
float computed_target = ctx->target;
|
||||
size_t sz = ctx->window.size();
|
||||
|
||||
if (sz > 0) {
|
||||
// Check if window is at capacity (oldest element will be evicted on next push)
|
||||
// Use the window_size parameter from context, not a capacity() method
|
||||
const bool window_full = (sz == ctx->window_size);
|
||||
|
||||
// Compute weighted sum with exponential decay
|
||||
// rat(0) = newest in buffer, gets weight 1
|
||||
// rat(i) gets weight decay^i
|
||||
//
|
||||
// When window is full: exclude oldest element (it will be evicted)
|
||||
// When window is not full: include all elements (nothing evicted)
|
||||
float weighted_sum = 0.0f;
|
||||
float weight = 1.0f;
|
||||
size_t elements_to_sum = window_full ? (sz - 1) : sz;
|
||||
|
||||
for (size_t i = 0; i < elements_to_sum; ++i) {
|
||||
weighted_sum += ctx->window.rat(i) * weight;
|
||||
weight *= tail_decay;
|
||||
}
|
||||
|
||||
// Compute total weight after new value is inserted
|
||||
// When full: sz elements remain (oldest evicted, new added)
|
||||
// When not full: sz + 1 elements (new added, nothing evicted)
|
||||
size_t final_element_count = window_full ? sz : (sz + 1);
|
||||
|
||||
float total_weight;
|
||||
if (std::abs(tail_decay - 1.0f) < FLT_EPSILON) {
|
||||
total_weight = (float) final_element_count;
|
||||
} else {
|
||||
total_weight = (1.0f - std::pow(tail_decay, (float) final_element_count)) / (1.0f - tail_decay);
|
||||
}
|
||||
|
||||
// Shift weights to account for new value taking position 0
|
||||
// All existing values age by 1, so multiply their weights by decay
|
||||
float shifted_weighted_sum = weighted_sum * tail_decay;
|
||||
|
||||
// Solve for the new value that achieves target weighted average
|
||||
float next_value = (ctx->target * total_weight) - shifted_weighted_sum;
|
||||
|
||||
// Clamp to allowed range
|
||||
computed_target = std::max(min_target, std::min(next_value, max_target));
|
||||
}
|
||||
|
||||
return computed_target;
|
||||
}
|
||||
|
||||
static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||
auto * ctx = (llama_sampler_power_law *) smpl->ctx;
|
||||
|
||||
if (ctx->target < 0.0f) {
|
||||
fprintf(stderr, "Target below zero, sampling from distribution\n");
|
||||
// no-op: just sample from the distribution as-is
|
||||
llama_sampler_softmax_impl(cur_p, false);
|
||||
const int idx = llama_sample_dist(cur_p, ctx->rng);
|
||||
const int idx = llama_sample_dist(cur_p, ctx->rng);
|
||||
cur_p->selected = idx;
|
||||
return;
|
||||
}
|
||||
|
||||
// fixed power law transform parameters (from original implementation)
|
||||
const float distribution_width = 0.2f;
|
||||
const float peak_logit_value = 3.0f;
|
||||
const float tail_heaviness = 3.0f;
|
||||
// fixed power law transform parameters
|
||||
const float distribution_width = 0.3f;
|
||||
const float peak_logit_value = 5.0f;
|
||||
const float tail_heaviness = 2.0f;
|
||||
|
||||
// target computation parameters
|
||||
const float min_target = 0.0f;
|
||||
const float max_target = 1.0f;
|
||||
const float tail_decay = 0.50f; // Exponential decay factor for history weighting
|
||||
// Lower = faster response, higher = more stability
|
||||
// Effective window ≈ 1/(1-decay) ≈ 20 tokens
|
||||
|
||||
// compute probabilities to get the "original" values
|
||||
llama_sampler_softmax_impl(cur_p, false);
|
||||
|
|
@ -2363,45 +2476,26 @@ static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_tok
|
|||
original_probs.push_back(cur_p->data[i].p);
|
||||
}
|
||||
|
||||
//
|
||||
// calculate adaptive target
|
||||
//
|
||||
float computed_target = llama_sampler_power_law_compute_target(ctx, min_target, max_target, tail_decay);
|
||||
|
||||
const float min_target = 0.0f;
|
||||
const float max_target = 1.0f;
|
||||
|
||||
float computed_target = ctx->target;
|
||||
if (ctx->window.size() > 0) {
|
||||
float sum_excluding_oldest = 0.0f;
|
||||
size_t sz = ctx->window.size();
|
||||
|
||||
// sum all except the oldest element
|
||||
for (size_t i = 0; i < sz - 1; ++i) {
|
||||
sum_excluding_oldest += ctx->window.rat(i);
|
||||
}
|
||||
|
||||
float next_value = (ctx->target * ctx->window_size) - sum_excluding_oldest;
|
||||
computed_target = std::max(min_target, std::min(next_value, max_target));
|
||||
}
|
||||
|
||||
//
|
||||
// power law transform
|
||||
//
|
||||
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
float p = cur_p->data[i].p;
|
||||
float p = cur_p->data[i].p;
|
||||
float normalized_distance = std::abs(p - computed_target) / distribution_width;
|
||||
cur_p->data[i].logit = peak_logit_value / (1.0f + std::pow(normalized_distance, tail_heaviness));
|
||||
cur_p->data[i].logit = peak_logit_value / (1.0f + std::pow(normalized_distance, tail_heaviness));
|
||||
}
|
||||
|
||||
llama_sampler_softmax_impl(cur_p, false);
|
||||
|
||||
// sample from the transformed distribution
|
||||
const int idx = llama_sample_dist(cur_p, ctx->rng);
|
||||
const int idx = llama_sample_dist(cur_p, ctx->rng);
|
||||
cur_p->selected = idx;
|
||||
|
||||
// add the ORIGINAL probability to the rolling window
|
||||
ctx->window.push_back(original_probs[idx]);
|
||||
float original_p = original_probs[idx];
|
||||
|
||||
ctx->window.push_back(original_p);
|
||||
}
|
||||
|
||||
static void llama_sampler_power_law_reset(struct llama_sampler * smpl) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue