This commit is contained in:
ddh0 2025-12-16 22:47:04 -05:00 committed by GitHub
commit 7fd10c8ea0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 232 additions and 30 deletions

View File

@ -1572,6 +1572,24 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
} }
} }
).set_sparam()); ).set_sparam());
add_opt(common_arg(
{"--power-law-target"}, "N",
string_format("power law sampler: select tokens near this probability (valid range 0.0 "
"to 1.0; <0 = disabled) (default: %.2f)\n"
"[(more info)]""(https://github.com/ggml-org/llama.cpp/pull/17927)",
(double)params.sampling.power_law_target),
[](common_params & params, const std::string & value) {
params.sampling.power_law_target = std::stof(value);
}
).set_sparam());
add_opt(common_arg(
{"--power-law-decay"}, "N",
string_format("decay rate for target adaptation over time. lower values -> faster but less stable adaptation.\n"
"(valid range 0.0 to 1.0; ≤0 = no adaptation) (default: %.2f)", (double)params.sampling.power_law_decay),
[](common_params & params, const std::string & value) {
params.sampling.power_law_decay = std::stof(value);
}
).set_sparam());
add_opt(common_arg( add_opt(common_arg(
{"--dynatemp-range"}, "N", {"--dynatemp-range"}, "N",
string_format("dynamic temperature range (default: %.1f, 0.0 = disabled)", (double)params.sampling.dynatemp_range), string_format("dynamic temperature range (default: %.1f, 0.0 = disabled)", (double)params.sampling.dynatemp_range),

View File

@ -117,6 +117,7 @@ enum common_sampler_type {
COMMON_SAMPLER_TYPE_INFILL = 9, COMMON_SAMPLER_TYPE_INFILL = 9,
COMMON_SAMPLER_TYPE_PENALTIES = 10, COMMON_SAMPLER_TYPE_PENALTIES = 10,
COMMON_SAMPLER_TYPE_TOP_N_SIGMA = 11, COMMON_SAMPLER_TYPE_TOP_N_SIGMA = 11,
COMMON_SAMPLER_TYPE_POWER_LAW = 12,
}; };
// dimensionality reduction methods, used by cvector-generator // dimensionality reduction methods, used by cvector-generator
@ -184,8 +185,10 @@ struct common_params_sampling {
float dry_base = 1.75f; // 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length) float dry_base = 1.75f; // 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length)
int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty
int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size) int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
float power_law_target = -1.0f; // select tokens near this probability (valid range 0.0 to 1.0; <0 = disabled)
float power_law_decay = 0.90f; // decay rate for target adaptation over time. lower values -> faster but less stable adaptation. (valid range 0.0 to 1.0; ≤0 = no adaptation)
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
float top_n_sigma = -1.00f;// -1.0 = disabled float top_n_sigma = -1.00f; // -1.0 = disabled
float mirostat_tau = 5.00f; // target entropy float mirostat_tau = 5.00f; // target entropy
float mirostat_eta = 0.10f; // learning rate float mirostat_eta = 0.10f; // learning rate
bool ignore_eos = false; bool ignore_eos = false;

View File

@ -151,11 +151,11 @@ std::string common_params_sampling::print() const {
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n" "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
"\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n" "\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n"
"\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, top_n_sigma = %.3f, temp = %.3f\n" "\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, top_n_sigma = %.3f, temp = %.3f\n"
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f", "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f, power_law_target = %.3f, power_law_decay = %.3f",
penalty_last_n, penalty_repeat, penalty_freq, penalty_present, penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n,
top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, top_n_sigma, temp, top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, top_n_sigma, temp,
mirostat, mirostat_eta, mirostat_tau); mirostat, mirostat_eta, mirostat_tau, power_law_target, power_law_decay);
return std::string(result); return std::string(result);
} }
@ -241,6 +241,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
} }
if (params.mirostat == 0) { if (params.mirostat == 0) {
// if this flag is set, we will not need to add `dist` at the end of the sampler chain
bool has_distribution_sampler = false;
for (const auto & cnstr : params.samplers) { for (const auto & cnstr : params.samplers) {
switch (cnstr) { switch (cnstr) {
case COMMON_SAMPLER_TYPE_DRY: case COMMON_SAMPLER_TYPE_DRY:
@ -250,7 +253,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
for (const auto & str : params.dry_sequence_breakers) { for (const auto & str : params.dry_sequence_breakers) {
c_breakers.push_back(str.c_str()); c_breakers.push_back(str.c_str());
} }
samplers.push_back(llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size())); samplers.push_back(llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
} }
break; break;
@ -281,12 +283,18 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
case COMMON_SAMPLER_TYPE_PENALTIES: case COMMON_SAMPLER_TYPE_PENALTIES:
samplers.push_back(llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present)); samplers.push_back(llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
break; break;
case COMMON_SAMPLER_TYPE_POWER_LAW:
has_distribution_sampler = true;
samplers.push_back(llama_sampler_init_power_law (params.power_law_target, params.power_law_decay, params.seed));
break;
default: default:
GGML_ASSERT(false && "unknown sampler type"); GGML_ASSERT(false && "unknown sampler type");
} }
} }
// only add `dist` to the end of the chain if no other distribution samplers were added
if (!has_distribution_sampler) {
samplers.push_back(llama_sampler_init_dist(params.seed)); samplers.push_back(llama_sampler_init_dist(params.seed));
}
} else if (params.mirostat == 1) { } else if (params.mirostat == 1) {
samplers.push_back(llama_sampler_init_temp(params.temp)); samplers.push_back(llama_sampler_init_temp(params.temp));
samplers.push_back(llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100)); samplers.push_back(llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
@ -553,6 +561,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
case COMMON_SAMPLER_TYPE_XTC: return 'x'; case COMMON_SAMPLER_TYPE_XTC: return 'x';
case COMMON_SAMPLER_TYPE_INFILL: return 'i'; case COMMON_SAMPLER_TYPE_INFILL: return 'i';
case COMMON_SAMPLER_TYPE_PENALTIES: return 'e'; case COMMON_SAMPLER_TYPE_PENALTIES: return 'e';
case COMMON_SAMPLER_TYPE_POWER_LAW: return 'w';
default : return '?'; default : return '?';
} }
} }
@ -569,6 +578,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
case COMMON_SAMPLER_TYPE_XTC: return "xtc"; case COMMON_SAMPLER_TYPE_XTC: return "xtc";
case COMMON_SAMPLER_TYPE_INFILL: return "infill"; case COMMON_SAMPLER_TYPE_INFILL: return "infill";
case COMMON_SAMPLER_TYPE_PENALTIES: return "penalties"; case COMMON_SAMPLER_TYPE_PENALTIES: return "penalties";
case COMMON_SAMPLER_TYPE_POWER_LAW: return "power_law";
default : return ""; default : return "";
} }
} }
@ -585,6 +595,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
{ "xtc", COMMON_SAMPLER_TYPE_XTC }, { "xtc", COMMON_SAMPLER_TYPE_XTC },
{ "infill", COMMON_SAMPLER_TYPE_INFILL }, { "infill", COMMON_SAMPLER_TYPE_INFILL },
{ "penalties", COMMON_SAMPLER_TYPE_PENALTIES }, { "penalties", COMMON_SAMPLER_TYPE_PENALTIES },
{ "power_law", COMMON_SAMPLER_TYPE_POWER_LAW },
}; };
// since samplers names are written multiple ways // since samplers names are written multiple ways
@ -600,6 +611,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
{ "typ", COMMON_SAMPLER_TYPE_TYPICAL_P }, { "typ", COMMON_SAMPLER_TYPE_TYPICAL_P },
{ "min-p", COMMON_SAMPLER_TYPE_MIN_P }, { "min-p", COMMON_SAMPLER_TYPE_MIN_P },
{ "temp", COMMON_SAMPLER_TYPE_TEMPERATURE }, { "temp", COMMON_SAMPLER_TYPE_TEMPERATURE },
{ "power-law", COMMON_SAMPLER_TYPE_POWER_LAW },
}; };
std::vector<common_sampler_type> samplers; std::vector<common_sampler_type> samplers;

View File

@ -1304,6 +1304,29 @@ extern "C" {
const char ** seq_breakers, const char ** seq_breakers,
size_t num_breakers); size_t num_breakers);
/// power-law
///
/// this sampler implements a power law probability transformation with adaptive
/// target tracking. it reshapes token probability distributions to favor tokens near a
/// configurable target probability, rather than always selecting from the highest probability
/// candidates. it is ideal for creative, unpredictable text generation.
///
/// this sampler is like `greedy`, `dist`, and `mirostat` in that it actually selects a token ID
/// rather than just transforming logits. therefore it must always be the last sampler in the
/// sampler chain.
///
/// minimal truncation before this sampler is recommended.
///
/// @param target select tokens near this probability (valid range 0.0 to 1.0; <0 = disabled)
/// @param decay decay rate for target adaptation over time. lower values -> faster but less stable adaptation. (valid range 0.0 to 1.0; ≤0 = no adaptation)
///
/// ref: https://github.com/MrJackSpade/llama.cpp/tree/master (original impl)
/// ref: https://github.com/ggml-org/llama.cpp/pull/17927 (llama.cpp PR)
LLAMA_API struct llama_sampler * llama_sampler_init_power_law(
float target,
float decay,
uint32_t seed);
LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias( LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias(
int32_t n_vocab, int32_t n_vocab,
int32_t n_logit_bias, int32_t n_logit_bias,

View File

@ -2313,6 +2313,150 @@ struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, floa
return result; return result;
} }
// power-law
//
// this sampler implements a power law probability transformation with adaptive
// target tracking. it reshapes token probability distributions to favor tokens near a
// configurable target probability, rather than always selecting from the highest probability
// candidates. it is ideal for creative, unpredictable text generation.
//
// this sampler is like `greedy`, `dist`, and `mirostat` in that it actually selects a token ID
// rather than just transforming logits. therefore it must always be the last sampler in the
// sampler chain.
//
// minimal truncation before this sampler is recommended.
//
// ref: https://github.com/MrJackSpade/llama.cpp/tree/master (original impl)
// ref: https://github.com/ggml-org/llama.cpp/pull/17927 (llama.cpp PR)
struct llama_sampler_power_law {
// the desired average probability for selected tokens (0.0 to 1.0)
// higher values favor more probable tokens (more deterministic)
// lower values favor less probable tokens (more creative)
// negative values disable Power Law sampling (sample from distribution as-is)
const float target;
// controls how quickly history influence fades (0.0 to 0.99)
// lower values = faster adaptation, more reactive to recent tokens
// higher values = slower adaptation, more stable over time
// effective history length ≈ 1/(1-decay) tokens
// examples: decay=0.5 → ~2 tokens, decay=0.9 → ~10, decay=0.95 → ~20
// internally clamped to <= 0.99 to prevent unbounded accumulation
const float decay;
const uint32_t seed;
std::mt19937 rng;
// historical token probabilities weighted by recency
float weighted_sum;
// sum of weights, converges to 1/(1-decay)
float total_weight;
// used to store original token probabilities (needed for history update after selection)
std::vector<float> original_probs;
};
// transformation constants
static constexpr float DISTRIBUTION_WIDTH = 0.3f;
static constexpr float PEAK_LOGIT_VALUE = 5.0f;
static constexpr float INV_WIDTH = 1.0f / DISTRIBUTION_WIDTH;
static const char * llama_sampler_power_law_name(const struct llama_sampler * /*smpl*/) {
return "power-law";
}
static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
auto * ctx = (llama_sampler_power_law *) smpl->ctx;
if (ctx->target < 0.0f) {
// no-op: just sample from the distribution as-is
llama_sampler_softmax_impl(cur_p, false);
cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
return;
}
// softmax and store the original probabilities
llama_sampler_softmax_impl(cur_p, false);
ctx->original_probs.resize(cur_p->size);
for (size_t i = 0; i < cur_p->size; ++i) {
ctx->original_probs[i] = cur_p->data[i].p;
}
// compute the adapted target probability for the current sampling step
float computed_target = std::clamp(
ctx->total_weight == 0.0f ? ctx->target : 2.0f * ctx->target - (ctx->weighted_sum / ctx->total_weight),
0.0f, 1.0f
);
// power law transform
for (size_t i = 0; i < cur_p->size; ++i) {
float dist = (cur_p->data[i].p - computed_target) * INV_WIDTH;
cur_p->data[i].logit = PEAK_LOGIT_VALUE / (1.0f + dist * dist);
}
llama_sampler_softmax_impl(cur_p, false);
// sample from transformed distribution
const int idx = llama_sample_dist(cur_p, ctx->rng);
cur_p->selected = idx;
// update running history with the original probability of the selected token
ctx->weighted_sum = ctx->original_probs[idx] + ctx->decay * ctx->weighted_sum;
ctx->total_weight = 1.0f + ctx->decay * ctx->total_weight; // history fades over time
}
static void llama_sampler_power_law_reset(struct llama_sampler * smpl) {
auto * ctx = (llama_sampler_power_law *) smpl->ctx;
ctx->weighted_sum = 0.0f;
ctx->total_weight = 0.0f;
}
static struct llama_sampler * llama_sampler_power_law_clone(const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_power_law *) smpl->ctx;
auto * result = llama_sampler_init_power_law(ctx->target, ctx->decay, ctx->seed);
auto * result_ctx = (llama_sampler_power_law *) result->ctx;
result_ctx->rng = ctx->rng;
result_ctx->weighted_sum = ctx->weighted_sum;
result_ctx->total_weight = ctx->total_weight;
result_ctx->original_probs.reserve(ctx->original_probs.capacity());
return result;
}
static void llama_sampler_power_law_free(struct llama_sampler * smpl) {
delete (llama_sampler_power_law *) smpl->ctx;
}
static struct llama_sampler_i llama_sampler_power_law_i = {
/* .name = */ llama_sampler_power_law_name,
/* .accept = */ nullptr,
/* .apply = */ llama_sampler_power_law_apply,
/* .reset = */ llama_sampler_power_law_reset,
/* .clone = */ llama_sampler_power_law_clone,
/* .free = */ llama_sampler_power_law_free,
};
struct llama_sampler * llama_sampler_init_power_law(
float target,
float decay,
uint32_t seed
) {
auto seed_cur = get_rng_seed(seed);
return llama_sampler_init(
/* .iface = */ &llama_sampler_power_law_i,
/* .ctx = */ new llama_sampler_power_law {
/* .target = */ std::clamp(target, 0.0f, 1.0f),
/* .decay = */ std::clamp(decay, 0.0f, 0.99f),
/* .seed = */ seed_cur,
/* .rng = */ std::mt19937(seed_cur),
/* .weighted_sum = */ 0.0f,
/* .total_weight = */ 0.0f,
/* .original_probs = */ {},
}
);
}
// logit-bias // logit-bias
struct llama_sampler_logit_bias { struct llama_sampler_logit_bias {

View File

@ -203,6 +203,8 @@ task_params server_task::params_from_json_cmpl(
params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat); params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat);
params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau); params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau);
params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta); params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta);
params.sampling.power_law_target = json_value(data, "power_law_target", defaults.sampling.power_law_target);
params.sampling.power_law_decay = json_value(data, "power_law_decay", defaults.sampling.power_law_decay);
params.sampling.seed = json_value(data, "seed", defaults.sampling.seed); params.sampling.seed = json_value(data, "seed", defaults.sampling.seed);
params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs); params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs);
params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep); params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep);