improve logging messages in llama_sampler_power_law

This commit is contained in:
ddh0 2025-12-15 09:25:05 -06:00 committed by GitHub
parent 6e66095e1f
commit 9c50b573f5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 4 additions and 4 deletions

View File

@ -2369,15 +2369,15 @@ static const char * llama_sampler_power_law_name(const struct llama_sampler * /*
static float llama_sampler_power_law_compute_target(const llama_sampler_power_law * ctx) { static float llama_sampler_power_law_compute_target(const llama_sampler_power_law * ctx) {
const float base_target = ctx->target; const float base_target = ctx->target;
if (ctx->total_weight == 0.0f) { if (ctx->total_weight == 0.0f) {
fprintf(stderr, "power-law: compute_target: total_weight == 0.0 (target fixed at %.3f)\n", base_target); fprintf(stderr, "power-law: compute_target: total_weight == 0.0 (target fixed at %.3f)\n", base_target); fflush(stderr);
return base_target; return base_target;
} }
float target = 2.0f * base_target - (ctx->weighted_sum / ctx->total_weight); float target = 2.0f * base_target - (ctx->weighted_sum / ctx->total_weight);
fprintf(stderr, "power-law: compute_target: target = %.3f\n", target); fprintf(stderr, "power-law: compute_target: raw target = %.3f\n", target);
// clamp result to [0.0, 1.0] // clamp result to [0.0, 1.0]
target = std::clamp(target, 0.0f, 1.0f); target = std::clamp(target, 0.0f, 1.0f);
fprintf(stderr, "power-law: compute_target: target (post-clamp) = %.3f\n", target); fflush(stderr); fprintf(stderr, "power-law: compute_target: clamped target = %.3f\n", target); fflush(stderr);
return target; return target;
} }
@ -2407,7 +2407,7 @@ static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_tok
// power law transform // power law transform
// //
fprintf(stderr, "power-law: transform: cur_p->size = %.3f\n", (double)cur_p->size); fprintf(stderr, "power-law: transform: cur_p->size = %d\n", (size_t)cur_p->size);
for (size_t i = 0; i < cur_p->size; ++i) { for (size_t i = 0; i < cur_p->size; ++i) {
float dist = (cur_p->data[i].p - computed_target) * INV_WIDTH; float dist = (cur_p->data[i].p - computed_target) * INV_WIDTH;
cur_p->data[i].logit = PEAK_LOGIT_VALUE / (1.0f + dist * dist); cur_p->data[i].logit = PEAK_LOGIT_VALUE / (1.0f + dist * dist);