improve logging messages in llama_sampler_power_law
This commit is contained in:
parent
6e66095e1f
commit
9c50b573f5
|
|
@ -2369,15 +2369,15 @@ static const char * llama_sampler_power_law_name(const struct llama_sampler * /*
|
||||||
static float llama_sampler_power_law_compute_target(const llama_sampler_power_law * ctx) {
|
static float llama_sampler_power_law_compute_target(const llama_sampler_power_law * ctx) {
|
||||||
const float base_target = ctx->target;
|
const float base_target = ctx->target;
|
||||||
if (ctx->total_weight == 0.0f) {
|
if (ctx->total_weight == 0.0f) {
|
||||||
fprintf(stderr, "power-law: compute_target: total_weight == 0.0 (target fixed at %.3f)\n", base_target);
|
fprintf(stderr, "power-law: compute_target: total_weight == 0.0 (target fixed at %.3f)\n", base_target); fflush(stderr);
|
||||||
return base_target;
|
return base_target;
|
||||||
}
|
}
|
||||||
float target = 2.0f * base_target - (ctx->weighted_sum / ctx->total_weight);
|
float target = 2.0f * base_target - (ctx->weighted_sum / ctx->total_weight);
|
||||||
fprintf(stderr, "power-law: compute_target: target = %.3f\n", target);
|
fprintf(stderr, "power-law: compute_target: raw target = %.3f\n", target);
|
||||||
|
|
||||||
// clamp result to [0.0, 1.0]
|
// clamp result to [0.0, 1.0]
|
||||||
target = std::clamp(target, 0.0f, 1.0f);
|
target = std::clamp(target, 0.0f, 1.0f);
|
||||||
fprintf(stderr, "power-law: compute_target: target (post-clamp) = %.3f\n", target); fflush(stderr);
|
fprintf(stderr, "power-law: compute_target: clamped target = %.3f\n", target); fflush(stderr);
|
||||||
return target;
|
return target;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -2407,7 +2407,7 @@ static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_tok
|
||||||
// power law transform
|
// power law transform
|
||||||
//
|
//
|
||||||
|
|
||||||
fprintf(stderr, "power-law: transform: cur_p->size = %.3f\n", (double)cur_p->size);
|
fprintf(stderr, "power-law: transform: cur_p->size = %d\n", (size_t)cur_p->size);
|
||||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||||
float dist = (cur_p->data[i].p - computed_target) * INV_WIDTH;
|
float dist = (cur_p->data[i].p - computed_target) * INV_WIDTH;
|
||||||
cur_p->data[i].logit = PEAK_LOGIT_VALUE / (1.0f + dist * dist);
|
cur_p->data[i].logit = PEAK_LOGIT_VALUE / (1.0f + dist * dist);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue