diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 6beb927a6c..78fe7706b9 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -2349,11 +2349,18 @@ struct llama_sampler_power_law { std::mt19937 rng; // historical token probabilities weighted by recency - float weighted_sum; + float weighted_sum; // sum of weights, converges to 1/(1-decay) - float total_weight; + float total_weight; + // used to store original token probabilities (needed for history update after selection) + std::vector original_probs; }; +// transformation constants +static constexpr float DISTRIBUTION_WIDTH = 0.3f; +static constexpr float PEAK_LOGIT_VALUE = 5.0f; +static constexpr float INV_WIDTH = 1.0f / DISTRIBUTION_WIDTH; + static const char * llama_sampler_power_law_name(const struct llama_sampler * /*smpl*/) { return "power-law"; } @@ -2369,7 +2376,7 @@ static float llama_sampler_power_law_compute_target(const llama_sampler_power_la fprintf(stderr, "power-law: compute_target: target = %.3f\n", target); // clamp result to [0.0, 1.0] - target = std::max(0.0f, std::min(target, 1.0f)); + target = std::clamp(target, 0.0f, 1.0f); fprintf(stderr, "power-law: compute_target: target (post-clamp) = %.3f\n", target); fflush(stderr); return target; } @@ -2379,43 +2386,32 @@ static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_tok if (ctx->target < 0.0f) { // no-op: just sample from the distribution as-is - fprintf(stderr, "power-law: no-op!"); fflush(stderr); + fprintf(stderr, "power-law: no-op!"); llama_sampler_softmax_impl(cur_p, false); const int idx = llama_sample_dist(cur_p, ctx->rng); cur_p->selected = idx; return; } - // clamp decay to avoid degenerate case at 1.0 (unbounded accumulation) - const float decay = std::min(ctx->decay, 0.99f); - fprintf(stderr, "power-law: decay = %.3f\n", decay); fflush(stderr); - // get the original probabilities llama_sampler_softmax_impl(cur_p, false); - // store the original probabilities (needed for history update after selection) - std::vector original_probs; - original_probs.reserve(cur_p->size); + // store the original probabilities + ctx->original_probs.resize(cur_p->size); for (size_t i = 0; i < cur_p->size; ++i) { - original_probs.push_back(cur_p->data[i].p); + ctx->original_probs[i] = cur_p->data[i].p; } float computed_target = llama_sampler_power_law_compute_target(ctx); - fprintf(stderr, "power-law: computed_target = %.3f\n", computed_target); fflush(stderr); + fprintf(stderr, "power-law: computed_target = %.3f\n", computed_target); // // power law transform // - // transformation constants - const float distribution_width = 0.3f; - const float peak_logit_value = 5.0f; - - const float inv_width = 1.0f / distribution_width; - for (size_t i = 0; i < cur_p->size; ++i) { - float dist = (cur_p->data[i].p - computed_target) * inv_width; - cur_p->data[i].logit = peak_logit_value / (1.0f + dist * dist); + float dist = (cur_p->data[i].p - computed_target) * INV_WIDTH; + cur_p->data[i].logit = PEAK_LOGIT_VALUE / (1.0f + dist * dist); } llama_sampler_softmax_impl(cur_p, false); @@ -2423,14 +2419,14 @@ static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_tok // sample from transformed distribution const int idx = llama_sample_dist(cur_p, ctx->rng); cur_p->selected = idx; - fprintf(stderr, "power-law: selected token at index %d\n", idx); fflush(stderr); + fprintf(stderr, "power-law: selected token at index %d\n", idx); // update running history with the original probability of the selected token - float original_p = original_probs[idx]; - fprintf(stderr, "power-law: original prob was %.3f\n", original_p); fflush(stderr); - ctx->weighted_sum = original_p + decay * ctx->weighted_sum; - fprintf(stderr, "power-law: updated ctx->weighted_sum = %.3f\n", ctx->weighted_sum); fflush(stderr); - ctx->total_weight = 1.0f + decay * ctx->total_weight; + float original_p = ctx->original_probs[idx]; + fprintf(stderr, "power-law: original prob was %.3f\n", original_p); + ctx->weighted_sum = original_p + ctx->decay * ctx->weighted_sum; + fprintf(stderr, "power-law: updated ctx->weighted_sum = %.3f\n", ctx->weighted_sum); + ctx->total_weight = 1.0f + ctx->decay * ctx->total_weight; fprintf(stderr, "power-law: updated ctx->total_weight = %.3f\n", ctx->total_weight); fflush(stderr); } @@ -2448,6 +2444,7 @@ static struct llama_sampler * llama_sampler_power_law_clone(const struct llama_s result_ctx->rng = ctx->rng; result_ctx->weighted_sum = ctx->weighted_sum; result_ctx->total_weight = ctx->total_weight; + result_ctx->original_probs.reserve(ctx->original_probs.capacity()); return result; } @@ -2475,7 +2472,7 @@ struct llama_sampler * llama_sampler_init_power_law( /* .iface = */ &llama_sampler_power_law_i, /* .ctx = */ new llama_sampler_power_law { /* .target = */ target, - /* .decay = */ decay, + /* .decay = */ std::min(decay, 0.99f), /* .seed = */ seed_cur, /* .rng = */ std::mt19937(seed_cur), /* .weighted_sum = */ 0.0f,