does this fix it?

This commit is contained in:
ddh0 2025-12-14 01:55:02 -06:00
parent 9613c48172
commit 2a3f579d1f
1 changed files with 21 additions and 28 deletions

View File

@ -2358,23 +2358,20 @@ static const char * llama_sampler_power_law_name(const struct llama_sampler * /*
return "power-law"; return "power-law";
} }
// compute the adaptive target probability for the current sampling step // compute the adapted target probability for the current sampling step
static float llama_sampler_power_law_compute_target(const llama_sampler_power_law * ctx, float decay) { static float llama_sampler_power_law_compute_target(const llama_sampler_power_law * ctx) {
const float base_target = ctx->target;
if (ctx->total_weight == 0.0f) { if (ctx->total_weight == 0.0f) {
// if there is no history, just use base target fprintf(stderr, "power-law: compute_target: total_weight == 0.0 (target fixed at %.3f)\n", base_target);
fprintf(stderr, "power-law: compute_target: total_weight == 0.0 (target fixed at %.3f)\n", ctx->target); return base_target;
fflush(stderr);
return ctx->target;
} }
float target = 2.0f * base_target - (ctx->weighted_sum / ctx->total_weight);
fprintf(stderr, "power-law: compute_target: target = %.3f\n", target);
// maintain a running weighted sum with exponential decay // clamp result to [0.0, 1.0]
float new_total_weight = 1.0f + decay * ctx->total_weight; target = std::max(0.0f, std::min(target, 1.0f));
fprintf(stderr, "power-law: compute_target: new_total_weight = %.3f\n", new_total_weight); fflush(stderr); fprintf(stderr, "power-law: compute_target: target (post-clamp) = %.3f\n", target); fflush(stderr);
float next_value = ctx->target * new_total_weight - decay * ctx->weighted_sum; return target;
fprintf(stderr, "power-law: compute_target: next_value = %.3f\n", next_value); fflush(stderr);
// clamp to [0.0, 1.0]
return std::max(0.0f, std::min(next_value, 1.0f));
} }
static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
@ -2393,11 +2390,6 @@ static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_tok
const float decay = std::min(ctx->decay, 0.99f); const float decay = std::min(ctx->decay, 0.99f);
fprintf(stderr, "power-law: decay = %.3f\n", decay); fflush(stderr); fprintf(stderr, "power-law: decay = %.3f\n", decay); fflush(stderr);
// fixed power law transform parameters
const float distribution_width = 0.3f;
const float peak_logit_value = 5.0f;
const float tail_heaviness = 2.0f;
// get the original probabilities // get the original probabilities
llama_sampler_softmax_impl(cur_p, false); llama_sampler_softmax_impl(cur_p, false);
@ -2408,21 +2400,22 @@ static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_tok
original_probs.push_back(cur_p->data[i].p); original_probs.push_back(cur_p->data[i].p);
} }
float computed_target = llama_sampler_power_law_compute_target(ctx, decay); float computed_target = llama_sampler_power_law_compute_target(ctx);
fprintf(stderr, "power-law: computed_target = %.3f\n", computed_target); fflush(stderr); fprintf(stderr, "power-law: computed_target = %.3f\n", computed_target); fflush(stderr);
// //
// power law transform // power law transform
// //
// transformation constants
const float distribution_width = 0.3f;
const float peak_logit_value = 5.0f;
const float inv_width = 1.0f / distribution_width;
for (size_t i = 0; i < cur_p->size; ++i) { for (size_t i = 0; i < cur_p->size; ++i) {
float p = cur_p->data[i].p; float dist = (cur_p->data[i].p - computed_target) * inv_width;
fprintf(stderr, "power-law: transform: p = %.3f\n", p); fflush(stderr); cur_p->data[i].logit = peak_logit_value / (1.0f + dist * dist);
float normed_distance = std::abs(p - computed_target) / distribution_width;
fprintf(stderr, "power-law: transform: normed_distance = %.3f\n", normed_distance); fflush(stderr);
float new_p = peak_logit_value / (1.0f + std::pow(normed_distance, tail_heaviness));
fprintf(stderr, "power-law: transform: new_p = %.3f\n", new_p); fflush(stderr);
cur_p->data[i].logit = new_p;
} }
llama_sampler_softmax_impl(cur_p, false); llama_sampler_softmax_impl(cur_p, false);
@ -2430,7 +2423,7 @@ static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_tok
// sample from transformed distribution // sample from transformed distribution
const int idx = llama_sample_dist(cur_p, ctx->rng); const int idx = llama_sample_dist(cur_p, ctx->rng);
cur_p->selected = idx; cur_p->selected = idx;
fprintf(stderr, "power-law: selected token %d\n", idx); fflush(stderr); fprintf(stderr, "power-law: selected token at index %d\n", idx); fflush(stderr);
// update running history with the original probability of the selected token // update running history with the original probability of the selected token
float original_p = original_probs[idx]; float original_p = original_probs[idx];