simplify target computation
last commit with debug logging!
This commit is contained in:
parent
0344068cf1
commit
1c2d2e900d
|
|
@ -2365,22 +2365,6 @@ static const char * llama_sampler_power_law_name(const struct llama_sampler * /*
|
|||
return "power-law";
|
||||
}
|
||||
|
||||
// compute the adapted target probability for the current sampling step
|
||||
static float llama_sampler_power_law_compute_target(const llama_sampler_power_law * ctx) {
|
||||
const float base_target = ctx->target;
|
||||
if (ctx->total_weight == 0.0f) {
|
||||
fprintf(stderr, "power-law: compute_target: total_weight == 0.0 (target fixed at %.3f)\n", base_target); fflush(stderr);
|
||||
return base_target;
|
||||
}
|
||||
float target = 2.0f * base_target - (ctx->weighted_sum / ctx->total_weight);
|
||||
fprintf(stderr, "power-law: compute_target: raw target = %.3f\n", target);
|
||||
|
||||
// clamp result to [0.0, 1.0]
|
||||
target = std::clamp(target, 0.0f, 1.0f);
|
||||
fprintf(stderr, "power-law: compute_target: clamped target = %.3f\n", target); fflush(stderr);
|
||||
return target;
|
||||
}
|
||||
|
||||
static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||
auto * ctx = (llama_sampler_power_law *) smpl->ctx;
|
||||
|
||||
|
|
@ -2400,13 +2384,18 @@ static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_tok
|
|||
ctx->original_probs[i] = cur_p->data[i].p;
|
||||
}
|
||||
|
||||
float computed_target = llama_sampler_power_law_compute_target(ctx);
|
||||
// compute the adapted target probability for the current sampling step
|
||||
float computed_target = std::clamp(
|
||||
ctx->total_weight == 0.0f ? ctx->target : 2.0f * ctx->target - (ctx->weighted_sum / ctx->total_weight),
|
||||
0.0f, 1.0f
|
||||
);
|
||||
fprintf(stderr, "power-law: computed target = %.3f\n", computed_target);
|
||||
|
||||
//
|
||||
// power law transform
|
||||
//
|
||||
|
||||
fprintf(stderr, "power-law: transform: cur_p->size = %d\n", (size_t)cur_p->size);
|
||||
fprintf(stderr, "power-law: cur_p->size = %d\n", (int)cur_p->size);
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
float dist = (cur_p->data[i].p - computed_target) * INV_WIDTH;
|
||||
cur_p->data[i].logit = PEAK_LOGIT_VALUE / (1.0f + dist * dist);
|
||||
|
|
@ -2421,7 +2410,6 @@ static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_tok
|
|||
|
||||
// update running history with the original probability of the selected token
|
||||
float original_p = ctx->original_probs[idx];
|
||||
fprintf(stderr, "power-law: original prob was %.3f\n", original_p);
|
||||
ctx->weighted_sum = original_p + ctx->decay * ctx->weighted_sum;
|
||||
fprintf(stderr, "power-law: updated ctx->weighted_sum = %.3f\n", ctx->weighted_sum);
|
||||
ctx->total_weight = 1.0f + ctx->decay * ctx->total_weight;
|
||||
|
|
|
|||
Loading…
Reference in New Issue