From e7a892065dfbb8ad2cf9ca43acd408e72bfc0321 Mon Sep 17 00:00:00 2001 From: ddh0 Date: Sun, 28 Dec 2025 20:31:26 -0600 Subject: [PATCH] fix cold start EMA - `ctx->weighted_sum` is now initialized and reset to `target / (1.0f - clamped_decay)` - `ctx->total_weight` is now initialized and reset to `1.0f / (1.0f - clamped_decay)` this fixes a "cold start" problem with the moving average --- src/llama-sampling.cpp | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 5a823ca457..137c865c30 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -2338,7 +2338,7 @@ struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, floa // ref: https://github.com/ggml-org/llama.cpp/pull/17927 struct llama_sampler_adaptive_p { const float target; // target probability (0.0 - 1.0; negative = disabled) - const float decay; // EMA decay; history ≈ 1/(1-decay) tokens (0.0 - 0.99) + const float decay; // EMA decay; history ~= 1/(1-decay) tokens (0.0 - 0.99) const uint32_t seed; // RNG seed std::mt19937 rng; // RNG float weighted_sum; // sum(p_i * decay^i) @@ -2397,15 +2397,18 @@ static void llama_sampler_adaptive_p_apply(struct llama_sampler * smpl, llama_to const int idx = llama_sample_dist(cur_p, ctx->rng); cur_p->selected = idx; - // update history with the original probability of the selected token + // update EMA with the original probability of the selected token ctx->weighted_sum = ctx->original_probs[idx] + ctx->decay * ctx->weighted_sum; ctx->total_weight = 1.0f + ctx->decay * ctx->total_weight; } static void llama_sampler_adaptive_p_reset(struct llama_sampler * smpl) { - auto * ctx = (llama_sampler_adaptive_p *) smpl->ctx; - ctx->weighted_sum = 0.0f; - ctx->total_weight = 0.0f; + auto * ctx = (llama_sampler_adaptive_p *) smpl->ctx; + // ctx->target and ctx->decay never change after init, so it's safe to keep them as is. + // original_probs is completely overwritten on every call to _apply. + // so we only need to reset the EMA state. + ctx->weighted_sum = ctx->target / (1.0f - ctx->decay); + ctx->total_weight = 1.0f / (1.0f - ctx->decay); } static struct llama_sampler * llama_sampler_adaptive_p_clone(const struct llama_sampler * smpl) { @@ -2413,10 +2416,11 @@ static struct llama_sampler * llama_sampler_adaptive_p_clone(const struct llama_ auto * result = llama_sampler_init_adaptive_p(ctx->target, ctx->decay, ctx->seed); auto * result_ctx = (llama_sampler_adaptive_p *) result->ctx; - result_ctx->rng = ctx->rng; - result_ctx->weighted_sum = ctx->weighted_sum; - result_ctx->total_weight = ctx->total_weight; - result_ctx->original_probs.reserve(ctx->original_probs.capacity()); + // copy everything (target, decay, and seed are already set) + result_ctx->original_probs = ctx->original_probs; + result_ctx->weighted_sum = ctx->weighted_sum; + result_ctx->total_weight = ctx->total_weight; + result_ctx->rng = ctx->rng; return result; } @@ -2440,15 +2444,16 @@ struct llama_sampler * llama_sampler_init_adaptive_p( uint32_t seed ) { auto seed_cur = get_rng_seed(seed); + float clamped_decay = std::clamp(decay, 0.0f, 0.99f); return llama_sampler_init( /* .iface = */ &llama_sampler_adaptive_p_i, /* .ctx = */ new llama_sampler_adaptive_p { /* .target = */ target, - /* .decay = */ std::clamp(decay, 0.0f, 0.99f), + /* .decay = */ clamped_decay, /* .seed = */ seed_cur, /* .rng = */ std::mt19937(seed_cur), - /* .weighted_sum = */ 0.0f, - /* .total_weight = */ 0.0f, + /* .weighted_sum = */ target / (1.0f - clamped_decay), + /* .total_weight = */ 1.0f / (1.0f - clamped_decay), /* .original_probs = */ {}, } );