initial commit for branch
This commit is contained in:
parent
34ce48d97a
commit
774cf23ee5
|
|
@ -1289,6 +1289,15 @@ extern "C" {
|
||||||
const char ** seq_breakers,
|
const char ** seq_breakers,
|
||||||
size_t num_breakers);
|
size_t num_breakers);
|
||||||
|
|
||||||
|
/// @details power law sampler, reshapes probability distribution to target specific probability ranges
|
||||||
|
/// ref: https://github.com/MrJackSpade/llama.cpp
|
||||||
|
/// ref: [PR]
|
||||||
|
LLAMA_API struct llama_sampler * llama_sampler_init_power_law(
|
||||||
|
float target, // target probability (0.0 to 1.0)
|
||||||
|
float target_range, // adaptive target range (±range from target)
|
||||||
|
int32_t queue_size, // rolling history window size for adaptation
|
||||||
|
uint32_t seed); // RNG seed
|
||||||
|
|
||||||
LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias(
|
LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias(
|
||||||
int32_t n_vocab,
|
int32_t n_vocab,
|
||||||
int32_t n_logit_bias,
|
int32_t n_logit_bias,
|
||||||
|
|
|
||||||
|
|
@ -2313,6 +2313,140 @@ struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, floa
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// power-law
|
||||||
|
// ref: https://github.com/MrJackSpade/llama.cpp/tree/master
|
||||||
|
// ref: [PR]
|
||||||
|
|
||||||
|
struct llama_sampler_power_law {
|
||||||
|
const float target;
|
||||||
|
const float target_range;
|
||||||
|
const int32_t queue_size;
|
||||||
|
const uint32_t seed;
|
||||||
|
|
||||||
|
std::mt19937 rng;
|
||||||
|
ring_buffer<float> history;
|
||||||
|
};
|
||||||
|
|
||||||
|
static const char * llama_sampler_power_law_name(const struct llama_sampler * /*smpl*/) {
|
||||||
|
return "power-law";
|
||||||
|
}
|
||||||
|
|
||||||
|
static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||||
|
auto * ctx = (llama_sampler_power_law *) smpl->ctx;
|
||||||
|
|
||||||
|
// these don't need to be modified or exposed to the user
|
||||||
|
const float peak_logit_value = 3.0f;
|
||||||
|
const float tail_heaviness = 3.0f;
|
||||||
|
|
||||||
|
const float min_target = ctx->target - ctx->target_range;
|
||||||
|
const float max_target = ctx->target + ctx->target_range;
|
||||||
|
|
||||||
|
// compute probabilities to get the "original" values
|
||||||
|
llama_sampler_softmax_impl(cur_p, false);
|
||||||
|
|
||||||
|
// store original probabilities (needed for history update)
|
||||||
|
std::vector<float> original_probs;
|
||||||
|
original_probs.reserve(cur_p->size);
|
||||||
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||||
|
original_probs.push_back(cur_p->data[i].p);
|
||||||
|
}
|
||||||
|
|
||||||
|
// calculate adaptive target
|
||||||
|
float computed_target = ctx->target;
|
||||||
|
if (ctx->history.size() > 0) {
|
||||||
|
float sum_excluding_oldest = 0.0f;
|
||||||
|
size_t sz = ctx->history.size();
|
||||||
|
|
||||||
|
// sum all except the oldest element
|
||||||
|
for (size_t i = 0; i < sz - 1; ++i) {
|
||||||
|
sum_excluding_oldest += ctx->history.rat(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
float next_value = (ctx->target * ctx->queue_size) - sum_excluding_oldest;
|
||||||
|
computed_target = std::max(min_target, std::min(next_value, max_target));
|
||||||
|
}
|
||||||
|
|
||||||
|
// find closest token (for degenerate width ~ 0 case)
|
||||||
|
float min_distance = FLT_MAX;
|
||||||
|
int closest_token_idx = -1;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||||
|
float distance = std::abs(cur_p->data[i].p - computed_target);
|
||||||
|
if (distance < min_distance) {
|
||||||
|
min_distance = distance;
|
||||||
|
closest_token_idx = (int) i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// apply power law transformation
|
||||||
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||||
|
float p = cur_p->data[i].p;
|
||||||
|
|
||||||
|
float distance = std::abs(p - computed_target);
|
||||||
|
float normalized_distance = distance / 0.2f;
|
||||||
|
cur_p->data[i].logit = peak_logit_value / (1.0f + std::pow(normalized_distance, tail_heaviness));
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_sampler_softmax_impl(cur_p, false);
|
||||||
|
|
||||||
|
// sample from distribution
|
||||||
|
const int idx = llama_sample_dist(cur_p, ctx->rng);
|
||||||
|
|
||||||
|
// set sampled token
|
||||||
|
cur_p->selected = idx;
|
||||||
|
|
||||||
|
// update history with ORIGINAL probability
|
||||||
|
ctx->history.push_back(original_probs[idx]);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void llama_sampler_power_law_reset(struct llama_sampler * smpl) {
|
||||||
|
auto * ctx = (llama_sampler_power_law *) smpl->ctx;
|
||||||
|
ctx->history = ring_buffer<float>(ctx->queue_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
static struct llama_sampler * llama_sampler_power_law_clone(const struct llama_sampler * smpl) {
|
||||||
|
const auto * ctx = (const llama_sampler_power_law *) smpl->ctx;
|
||||||
|
auto * result = llama_sampler_init_power_law(ctx->target, ctx->target_range, ctx->queue_size, ctx->seed);
|
||||||
|
auto * result_ctx = (llama_sampler_power_law *) result->ctx;
|
||||||
|
|
||||||
|
result_ctx->history = ctx->history;
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void llama_sampler_power_law_free(struct llama_sampler * smpl) {
|
||||||
|
delete (llama_sampler_power_law *) smpl->ctx;
|
||||||
|
}
|
||||||
|
|
||||||
|
static struct llama_sampler_i llama_sampler_power_law_i = {
|
||||||
|
/* .name = */ llama_sampler_power_law_name,
|
||||||
|
/* .accept = */ nullptr,
|
||||||
|
/* .apply = */ llama_sampler_power_law_apply,
|
||||||
|
/* .reset = */ llama_sampler_power_law_reset,
|
||||||
|
/* .clone = */ llama_sampler_power_law_clone,
|
||||||
|
/* .free = */ llama_sampler_power_law_free,
|
||||||
|
};
|
||||||
|
|
||||||
|
struct llama_sampler * llama_sampler_init_power_law(
|
||||||
|
float target,
|
||||||
|
float target_range,
|
||||||
|
int32_t queue_size,
|
||||||
|
uint32_t seed
|
||||||
|
) {
|
||||||
|
auto seed_cur = get_rng_seed(seed);
|
||||||
|
return llama_sampler_init(
|
||||||
|
/* .iface = */ &llama_sampler_power_law_i,
|
||||||
|
/* .ctx = */ new llama_sampler_power_law {
|
||||||
|
/* .target = */ target,
|
||||||
|
/* .target_range = */ target_range,
|
||||||
|
/* .queue_size = */ queue_size,
|
||||||
|
/* .seed = */ seed_cur,
|
||||||
|
/* .rng = */ std::mt19937(seed_cur),
|
||||||
|
/* .history = */ ring_buffer<float>(queue_size),
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
// logit-bias
|
// logit-bias
|
||||||
|
|
||||||
struct llama_sampler_logit_bias {
|
struct llama_sampler_logit_bias {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue