fix cold start EMA

- `ctx->weighted_sum` is now initialized and reset to `target / (1.0f -
clamped_decay)`
- `ctx->total_weight` is now initialized and reset to `1.0f / (1.0f -
clamped_decay)`

this fixes a "cold start" problem with the moving average
This commit is contained in:
ddh0 2025-12-28 20:31:26 -06:00
parent f0d3f13124
commit e7a892065d
1 changed files with 17 additions and 12 deletions

View File

@ -2338,7 +2338,7 @@ struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, floa
// ref: https://github.com/ggml-org/llama.cpp/pull/17927
struct llama_sampler_adaptive_p {
const float target; // target probability (0.0 - 1.0; negative = disabled)
const float decay; // EMA decay; history 1/(1-decay) tokens (0.0 - 0.99)
const float decay; // EMA decay; history ~= 1/(1-decay) tokens (0.0 - 0.99)
const uint32_t seed; // RNG seed
std::mt19937 rng; // RNG
float weighted_sum; // sum(p_i * decay^i)
@ -2397,15 +2397,18 @@ static void llama_sampler_adaptive_p_apply(struct llama_sampler * smpl, llama_to
const int idx = llama_sample_dist(cur_p, ctx->rng);
cur_p->selected = idx;
// update history with the original probability of the selected token
// update EMA with the original probability of the selected token
ctx->weighted_sum = ctx->original_probs[idx] + ctx->decay * ctx->weighted_sum;
ctx->total_weight = 1.0f + ctx->decay * ctx->total_weight;
}
static void llama_sampler_adaptive_p_reset(struct llama_sampler * smpl) {
auto * ctx = (llama_sampler_adaptive_p *) smpl->ctx;
ctx->weighted_sum = 0.0f;
ctx->total_weight = 0.0f;
auto * ctx = (llama_sampler_adaptive_p *) smpl->ctx;
// ctx->target and ctx->decay never change after init, so it's safe to keep them as is.
// original_probs is completely overwritten on every call to _apply.
// so we only need to reset the EMA state.
ctx->weighted_sum = ctx->target / (1.0f - ctx->decay);
ctx->total_weight = 1.0f / (1.0f - ctx->decay);
}
static struct llama_sampler * llama_sampler_adaptive_p_clone(const struct llama_sampler * smpl) {
@ -2413,10 +2416,11 @@ static struct llama_sampler * llama_sampler_adaptive_p_clone(const struct llama_
auto * result = llama_sampler_init_adaptive_p(ctx->target, ctx->decay, ctx->seed);
auto * result_ctx = (llama_sampler_adaptive_p *) result->ctx;
result_ctx->rng = ctx->rng;
result_ctx->weighted_sum = ctx->weighted_sum;
result_ctx->total_weight = ctx->total_weight;
result_ctx->original_probs.reserve(ctx->original_probs.capacity());
// copy everything (target, decay, and seed are already set)
result_ctx->original_probs = ctx->original_probs;
result_ctx->weighted_sum = ctx->weighted_sum;
result_ctx->total_weight = ctx->total_weight;
result_ctx->rng = ctx->rng;
return result;
}
@ -2440,15 +2444,16 @@ struct llama_sampler * llama_sampler_init_adaptive_p(
uint32_t seed
) {
auto seed_cur = get_rng_seed(seed);
float clamped_decay = std::clamp(decay, 0.0f, 0.99f);
return llama_sampler_init(
/* .iface = */ &llama_sampler_adaptive_p_i,
/* .ctx = */ new llama_sampler_adaptive_p {
/* .target = */ target,
/* .decay = */ std::clamp(decay, 0.0f, 0.99f),
/* .decay = */ clamped_decay,
/* .seed = */ seed_cur,
/* .rng = */ std::mt19937(seed_cur),
/* .weighted_sum = */ 0.0f,
/* .total_weight = */ 0.0f,
/* .weighted_sum = */ target / (1.0f - clamped_decay),
/* .total_weight = */ 1.0f / (1.0f - clamped_decay),
/* .original_probs = */ {},
}
);