fix cold start EMA
- `ctx->weighted_sum` is now initialized and reset to `target / (1.0f - clamped_decay)` - `ctx->total_weight` is now initialized and reset to `1.0f / (1.0f - clamped_decay)` this fixes a "cold start" problem with the moving average
This commit is contained in:
parent
f0d3f13124
commit
e7a892065d
|
|
@ -2338,7 +2338,7 @@ struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, floa
|
|||
// ref: https://github.com/ggml-org/llama.cpp/pull/17927
|
||||
struct llama_sampler_adaptive_p {
|
||||
const float target; // target probability (0.0 - 1.0; negative = disabled)
|
||||
const float decay; // EMA decay; history ≈ 1/(1-decay) tokens (0.0 - 0.99)
|
||||
const float decay; // EMA decay; history ~= 1/(1-decay) tokens (0.0 - 0.99)
|
||||
const uint32_t seed; // RNG seed
|
||||
std::mt19937 rng; // RNG
|
||||
float weighted_sum; // sum(p_i * decay^i)
|
||||
|
|
@ -2397,15 +2397,18 @@ static void llama_sampler_adaptive_p_apply(struct llama_sampler * smpl, llama_to
|
|||
const int idx = llama_sample_dist(cur_p, ctx->rng);
|
||||
cur_p->selected = idx;
|
||||
|
||||
// update history with the original probability of the selected token
|
||||
// update EMA with the original probability of the selected token
|
||||
ctx->weighted_sum = ctx->original_probs[idx] + ctx->decay * ctx->weighted_sum;
|
||||
ctx->total_weight = 1.0f + ctx->decay * ctx->total_weight;
|
||||
}
|
||||
|
||||
static void llama_sampler_adaptive_p_reset(struct llama_sampler * smpl) {
|
||||
auto * ctx = (llama_sampler_adaptive_p *) smpl->ctx;
|
||||
ctx->weighted_sum = 0.0f;
|
||||
ctx->total_weight = 0.0f;
|
||||
auto * ctx = (llama_sampler_adaptive_p *) smpl->ctx;
|
||||
// ctx->target and ctx->decay never change after init, so it's safe to keep them as is.
|
||||
// original_probs is completely overwritten on every call to _apply.
|
||||
// so we only need to reset the EMA state.
|
||||
ctx->weighted_sum = ctx->target / (1.0f - ctx->decay);
|
||||
ctx->total_weight = 1.0f / (1.0f - ctx->decay);
|
||||
}
|
||||
|
||||
static struct llama_sampler * llama_sampler_adaptive_p_clone(const struct llama_sampler * smpl) {
|
||||
|
|
@ -2413,10 +2416,11 @@ static struct llama_sampler * llama_sampler_adaptive_p_clone(const struct llama_
|
|||
auto * result = llama_sampler_init_adaptive_p(ctx->target, ctx->decay, ctx->seed);
|
||||
auto * result_ctx = (llama_sampler_adaptive_p *) result->ctx;
|
||||
|
||||
result_ctx->rng = ctx->rng;
|
||||
result_ctx->weighted_sum = ctx->weighted_sum;
|
||||
result_ctx->total_weight = ctx->total_weight;
|
||||
result_ctx->original_probs.reserve(ctx->original_probs.capacity());
|
||||
// copy everything (target, decay, and seed are already set)
|
||||
result_ctx->original_probs = ctx->original_probs;
|
||||
result_ctx->weighted_sum = ctx->weighted_sum;
|
||||
result_ctx->total_weight = ctx->total_weight;
|
||||
result_ctx->rng = ctx->rng;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
|
@ -2440,15 +2444,16 @@ struct llama_sampler * llama_sampler_init_adaptive_p(
|
|||
uint32_t seed
|
||||
) {
|
||||
auto seed_cur = get_rng_seed(seed);
|
||||
float clamped_decay = std::clamp(decay, 0.0f, 0.99f);
|
||||
return llama_sampler_init(
|
||||
/* .iface = */ &llama_sampler_adaptive_p_i,
|
||||
/* .ctx = */ new llama_sampler_adaptive_p {
|
||||
/* .target = */ target,
|
||||
/* .decay = */ std::clamp(decay, 0.0f, 0.99f),
|
||||
/* .decay = */ clamped_decay,
|
||||
/* .seed = */ seed_cur,
|
||||
/* .rng = */ std::mt19937(seed_cur),
|
||||
/* .weighted_sum = */ 0.0f,
|
||||
/* .total_weight = */ 0.0f,
|
||||
/* .weighted_sum = */ target / (1.0f - clamped_decay),
|
||||
/* .total_weight = */ 1.0f / (1.0f - clamped_decay),
|
||||
/* .original_probs = */ {},
|
||||
}
|
||||
);
|
||||
|
|
|
|||
Loading…
Reference in New Issue