add `use_power_law` flag + logic, minor cleanup
This commit is contained in:
parent
27dda80dd7
commit
775299892e
|
|
@ -241,8 +241,8 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||
}
|
||||
|
||||
if (params.mirostat == 0) {
|
||||
// if this flag is set, we will not need to add `dist` at the end of the sampler chain
|
||||
bool has_distribution_sampler = false;
|
||||
|
||||
bool use_power_law = false;
|
||||
|
||||
for (const auto & cnstr : params.samplers) {
|
||||
switch (cnstr) {
|
||||
|
|
@ -253,46 +253,52 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||
for (const auto & str : params.dry_sequence_breakers) {
|
||||
c_breakers.push_back(str.c_str());
|
||||
}
|
||||
samplers.push_back(llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
|
||||
samplers.push_back(llama_sampler_init_dry(vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
|
||||
}
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TOP_K:
|
||||
samplers.push_back(llama_sampler_init_top_k (params.top_k));
|
||||
samplers.push_back(llama_sampler_init_top_k(params.top_k));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TOP_P:
|
||||
samplers.push_back(llama_sampler_init_top_p (params.top_p, params.min_keep));
|
||||
samplers.push_back(llama_sampler_init_top_p(params.top_p, params.min_keep));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
|
||||
samplers.push_back(llama_sampler_init_top_n_sigma(params.top_n_sigma));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_MIN_P:
|
||||
samplers.push_back(llama_sampler_init_min_p (params.min_p, params.min_keep));
|
||||
samplers.push_back(llama_sampler_init_min_p(params.min_p, params.min_keep));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_XTC:
|
||||
samplers.push_back(llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
|
||||
samplers.push_back(llama_sampler_init_xtc(params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TYPICAL_P:
|
||||
samplers.push_back(llama_sampler_init_typical (params.typ_p, params.min_keep));
|
||||
samplers.push_back(llama_sampler_init_typical(params.typ_p, params.min_keep));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
||||
samplers.push_back(llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
|
||||
samplers.push_back(llama_sampler_init_temp_ext(params.temp, params.dynatemp_range, params.dynatemp_exponent));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_INFILL:
|
||||
samplers.push_back(llama_sampler_init_infill (vocab));
|
||||
samplers.push_back(llama_sampler_init_infill(vocab));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_PENALTIES:
|
||||
samplers.push_back(llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
|
||||
samplers.push_back(llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_POWER_LAW:
|
||||
has_distribution_sampler = true;
|
||||
samplers.push_back(llama_sampler_init_power_law (params.power_law_target, params.power_law_decay, params.seed));
|
||||
// the `power_law` sampler is like `dist` in that it selects a single token,
|
||||
// so we will add `dist` at the end of the chain by default, unless the user
|
||||
// specifically included `power_law`. we set this flag here so we know to add
|
||||
// it at the very end.
|
||||
use_power_law = true;
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false && "unknown sampler type");
|
||||
}
|
||||
}
|
||||
// only add `dist` to the end of the chain if no other distribution samplers were added
|
||||
if (!has_distribution_sampler) {
|
||||
if (use_power_law) {
|
||||
// only if user explicitly included power_law sampler
|
||||
samplers.push_back(llama_sampler_init_power_law(params.power_law_target, params.power_law_decay, params.seed));
|
||||
} else {
|
||||
// default: sample from distribution
|
||||
samplers.push_back(llama_sampler_init_dist(params.seed));
|
||||
}
|
||||
} else if (params.mirostat == 1) {
|
||||
|
|
|
|||
|
|
@ -1309,7 +1309,7 @@ extern "C" {
|
|||
/// this sampler implements a power law probability transformation with adaptive
|
||||
/// target tracking. it reshapes token probability distributions to favor tokens near a
|
||||
/// configurable target probability, rather than always selecting from the highest probability
|
||||
/// candidates. it is ideal for creative, unpredictable text generation.
|
||||
/// candidates.
|
||||
///
|
||||
/// this sampler is like `greedy`, `dist`, and `mirostat` in that it actually selects a token ID
|
||||
/// rather than just transforming logits. therefore it must always be the last sampler in the
|
||||
|
|
|
|||
|
|
@ -2318,7 +2318,7 @@ struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, floa
|
|||
// this sampler implements a power law probability transformation with adaptive
|
||||
// target tracking. it reshapes token probability distributions to favor tokens near a
|
||||
// configurable target probability, rather than always selecting from the highest probability
|
||||
// candidates. it is ideal for creative, unpredictable text generation.
|
||||
// candidates.
|
||||
//
|
||||
// this sampler is like `greedy`, `dist`, and `mirostat` in that it actually selects a token ID
|
||||
// rather than just transforming logits. therefore it must always be the last sampler in the
|
||||
|
|
@ -2332,7 +2332,7 @@ struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, floa
|
|||
struct llama_sampler_power_law {
|
||||
|
||||
// the desired average probability for selected tokens (0.0 to 1.0)
|
||||
// higher values favor more probable tokens (more deterministic)
|
||||
// higher values favor more probable tokens (more stable and predictable)
|
||||
// lower values favor less probable tokens (more creative)
|
||||
// negative values disable Power Law sampling (sample from distribution as-is)
|
||||
const float target;
|
||||
|
|
@ -2341,19 +2341,17 @@ struct llama_sampler_power_law {
|
|||
// lower values = faster adaptation, more reactive to recent tokens
|
||||
// higher values = slower adaptation, more stable over time
|
||||
// effective history length ≈ 1/(1-decay) tokens
|
||||
// examples: decay=0.5 → ~2 tokens, decay=0.9 → ~10, decay=0.95 → ~20
|
||||
// example: decay=0.5 --> ~2 tokens; decay=0.9 --> ~10 tokens; decay=0.95 --> ~20 tokens
|
||||
// internally clamped to <= 0.99 to prevent unbounded accumulation
|
||||
const float decay;
|
||||
|
||||
const uint32_t seed;
|
||||
std::mt19937 rng;
|
||||
|
||||
// historical token probabilities weighted by recency
|
||||
float weighted_sum;
|
||||
// sum of weights, converges to 1/(1-decay)
|
||||
float total_weight;
|
||||
// used to store original token probabilities (needed for history update after selection)
|
||||
std::vector<float> original_probs;
|
||||
// member variables
|
||||
float weighted_sum; // historical token probabilities weighted by recency
|
||||
float total_weight; // sum of weights, converges to 1/(1-decay)
|
||||
std::vector<float> original_probs; // used to store original token probabilities
|
||||
};
|
||||
|
||||
// transformation constants
|
||||
|
|
@ -2401,8 +2399,8 @@ static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_tok
|
|||
cur_p->selected = idx;
|
||||
|
||||
// update running history with the original probability of the selected token
|
||||
ctx->weighted_sum = ctx->original_probs[idx] + ctx->decay * ctx->weighted_sum;
|
||||
ctx->total_weight = 1.0f + ctx->decay * ctx->total_weight; // history fades over time
|
||||
ctx->weighted_sum = ctx->original_probs[idx] + ctx->decay * ctx->weighted_sum; // history fades over time
|
||||
ctx->total_weight = 1.0f + ctx->decay * ctx->total_weight;
|
||||
}
|
||||
|
||||
static void llama_sampler_power_law_reset(struct llama_sampler * smpl) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue