add `use_power_law` flag + logic, minor cleanup

This commit is contained in:
ddh0 2025-12-17 15:06:05 -06:00
parent 27dda80dd7
commit 775299892e
3 changed files with 31 additions and 27 deletions

View File

@ -241,8 +241,8 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
}
if (params.mirostat == 0) {
// if this flag is set, we will not need to add `dist` at the end of the sampler chain
bool has_distribution_sampler = false;
bool use_power_law = false;
for (const auto & cnstr : params.samplers) {
switch (cnstr) {
@ -253,46 +253,52 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
for (const auto & str : params.dry_sequence_breakers) {
c_breakers.push_back(str.c_str());
}
samplers.push_back(llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
samplers.push_back(llama_sampler_init_dry(vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
}
break;
case COMMON_SAMPLER_TYPE_TOP_K:
samplers.push_back(llama_sampler_init_top_k (params.top_k));
samplers.push_back(llama_sampler_init_top_k(params.top_k));
break;
case COMMON_SAMPLER_TYPE_TOP_P:
samplers.push_back(llama_sampler_init_top_p (params.top_p, params.min_keep));
samplers.push_back(llama_sampler_init_top_p(params.top_p, params.min_keep));
break;
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
samplers.push_back(llama_sampler_init_top_n_sigma(params.top_n_sigma));
break;
case COMMON_SAMPLER_TYPE_MIN_P:
samplers.push_back(llama_sampler_init_min_p (params.min_p, params.min_keep));
samplers.push_back(llama_sampler_init_min_p(params.min_p, params.min_keep));
break;
case COMMON_SAMPLER_TYPE_XTC:
samplers.push_back(llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
samplers.push_back(llama_sampler_init_xtc(params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
break;
case COMMON_SAMPLER_TYPE_TYPICAL_P:
samplers.push_back(llama_sampler_init_typical (params.typ_p, params.min_keep));
samplers.push_back(llama_sampler_init_typical(params.typ_p, params.min_keep));
break;
case COMMON_SAMPLER_TYPE_TEMPERATURE:
samplers.push_back(llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
samplers.push_back(llama_sampler_init_temp_ext(params.temp, params.dynatemp_range, params.dynatemp_exponent));
break;
case COMMON_SAMPLER_TYPE_INFILL:
samplers.push_back(llama_sampler_init_infill (vocab));
samplers.push_back(llama_sampler_init_infill(vocab));
break;
case COMMON_SAMPLER_TYPE_PENALTIES:
samplers.push_back(llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
samplers.push_back(llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
break;
case COMMON_SAMPLER_TYPE_POWER_LAW:
has_distribution_sampler = true;
samplers.push_back(llama_sampler_init_power_law (params.power_law_target, params.power_law_decay, params.seed));
// the `power_law` sampler is like `dist` in that it selects a single token,
// so we will add `dist` at the end of the chain by default, unless the user
// specifically included `power_law`. we set this flag here so we know to add
// it at the very end.
use_power_law = true;
break;
default:
GGML_ASSERT(false && "unknown sampler type");
}
}
// only add `dist` to the end of the chain if no other distribution samplers were added
if (!has_distribution_sampler) {
if (use_power_law) {
// only if user explicitly included power_law sampler
samplers.push_back(llama_sampler_init_power_law(params.power_law_target, params.power_law_decay, params.seed));
} else {
// default: sample from distribution
samplers.push_back(llama_sampler_init_dist(params.seed));
}
} else if (params.mirostat == 1) {

View File

@ -1309,7 +1309,7 @@ extern "C" {
/// this sampler implements a power law probability transformation with adaptive
/// target tracking. it reshapes token probability distributions to favor tokens near a
/// configurable target probability, rather than always selecting from the highest probability
/// candidates. it is ideal for creative, unpredictable text generation.
/// candidates.
///
/// this sampler is like `greedy`, `dist`, and `mirostat` in that it actually selects a token ID
/// rather than just transforming logits. therefore it must always be the last sampler in the

View File

@ -2318,7 +2318,7 @@ struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, floa
// this sampler implements a power law probability transformation with adaptive
// target tracking. it reshapes token probability distributions to favor tokens near a
// configurable target probability, rather than always selecting from the highest probability
// candidates. it is ideal for creative, unpredictable text generation.
// candidates.
//
// this sampler is like `greedy`, `dist`, and `mirostat` in that it actually selects a token ID
// rather than just transforming logits. therefore it must always be the last sampler in the
@ -2332,7 +2332,7 @@ struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, floa
struct llama_sampler_power_law {
// the desired average probability for selected tokens (0.0 to 1.0)
// higher values favor more probable tokens (more deterministic)
// higher values favor more probable tokens (more stable and predictable)
// lower values favor less probable tokens (more creative)
// negative values disable Power Law sampling (sample from distribution as-is)
const float target;
@ -2341,19 +2341,17 @@ struct llama_sampler_power_law {
// lower values = faster adaptation, more reactive to recent tokens
// higher values = slower adaptation, more stable over time
// effective history length ≈ 1/(1-decay) tokens
// examples: decay=0.5 → ~2 tokens, decay=0.9 → ~10, decay=0.95 → ~20
// example: decay=0.5 --> ~2 tokens; decay=0.9 --> ~10 tokens; decay=0.95 --> ~20 tokens
// internally clamped to <= 0.99 to prevent unbounded accumulation
const float decay;
const uint32_t seed;
std::mt19937 rng;
// historical token probabilities weighted by recency
float weighted_sum;
// sum of weights, converges to 1/(1-decay)
float total_weight;
// used to store original token probabilities (needed for history update after selection)
std::vector<float> original_probs;
// member variables
float weighted_sum; // historical token probabilities weighted by recency
float total_weight; // sum of weights, converges to 1/(1-decay)
std::vector<float> original_probs; // used to store original token probabilities
};
// transformation constants
@ -2401,8 +2399,8 @@ static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_tok
cur_p->selected = idx;
// update running history with the original probability of the selected token
ctx->weighted_sum = ctx->original_probs[idx] + ctx->decay * ctx->weighted_sum;
ctx->total_weight = 1.0f + ctx->decay * ctx->total_weight; // history fades over time
ctx->weighted_sum = ctx->original_probs[idx] + ctx->decay * ctx->weighted_sum; // history fades over time
ctx->total_weight = 1.0f + ctx->decay * ctx->total_weight;
}
static void llama_sampler_power_law_reset(struct llama_sampler * smpl) {