diff --git a/include/llama.h b/include/llama.h index b52eaacfa7..7e1e65523b 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1289,6 +1289,15 @@ extern "C" { const char ** seq_breakers, size_t num_breakers); + /// @details power law sampler, reshapes probability distribution to target specific probability ranges + /// ref: https://github.com/MrJackSpade/llama.cpp + /// ref: [PR] + LLAMA_API struct llama_sampler * llama_sampler_init_power_law( + float target, // target probability (0.0 to 1.0) + float target_range, // adaptive target range (±range from target) + int32_t queue_size, // rolling history window size for adaptation + uint32_t seed); // RNG seed + LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias( int32_t n_vocab, int32_t n_logit_bias, diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 3f4a729bc3..6ef8121d7c 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -2313,6 +2313,140 @@ struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, floa return result; } +// power-law +// ref: https://github.com/MrJackSpade/llama.cpp/tree/master +// ref: [PR] + +struct llama_sampler_power_law { + const float target; + const float target_range; + const int32_t queue_size; + const uint32_t seed; + + std::mt19937 rng; + ring_buffer history; +}; + +static const char * llama_sampler_power_law_name(const struct llama_sampler * /*smpl*/) { + return "power-law"; +} + +static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_power_law *) smpl->ctx; + + // these don't need to be modified or exposed to the user + const float peak_logit_value = 3.0f; + const float tail_heaviness = 3.0f; + + const float min_target = ctx->target - ctx->target_range; + const float max_target = ctx->target + ctx->target_range; + + // compute probabilities to get the "original" values + llama_sampler_softmax_impl(cur_p, false); + + // store original probabilities (needed for history update) + std::vector original_probs; + original_probs.reserve(cur_p->size); + for (size_t i = 0; i < cur_p->size; ++i) { + original_probs.push_back(cur_p->data[i].p); + } + + // calculate adaptive target + float computed_target = ctx->target; + if (ctx->history.size() > 0) { + float sum_excluding_oldest = 0.0f; + size_t sz = ctx->history.size(); + + // sum all except the oldest element + for (size_t i = 0; i < sz - 1; ++i) { + sum_excluding_oldest += ctx->history.rat(i); + } + + float next_value = (ctx->target * ctx->queue_size) - sum_excluding_oldest; + computed_target = std::max(min_target, std::min(next_value, max_target)); + } + + // find closest token (for degenerate width ~ 0 case) + float min_distance = FLT_MAX; + int closest_token_idx = -1; + + for (size_t i = 0; i < cur_p->size; ++i) { + float distance = std::abs(cur_p->data[i].p - computed_target); + if (distance < min_distance) { + min_distance = distance; + closest_token_idx = (int) i; + } + } + + // apply power law transformation + for (size_t i = 0; i < cur_p->size; ++i) { + float p = cur_p->data[i].p; + + float distance = std::abs(p - computed_target); + float normalized_distance = distance / 0.2f; + cur_p->data[i].logit = peak_logit_value / (1.0f + std::pow(normalized_distance, tail_heaviness)); + } + + llama_sampler_softmax_impl(cur_p, false); + + // sample from distribution + const int idx = llama_sample_dist(cur_p, ctx->rng); + + // set sampled token + cur_p->selected = idx; + + // update history with ORIGINAL probability + ctx->history.push_back(original_probs[idx]); +} + +static void llama_sampler_power_law_reset(struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_power_law *) smpl->ctx; + ctx->history = ring_buffer(ctx->queue_size); +} + +static struct llama_sampler * llama_sampler_power_law_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_power_law *) smpl->ctx; + auto * result = llama_sampler_init_power_law(ctx->target, ctx->target_range, ctx->queue_size, ctx->seed); + auto * result_ctx = (llama_sampler_power_law *) result->ctx; + + result_ctx->history = ctx->history; + + return result; +} + +static void llama_sampler_power_law_free(struct llama_sampler * smpl) { + delete (llama_sampler_power_law *) smpl->ctx; +} + +static struct llama_sampler_i llama_sampler_power_law_i = { + /* .name = */ llama_sampler_power_law_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_power_law_apply, + /* .reset = */ llama_sampler_power_law_reset, + /* .clone = */ llama_sampler_power_law_clone, + /* .free = */ llama_sampler_power_law_free, +}; + +struct llama_sampler * llama_sampler_init_power_law( + float target, + float target_range, + int32_t queue_size, + uint32_t seed +) { + auto seed_cur = get_rng_seed(seed); + return llama_sampler_init( + /* .iface = */ &llama_sampler_power_law_i, + /* .ctx = */ new llama_sampler_power_law { + /* .target = */ target, + /* .target_range = */ target_range, + /* .queue_size = */ queue_size, + /* .seed = */ seed_cur, + /* .rng = */ std::mt19937(seed_cur), + /* .history = */ ring_buffer(queue_size), + } + ); +} + // logit-bias struct llama_sampler_logit_bias {