diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 738fd05caa..5871668d96 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -2389,10 +2389,12 @@ static const char * llama_sampler_power_law_name(const struct llama_sampler * /* // Finally, the computed target is clamped to [min_target, max_target] to // prevent extreme values that could destabilize sampling. // -static float llama_sampler_power_law_compute_target(const llama_sampler_power_law * ctx, - float min_target, - float max_target, - float tail_decay) { +static float llama_sampler_power_law_compute_target( + const llama_sampler_power_law * ctx, + float min_target, + float max_target, + float tail_decay) { + float computed_target = ctx->target; size_t sz = ctx->window.size(); @@ -2416,6 +2418,10 @@ static float llama_sampler_power_law_compute_target(const llama_sampler_power_la weight *= tail_decay; } + // Shift weights to account for new value taking position 0 + // All existing values age by 1, so multiply their weights by decay + float shifted_weighted_sum = weighted_sum * tail_decay; + // Compute total weight after new value is inserted // When full: sz elements remain (oldest evicted, new added) // When not full: sz + 1 elements (new added, nothing evicted) @@ -2428,10 +2434,6 @@ static float llama_sampler_power_law_compute_target(const llama_sampler_power_la total_weight = (1.0f - std::pow(tail_decay, (float) final_element_count)) / (1.0f - tail_decay); } - // Shift weights to account for new value taking position 0 - // All existing values age by 1, so multiply their weights by decay - float shifted_weighted_sum = weighted_sum * tail_decay; - // Solve for the new value that achieves target weighted average float next_value = (ctx->target * total_weight) - shifted_weighted_sum; @@ -2446,7 +2448,6 @@ static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_tok auto * ctx = (llama_sampler_power_law *) smpl->ctx; if (ctx->target < 0.0f) { - fprintf(stderr, "Target below zero, sampling from distribution\n"); // no-op: just sample from the distribution as-is llama_sampler_softmax_impl(cur_p, false); const int idx = llama_sample_dist(cur_p, ctx->rng); @@ -2462,9 +2463,9 @@ static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_tok // target computation parameters const float min_target = 0.0f; const float max_target = 1.0f; - const float tail_decay = 0.50f; // Exponential decay factor for history weighting - // Lower = faster response, higher = more stability - // Effective window ≈ 1/(1-decay) ≈ 20 tokens + const float tail_decay = 0.50f; // exponential decay factor for history weighting + // lower = faster response, higher = more stability + // effective window ≈ 1/(1-decay) ≈ 20 tokens // compute probabilities to get the "original" values llama_sampler_softmax_impl(cur_p, false); @@ -2479,7 +2480,10 @@ static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_tok // calculate adaptive target float computed_target = llama_sampler_power_law_compute_target(ctx, min_target, max_target, tail_decay); + // // power law transform + // + for (size_t i = 0; i < cur_p->size; ++i) { float p = cur_p->data[i].p; float normalized_distance = std::abs(p - computed_target) / distribution_width;