diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 7b48e5d970..7684c8f38c 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -2365,22 +2365,6 @@ static const char * llama_sampler_power_law_name(const struct llama_sampler * /* return "power-law"; } -// compute the adapted target probability for the current sampling step -static float llama_sampler_power_law_compute_target(const llama_sampler_power_law * ctx) { - const float base_target = ctx->target; - if (ctx->total_weight == 0.0f) { - fprintf(stderr, "power-law: compute_target: total_weight == 0.0 (target fixed at %.3f)\n", base_target); fflush(stderr); - return base_target; - } - float target = 2.0f * base_target - (ctx->weighted_sum / ctx->total_weight); - fprintf(stderr, "power-law: compute_target: raw target = %.3f\n", target); - - // clamp result to [0.0, 1.0] - target = std::clamp(target, 0.0f, 1.0f); - fprintf(stderr, "power-law: compute_target: clamped target = %.3f\n", target); fflush(stderr); - return target; -} - static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { auto * ctx = (llama_sampler_power_law *) smpl->ctx; @@ -2400,13 +2384,18 @@ static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_tok ctx->original_probs[i] = cur_p->data[i].p; } - float computed_target = llama_sampler_power_law_compute_target(ctx); + // compute the adapted target probability for the current sampling step + float computed_target = std::clamp( + ctx->total_weight == 0.0f ? ctx->target : 2.0f * ctx->target - (ctx->weighted_sum / ctx->total_weight), + 0.0f, 1.0f + ); + fprintf(stderr, "power-law: computed target = %.3f\n", computed_target); // // power law transform // - fprintf(stderr, "power-law: transform: cur_p->size = %d\n", (size_t)cur_p->size); + fprintf(stderr, "power-law: cur_p->size = %d\n", (int)cur_p->size); for (size_t i = 0; i < cur_p->size; ++i) { float dist = (cur_p->data[i].p - computed_target) * INV_WIDTH; cur_p->data[i].logit = PEAK_LOGIT_VALUE / (1.0f + dist * dist); @@ -2421,7 +2410,6 @@ static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_tok // update running history with the original probability of the selected token float original_p = ctx->original_probs[idx]; - fprintf(stderr, "power-law: original prob was %.3f\n", original_p); ctx->weighted_sum = original_p + ctx->decay * ctx->weighted_sum; fprintf(stderr, "power-law: updated ctx->weighted_sum = %.3f\n", ctx->weighted_sum); ctx->total_weight = 1.0f + ctx->decay * ctx->total_weight;