diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 7684c8f38c..77ec141a56 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -2370,10 +2370,8 @@ static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_tok if (ctx->target < 0.0f) { // no-op: just sample from the distribution as-is - fprintf(stderr, "power-law: no-op!\n"); fflush(stderr); llama_sampler_softmax_impl(cur_p, false); - const int idx = llama_sample_dist(cur_p, ctx->rng); - cur_p->selected = idx; + cur_p->selected = llama_sample_dist(cur_p, ctx->rng); return; } @@ -2389,13 +2387,8 @@ static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_tok ctx->total_weight == 0.0f ? ctx->target : 2.0f * ctx->target - (ctx->weighted_sum / ctx->total_weight), 0.0f, 1.0f ); - fprintf(stderr, "power-law: computed target = %.3f\n", computed_target); - // // power law transform - // - - fprintf(stderr, "power-law: cur_p->size = %d\n", (int)cur_p->size); for (size_t i = 0; i < cur_p->size; ++i) { float dist = (cur_p->data[i].p - computed_target) * INV_WIDTH; cur_p->data[i].logit = PEAK_LOGIT_VALUE / (1.0f + dist * dist); @@ -2406,14 +2399,10 @@ static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_tok // sample from transformed distribution const int idx = llama_sample_dist(cur_p, ctx->rng); cur_p->selected = idx; - fprintf(stderr, "power-law: selected token at index %d\n", idx); // update running history with the original probability of the selected token - float original_p = ctx->original_probs[idx]; - ctx->weighted_sum = original_p + ctx->decay * ctx->weighted_sum; - fprintf(stderr, "power-law: updated ctx->weighted_sum = %.3f\n", ctx->weighted_sum); - ctx->total_weight = 1.0f + ctx->decay * ctx->total_weight; - fprintf(stderr, "power-law: updated ctx->total_weight = %.3f\n", ctx->total_weight); fflush(stderr); + ctx->weighted_sum = ctx->original_probs[idx] + ctx->decay * ctx->weighted_sum; + ctx->total_weight = 1.0f + ctx->decay * ctx->total_weight; // history fades over time } static void llama_sampler_power_law_reset(struct llama_sampler * smpl) { @@ -2453,15 +2442,12 @@ struct llama_sampler * llama_sampler_init_power_law( float decay, uint32_t seed ) { - const float _decay = std::min(decay, 0.99f); - fprintf(stderr, "power-law: init: target %.3f, decay %.3f\n", (double)target, (double)_decay); - fflush(stderr); auto seed_cur = get_rng_seed(seed); return llama_sampler_init( /* .iface = */ &llama_sampler_power_law_i, /* .ctx = */ new llama_sampler_power_law { - /* .target = */ target, - /* .decay = */ _decay, + /* .target = */ std::clamp(target, 0.0f, 1.0f), + /* .decay = */ std::clamp(decay, 0.0f, 0.99f), /* .seed = */ seed_cur, /* .rng = */ std::mt19937(seed_cur), /* .weighted_sum = */ 0.0f,