diff --git a/common/sampling.cpp b/common/sampling.cpp index 05e44170e4..d571c5ecd4 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -241,8 +241,8 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co } if (params.mirostat == 0) { - // if this flag is set, we will not need to add `dist` at the end of the sampler chain - bool has_distribution_sampler = false; + + bool use_power_law = false; for (const auto & cnstr : params.samplers) { switch (cnstr) { @@ -253,46 +253,52 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co for (const auto & str : params.dry_sequence_breakers) { c_breakers.push_back(str.c_str()); } - samplers.push_back(llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size())); + samplers.push_back(llama_sampler_init_dry(vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size())); } break; case COMMON_SAMPLER_TYPE_TOP_K: - samplers.push_back(llama_sampler_init_top_k (params.top_k)); + samplers.push_back(llama_sampler_init_top_k(params.top_k)); break; case COMMON_SAMPLER_TYPE_TOP_P: - samplers.push_back(llama_sampler_init_top_p (params.top_p, params.min_keep)); + samplers.push_back(llama_sampler_init_top_p(params.top_p, params.min_keep)); break; case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: samplers.push_back(llama_sampler_init_top_n_sigma(params.top_n_sigma)); break; case COMMON_SAMPLER_TYPE_MIN_P: - samplers.push_back(llama_sampler_init_min_p (params.min_p, params.min_keep)); + samplers.push_back(llama_sampler_init_min_p(params.min_p, params.min_keep)); break; case COMMON_SAMPLER_TYPE_XTC: - samplers.push_back(llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed)); + samplers.push_back(llama_sampler_init_xtc(params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed)); break; case COMMON_SAMPLER_TYPE_TYPICAL_P: - samplers.push_back(llama_sampler_init_typical (params.typ_p, params.min_keep)); + samplers.push_back(llama_sampler_init_typical(params.typ_p, params.min_keep)); break; case COMMON_SAMPLER_TYPE_TEMPERATURE: - samplers.push_back(llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); + samplers.push_back(llama_sampler_init_temp_ext(params.temp, params.dynatemp_range, params.dynatemp_exponent)); break; case COMMON_SAMPLER_TYPE_INFILL: - samplers.push_back(llama_sampler_init_infill (vocab)); + samplers.push_back(llama_sampler_init_infill(vocab)); break; case COMMON_SAMPLER_TYPE_PENALTIES: - samplers.push_back(llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present)); + samplers.push_back(llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present)); break; case COMMON_SAMPLER_TYPE_POWER_LAW: - has_distribution_sampler = true; - samplers.push_back(llama_sampler_init_power_law (params.power_law_target, params.power_law_decay, params.seed)); + // the `power_law` sampler is like `dist` in that it selects a single token, + // so we will add `dist` at the end of the chain by default, unless the user + // specifically included `power_law`. we set this flag here so we know to add + // it at the very end. + use_power_law = true; break; default: GGML_ASSERT(false && "unknown sampler type"); } } - // only add `dist` to the end of the chain if no other distribution samplers were added - if (!has_distribution_sampler) { + if (use_power_law) { + // only if user explicitly included power_law sampler + samplers.push_back(llama_sampler_init_power_law(params.power_law_target, params.power_law_decay, params.seed)); + } else { + // default: sample from distribution samplers.push_back(llama_sampler_init_dist(params.seed)); } } else if (params.mirostat == 1) { diff --git a/include/llama.h b/include/llama.h index 3ec3f25c95..f903d34a56 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1309,7 +1309,7 @@ extern "C" { /// this sampler implements a power law probability transformation with adaptive /// target tracking. it reshapes token probability distributions to favor tokens near a /// configurable target probability, rather than always selecting from the highest probability - /// candidates. it is ideal for creative, unpredictable text generation. + /// candidates. /// /// this sampler is like `greedy`, `dist`, and `mirostat` in that it actually selects a token ID /// rather than just transforming logits. therefore it must always be the last sampler in the diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 77ec141a56..59393275a5 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -2318,7 +2318,7 @@ struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, floa // this sampler implements a power law probability transformation with adaptive // target tracking. it reshapes token probability distributions to favor tokens near a // configurable target probability, rather than always selecting from the highest probability -// candidates. it is ideal for creative, unpredictable text generation. +// candidates. // // this sampler is like `greedy`, `dist`, and `mirostat` in that it actually selects a token ID // rather than just transforming logits. therefore it must always be the last sampler in the @@ -2332,7 +2332,7 @@ struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, floa struct llama_sampler_power_law { // the desired average probability for selected tokens (0.0 to 1.0) - // higher values favor more probable tokens (more deterministic) + // higher values favor more probable tokens (more stable and predictable) // lower values favor less probable tokens (more creative) // negative values disable Power Law sampling (sample from distribution as-is) const float target; @@ -2341,19 +2341,17 @@ struct llama_sampler_power_law { // lower values = faster adaptation, more reactive to recent tokens // higher values = slower adaptation, more stable over time // effective history length ≈ 1/(1-decay) tokens - // examples: decay=0.5 → ~2 tokens, decay=0.9 → ~10, decay=0.95 → ~20 + // example: decay=0.5 --> ~2 tokens; decay=0.9 --> ~10 tokens; decay=0.95 --> ~20 tokens // internally clamped to <= 0.99 to prevent unbounded accumulation const float decay; const uint32_t seed; std::mt19937 rng; - // historical token probabilities weighted by recency - float weighted_sum; - // sum of weights, converges to 1/(1-decay) - float total_weight; - // used to store original token probabilities (needed for history update after selection) - std::vector original_probs; + // member variables + float weighted_sum; // historical token probabilities weighted by recency + float total_weight; // sum of weights, converges to 1/(1-decay) + std::vector original_probs; // used to store original token probabilities }; // transformation constants @@ -2401,8 +2399,8 @@ static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_tok cur_p->selected = idx; // update running history with the original probability of the selected token - ctx->weighted_sum = ctx->original_probs[idx] + ctx->decay * ctx->weighted_sum; - ctx->total_weight = 1.0f + ctx->decay * ctx->total_weight; // history fades over time + ctx->weighted_sum = ctx->original_probs[idx] + ctx->decay * ctx->weighted_sum; // history fades over time + ctx->total_weight = 1.0f + ctx->decay * ctx->total_weight; } static void llama_sampler_power_law_reset(struct llama_sampler * smpl) {