simplify target computation

last commit with debug logging!
This commit is contained in:
ddh0 2025-12-15 21:02:11 -06:00
parent 0344068cf1
commit 1c2d2e900d
1 changed files with 7 additions and 19 deletions

View File

@ -2365,22 +2365,6 @@ static const char * llama_sampler_power_law_name(const struct llama_sampler * /*
return "power-law";
}
// compute the adapted target probability for the current sampling step
static float llama_sampler_power_law_compute_target(const llama_sampler_power_law * ctx) {
const float base_target = ctx->target;
if (ctx->total_weight == 0.0f) {
fprintf(stderr, "power-law: compute_target: total_weight == 0.0 (target fixed at %.3f)\n", base_target); fflush(stderr);
return base_target;
}
float target = 2.0f * base_target - (ctx->weighted_sum / ctx->total_weight);
fprintf(stderr, "power-law: compute_target: raw target = %.3f\n", target);
// clamp result to [0.0, 1.0]
target = std::clamp(target, 0.0f, 1.0f);
fprintf(stderr, "power-law: compute_target: clamped target = %.3f\n", target); fflush(stderr);
return target;
}
static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
auto * ctx = (llama_sampler_power_law *) smpl->ctx;
@ -2400,13 +2384,18 @@ static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_tok
ctx->original_probs[i] = cur_p->data[i].p;
}
float computed_target = llama_sampler_power_law_compute_target(ctx);
// compute the adapted target probability for the current sampling step
float computed_target = std::clamp(
ctx->total_weight == 0.0f ? ctx->target : 2.0f * ctx->target - (ctx->weighted_sum / ctx->total_weight),
0.0f, 1.0f
);
fprintf(stderr, "power-law: computed target = %.3f\n", computed_target);
//
// power law transform
//
fprintf(stderr, "power-law: transform: cur_p->size = %d\n", (size_t)cur_p->size);
fprintf(stderr, "power-law: cur_p->size = %d\n", (int)cur_p->size);
for (size_t i = 0; i < cur_p->size; ++i) {
float dist = (cur_p->data[i].p - computed_target) * INV_WIDTH;
cur_p->data[i].logit = PEAK_LOGIT_VALUE / (1.0f + dist * dist);
@ -2421,7 +2410,6 @@ static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_tok
// update running history with the original probability of the selected token
float original_p = ctx->original_probs[idx];
fprintf(stderr, "power-law: original prob was %.3f\n", original_p);
ctx->weighted_sum = original_p + ctx->decay * ctx->weighted_sum;
fprintf(stderr, "power-law: updated ctx->weighted_sum = %.3f\n", ctx->weighted_sum);
ctx->total_weight = 1.0f + ctx->decay * ctx->total_weight;