diff --git a/common/arg.cpp b/common/arg.cpp index ceb4d74111..163c9b71b0 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1729,6 +1729,26 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } } ).set_sparam()); + add_opt(common_arg( + {"--adaptive-target"}, "N", + string_format("adaptive-p: select tokens near this probability (valid range 0.0 " + "to 1.0; negative = disabled) (default: %.2f)\n" + "[(more info)](https://github.com/ggml-org/llama.cpp/pull/17927)", + (double)params.sampling.adaptive_target), + [](common_params & params, const std::string & value) { + params.sampling.adaptive_target = std::stof(value); + } + ).set_sparam()); + add_opt(common_arg( + {"--adaptive-decay"}, "N", + string_format("adaptive-p: decay rate for target adaptation over time. lower values " + "are more reactive, higher values are more stable.\n" + "(valid range 0.0 to 0.99) (default: %.2f)", + (double)params.sampling.adaptive_decay), + [](common_params & params, const std::string & value) { + params.sampling.adaptive_decay = std::stof(value); + } + ).set_sparam()); add_opt(common_arg( {"--dynatemp-range"}, "N", string_format("dynamic temperature range (default: %.1f, 0.0 = disabled)", (double)params.sampling.dynatemp_range), diff --git a/common/common.h b/common/common.h index e60087dea3..b9566df62c 100644 --- a/common/common.h +++ b/common/common.h @@ -119,6 +119,7 @@ enum common_sampler_type { COMMON_SAMPLER_TYPE_INFILL = 9, COMMON_SAMPLER_TYPE_PENALTIES = 10, COMMON_SAMPLER_TYPE_TOP_N_SIGMA = 11, + COMMON_SAMPLER_TYPE_ADAPTIVE_P = 12, }; // dimensionality reduction methods, used by cvector-generator @@ -166,32 +167,34 @@ enum common_params_sampling_config : uint64_t { struct common_params_sampling { uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler - int32_t n_prev = 64; // number of previous tokens to remember - int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. - int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens - int32_t top_k = 40; // <= 0 to use vocab size - float top_p = 0.95f; // 1.0 = disabled - float min_p = 0.05f; // 0.0 = disabled - float xtc_probability = 0.00f; // 0.0 = disabled - float xtc_threshold = 0.10f; // > 0.5 disables XTC - float typ_p = 1.00f; // typical_p, 1.0 = disabled - float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities - float dynatemp_range = 0.00f; // 0.0 = disabled - float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler - int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) - float penalty_repeat = 1.00f; // 1.0 = disabled - float penalty_freq = 0.00f; // 0.0 = disabled - float penalty_present = 0.00f; // 0.0 = disabled - float dry_multiplier = 0.0f; // 0.0 = disabled; DRY repetition penalty for tokens extending repetition: - float dry_base = 1.75f; // 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length) - int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty - int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size) - int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 - float top_n_sigma = -1.00f;// -1.0 = disabled - float mirostat_tau = 5.00f; // target entropy - float mirostat_eta = 0.10f; // learning rate + int32_t n_prev = 64; // number of previous tokens to remember + int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. + int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens + int32_t top_k = 40; // <= 0 to use vocab size + float top_p = 0.95f; // 1.0 = disabled + float min_p = 0.05f; // 0.0 = disabled + float xtc_probability = 0.00f; // 0.0 = disabled + float xtc_threshold = 0.10f; // > 0.5 disables XTC + float typ_p = 1.00f; // typical_p, 1.0 = disabled + float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities + float dynatemp_range = 0.00f; // 0.0 = disabled + float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler + int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) + float penalty_repeat = 1.00f; // 1.0 = disabled + float penalty_freq = 0.00f; // 0.0 = disabled + float penalty_present = 0.00f; // 0.0 = disabled + float dry_multiplier = 0.0f; // 0.0 = disabled; DRY repetition penalty for tokens extending repetition: + float dry_base = 1.75f; // 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length) + int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty + int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size) + float adaptive_target = -1.0f; // select tokens near this probability (valid range 0.0 to 1.0; negative = disabled) + float adaptive_decay = 0.90f; // EMA decay for adaptation; history ≈ 1/(1-decay) tokens (0.0 - 0.99) + int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 + float top_n_sigma = -1.00f; // -1.0 = disabled + float mirostat_tau = 5.00f; // target entropy + float mirostat_eta = 0.10f; // learning rate bool ignore_eos = false; - bool no_perf = false; // disable performance metrics + bool no_perf = false; // disable performance metrics bool timing_per_token = false; uint64_t user_sampling_config = 0; // bitfield to track user-specified samplers diff --git a/common/sampling.cpp b/common/sampling.cpp index 54f6377a1e..11a1d48398 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -167,11 +167,11 @@ std::string common_params_sampling::print() const { "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n" "\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n" "\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, top_n_sigma = %.3f, temp = %.3f\n" - "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f", + "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f, adaptive_target = %.3f, adaptive_decay = %.3f", penalty_last_n, penalty_repeat, penalty_freq, penalty_present, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, top_n_sigma, temp, - mirostat, mirostat_eta, mirostat_tau); + mirostat, mirostat_eta, mirostat_tau, adaptive_target, adaptive_decay); return std::string(result); } @@ -255,6 +255,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st } if (params.mirostat == 0) { + + bool use_adaptive_p = false; // see below + for (const auto & cnstr : params.samplers) { switch (cnstr) { case COMMON_SAMPLER_TYPE_DRY: @@ -264,43 +267,54 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st for (const auto & str : params.dry_sequence_breakers) { c_breakers.push_back(str.c_str()); } - - samplers.push_back(llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size())); + samplers.push_back(llama_sampler_init_dry(vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size())); } break; case COMMON_SAMPLER_TYPE_TOP_K: - samplers.push_back(llama_sampler_init_top_k (params.top_k)); + samplers.push_back(llama_sampler_init_top_k(params.top_k)); break; case COMMON_SAMPLER_TYPE_TOP_P: - samplers.push_back(llama_sampler_init_top_p (params.top_p, params.min_keep)); + samplers.push_back(llama_sampler_init_top_p(params.top_p, params.min_keep)); break; case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: samplers.push_back(llama_sampler_init_top_n_sigma(params.top_n_sigma)); break; case COMMON_SAMPLER_TYPE_MIN_P: - samplers.push_back(llama_sampler_init_min_p (params.min_p, params.min_keep)); + samplers.push_back(llama_sampler_init_min_p(params.min_p, params.min_keep)); break; case COMMON_SAMPLER_TYPE_XTC: - samplers.push_back(llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed)); + samplers.push_back(llama_sampler_init_xtc(params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed)); break; case COMMON_SAMPLER_TYPE_TYPICAL_P: - samplers.push_back(llama_sampler_init_typical (params.typ_p, params.min_keep)); + samplers.push_back(llama_sampler_init_typical(params.typ_p, params.min_keep)); break; case COMMON_SAMPLER_TYPE_TEMPERATURE: - samplers.push_back(llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); + samplers.push_back(llama_sampler_init_temp_ext(params.temp, params.dynatemp_range, params.dynatemp_exponent)); break; case COMMON_SAMPLER_TYPE_INFILL: - samplers.push_back(llama_sampler_init_infill (vocab)); + samplers.push_back(llama_sampler_init_infill(vocab)); break; case COMMON_SAMPLER_TYPE_PENALTIES: - samplers.push_back(llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present)); + samplers.push_back(llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present)); + break; + case COMMON_SAMPLER_TYPE_ADAPTIVE_P: + // the `adaptive-p` sampler is like `dist` and `mirostat` in that it selects + // a single token, so we will add `dist` at the end of the chain by default, + // unless the user specifically included `adaptive-p`. we set this flag here + // so we know to add the sampler at the very end. + use_adaptive_p = true; break; default: GGML_ASSERT(false && "unknown sampler type"); } } - - samplers.push_back(llama_sampler_init_dist(params.seed)); + if (use_adaptive_p) { + // only if user explicitly included adaptive-p sampler + samplers.push_back(llama_sampler_init_adaptive_p(params.adaptive_target, params.adaptive_decay, params.seed)); + } else { + // default: sample from distribution + samplers.push_back(llama_sampler_init_dist(params.seed)); + } } else if (params.mirostat == 1) { samplers.push_back(llama_sampler_init_temp(params.temp)); samplers.push_back(llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100)); @@ -625,6 +639,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) { case COMMON_SAMPLER_TYPE_XTC: return 'x'; case COMMON_SAMPLER_TYPE_INFILL: return 'i'; case COMMON_SAMPLER_TYPE_PENALTIES: return 'e'; + case COMMON_SAMPLER_TYPE_ADAPTIVE_P: return 'a'; default : return '?'; } } @@ -641,6 +656,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) { case COMMON_SAMPLER_TYPE_XTC: return "xtc"; case COMMON_SAMPLER_TYPE_INFILL: return "infill"; case COMMON_SAMPLER_TYPE_PENALTIES: return "penalties"; + case COMMON_SAMPLER_TYPE_ADAPTIVE_P: return "adaptive_p"; default : return ""; } } @@ -657,6 +673,7 @@ std::vector common_sampler_types_from_names(const std::vect { "xtc", COMMON_SAMPLER_TYPE_XTC }, { "infill", COMMON_SAMPLER_TYPE_INFILL }, { "penalties", COMMON_SAMPLER_TYPE_PENALTIES }, + { "adaptive_p", COMMON_SAMPLER_TYPE_ADAPTIVE_P }, }; // since samplers names are written multiple ways @@ -672,6 +689,7 @@ std::vector common_sampler_types_from_names(const std::vect { "typ", COMMON_SAMPLER_TYPE_TYPICAL_P }, { "min-p", COMMON_SAMPLER_TYPE_MIN_P }, { "temp", COMMON_SAMPLER_TYPE_TEMPERATURE }, + { "adaptive-p", COMMON_SAMPLER_TYPE_ADAPTIVE_P }, }; std::vector samplers; @@ -708,6 +726,7 @@ std::vector common_sampler_types_from_chars(const std::stri { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_PENALTIES), COMMON_SAMPLER_TYPE_PENALTIES }, + { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_ADAPTIVE_P), COMMON_SAMPLER_TYPE_ADAPTIVE_P }, }; std::vector samplers; diff --git a/include/llama.h b/include/llama.h index bff788f92a..280745713e 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1395,6 +1395,33 @@ extern "C" { const char ** seq_breakers, size_t num_breakers); + /// adaptive-p: select tokens near a configurable target probability over time. + /// + /// the adaptive-p sampler transforms the token probability distribution to favor tokens + /// that fall near a user-configurable probability target. + /// + /// internally, the sampler maintains an exponential moving average of the *ORIGINAL* + /// probabilities of selected tokens at each sampling step. it uses this EMA to compute an + /// adapted target probability at each sampling step, thus maintaining the desired target + /// probability over time. + /// + /// adaptive-p selects a token ID rather than just mutating candidates, so it must be last + /// in the sampler chain (like mirostat, dist, greedy). + /// + /// only mild truncation before this sampler is recommended. we suggest applying min-p + /// before adaptive-p as the only other active sampler in the chain. + /// + /// @param target select tokens near this probability (valid range 0.0 to 1.0; negative = disabled) + /// @param decay EMA decay for adaptation; history ≈ 1/(1-decay) tokens (valid range 0.0 - 0.99) + /// @param seed RNG seed + /// + /// ref: https://github.com/ggml-org/llama.cpp/pull/17927 + /// + LLAMA_API struct llama_sampler * llama_sampler_init_adaptive_p( + float target, + float decay, + uint32_t seed); + LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias( int32_t n_vocab, int32_t n_logit_bias, diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 11f0394c4c..5dde513065 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1513,12 +1513,9 @@ static void llama_sampler_top_p_backend_apply( mask_reshaped = ggml_set_rows(ctx, mask_reshaped, ones, ggml_cast(ctx, idxf, GGML_TYPE_I32)); mask = ggml_reshape_1d(ctx, mask_reshaped, mask->ne[0]); - // Use ggml_scale_bias (output = (a * s) + b) which in this case becomes: - // top_p_bias = (mask * 1e9f) - 1e9f. - // So entries in the mask that we want to discard will become -1e9f, and - // others will be 0 (meaning that will not effect the logits). - const float large_val = 1e9f; - struct ggml_tensor * top_p_bias = ggml_scale_bias(ctx, mask, large_val, -large_val); + // Apply -INFINITY bias for masked-out tokens + // log(1) = 0 (keep), log(0) = -INF (discard) + struct ggml_tensor * top_p_bias = ggml_log(ctx, mask); ggml_set_name(top_p_bias, "top_p_bias"); data->logits = ggml_add(ctx, sorted_logits, top_p_bias); @@ -1673,15 +1670,11 @@ static void llama_sampler_min_p_backend_apply( struct ggml_tensor * mask = ggml_step(ctx, sub); ggml_set_name(mask, "min_p_mask"); - // Use ggml_scale_bias (output = (a * s) + b) which in this case becomes: - // min_p_bias = (mask * 1e9f) - 1e9f. - // So entries in the mask that we want to discard will become -1e9f, and - // others will be 0 (meaning that will not effect the logits). - const float large_val = 1e9f; - struct ggml_tensor * min_p_bias = ggml_scale_bias(ctx, mask, large_val, -large_val); + // Apply -INFINITY bias for masked-out tokens + // log(1) = 0 (keep), log(0) = -INF (discard) + struct ggml_tensor * min_p_bias = ggml_log(ctx, mask); ggml_set_name(min_p_bias, "min_p_bias"); - // Add the min_p bias to the logits. data->logits = ggml_add(ctx, data->logits, min_p_bias); ggml_set_name(data->logits, "min_p_logits"); @@ -3293,6 +3286,170 @@ struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, floa return result; } +// adaptive-p sampler state +// +// maintains an exponential moving average of the *ORIGINAL* probabilities +// of selected tokens, used to compute an adapted target at each sampling step. +// +// see llama.h for a full description of the sampler +// +// ref: https://github.com/ggml-org/llama.cpp/pull/17927 +// +struct llama_sampler_adaptive_p { + const float target; // target probability (0.0 - 1.0; negative = disabled) + const float decay; // EMA decay; history ~= 1/(1-decay) tokens (0.0 - 0.99) + const uint32_t seed; // original RNG seed + uint32_t seed_cur; // actual RNG seed + std::mt19937 rng; // RNG state + float weighted_sum; // sum(p_i * decay^i) + float total_weight; // sum(decay^i), converges to 1/(1-decay) + std::vector original_probs; // pre-transform probs, cached for EMA update + llama_token pending_token_id; // token ID of selected token + int32_t pending_token_idx; // index of orig. prob. of selected token in original_probs +}; + +// adaptive probability transformation constants +static constexpr float DISTRIBUTION_WIDTH = 0.3f; +static constexpr float PEAK_LOGIT_VALUE = 5.0f; +static constexpr float SHARPNESS = 10.0f; +static constexpr float INV_WIDTH = 1.0f / DISTRIBUTION_WIDTH; + +static const char * llama_sampler_adaptive_p_name(const struct llama_sampler * /*smpl*/) { + return "adaptive-p"; +} + +static void llama_sampler_adaptive_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_adaptive_p *) smpl->ctx; + + llama_sampler_softmax_impl(cur_p, false); + + if (ctx->target < 0.0f) { + // at negative target values, adaptive-p is no-op + // we simply sample from the existing distribution + cur_p->selected = llama_sample_dist(cur_p, ctx->rng); + return; + } + + // store the original probabilities + ctx->original_probs.resize(cur_p->size); + for (size_t i = 0; i < cur_p->size; ++i) { + ctx->original_probs[i] = cur_p->data[i].p; + } + + // using the EMA, compute the adapted target probability for the current sampling step + auto target = std::clamp(ctx->target, 0.0f, 1.0f); + float adapted_target = std::clamp( + ctx->total_weight == 0.0f ? target : 2.0f * target - (ctx->weighted_sum / ctx->total_weight), + 0.0f, 1.0f + ); + + // adaptive probability transform + // + // quadratic near target for fine differentiation, transitioning to linear decay in the + // tails. unbounded negative logits ensure proper suppression of far-from-target tokens + // after the softmax. + // + for (size_t i = 0; i < cur_p->size; ++i) { + if (cur_p->data[i].logit == -INFINITY) { + // don't transform logits that are -INFINITY + // (as masked out by e.g. min-p and top-p when using backend sampling) + continue; + } + float dist = std::abs((cur_p->data[i].p - adapted_target) * INV_WIDTH); + cur_p->data[i].logit = PEAK_LOGIT_VALUE - SHARPNESS * dist * dist / (1.0f + dist); + } + + // softmax and sample from the transformed distribution + llama_sampler_softmax_impl(cur_p, false); + const int idx = llama_sample_dist(cur_p, ctx->rng); + cur_p->selected = idx; + + // store the selected token ID for acceptance later + ctx->pending_token_id = cur_p->data[idx].id; + ctx->pending_token_idx = idx; +} + +static void llama_sampler_adaptive_p_accept(struct llama_sampler * smpl, llama_token token) { + auto * ctx = (llama_sampler_adaptive_p *) smpl->ctx; + if (ctx->pending_token_id == token) { + GGML_ASSERT(ctx->pending_token_id != LLAMA_TOKEN_NULL); + GGML_ASSERT(ctx->pending_token_idx != -1); + // update EMA with the original probability of the selected token + ctx->weighted_sum = ctx->original_probs[ctx->pending_token_idx] + ctx->decay * ctx->weighted_sum; + ctx->total_weight = 1.0f + ctx->decay * ctx->total_weight; + } + ctx->pending_token_id = LLAMA_TOKEN_NULL; + ctx->pending_token_idx = -1; +} + +static void llama_sampler_adaptive_p_reset(struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_adaptive_p *) smpl->ctx; + // ctx->target and ctx->decay never change after init, so it's safe to keep them as is. + // original_probs is completely overwritten on every call to _apply. + // so we only need to reset the EMA state and pending token. + ctx->weighted_sum = ctx->target / (1.0f - ctx->decay); + ctx->total_weight = 1.0f / (1.0f - ctx->decay); + ctx->pending_token_id = LLAMA_TOKEN_NULL; + ctx->pending_token_idx = -1; + ctx->seed_cur = get_rng_seed(ctx->seed); + ctx->rng.seed(ctx->seed_cur); +} + +static struct llama_sampler * llama_sampler_adaptive_p_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_adaptive_p *) smpl->ctx; + auto * result = llama_sampler_init_adaptive_p(ctx->target, ctx->decay, ctx->seed); + auto * result_ctx = (llama_sampler_adaptive_p *) result->ctx; + + // copy everything (target, decay, seed, and RNG are already set) + result_ctx->weighted_sum = ctx->weighted_sum; + result_ctx->total_weight = ctx->total_weight; + result_ctx->pending_token_id = ctx->pending_token_id; + result_ctx->pending_token_idx = ctx->pending_token_idx; + + return result; +} + +static void llama_sampler_adaptive_p_free(struct llama_sampler * smpl) { + delete (llama_sampler_adaptive_p *) smpl->ctx; +} + +static struct llama_sampler_i llama_sampler_adaptive_p_i = { + /* .name = */ llama_sampler_adaptive_p_name, + /* .accept = */ llama_sampler_adaptive_p_accept, + /* .apply = */ llama_sampler_adaptive_p_apply, + /* .reset = */ llama_sampler_adaptive_p_reset, + /* .clone = */ llama_sampler_adaptive_p_clone, + /* .free = */ llama_sampler_adaptive_p_free, + /* .backend_init = */ nullptr, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ nullptr, + /* .backend_set_input = */ nullptr, +}; + +struct llama_sampler * llama_sampler_init_adaptive_p( + float target, + float decay, + uint32_t seed +) { + auto seed_cur = get_rng_seed(seed); + float clamped_decay = std::clamp(decay, 0.0f, 0.99f); + return llama_sampler_init( + /* .iface = */ &llama_sampler_adaptive_p_i, + /* .ctx = */ new llama_sampler_adaptive_p { + /* .target = */ target, + /* .decay = */ clamped_decay, + /* .seed = */ seed, + /* .seed_cur = */ seed_cur, + /* .rng = */ std::mt19937(seed_cur), + /* .weighted_sum = */ target / (1.0f - clamped_decay), + /* .total_weight = */ 1.0f / (1.0f - clamped_decay), + /* .original_probs = */ {}, + /* .pending_token_id = */ LLAMA_TOKEN_NULL, + /* .pending_token_idx = */ -1 + } + ); +} + // logit-bias struct llama_sampler_logit_bias : public llama_sampler_backend { diff --git a/tools/cli/README.md b/tools/cli/README.md index 9e9f9bead1..3b6f0708ed 100644 --- a/tools/cli/README.md +++ b/tools/cli/README.md @@ -113,6 +113,8 @@ | `--top-k N` | top-k sampling (default: 40, 0 = disabled)
(env: LLAMA_ARG_TOP_K) | | `--top-p N` | top-p sampling (default: 0.9, 1.0 = disabled) | | `--min-p N` | min-p sampling (default: 0.1, 0.0 = disabled) | +| `--adaptive-target N` | adaptive-p: select tokens near this probability (valid range 0.0 to 1.0; negative = disabled) | +| `--adaptive-decay N` | adaptive-p: EMA decay for adaptation; effective history length ≈ 1/(1-decay) tokens (valid range 0.0 - 0.99) | | `--top-nsigma N` | top-n-sigma sampling (default: -1.0, -1.0 = disabled) | | `--xtc-probability N` | xtc probability (default: 0.0, 0.0 = disabled) | | `--xtc-threshold N` | xtc threshold (default: 0.1, 1.0 = disabled) | diff --git a/tools/completion/README.md b/tools/completion/README.md index 9eea90eed0..a16be3f684 100644 --- a/tools/completion/README.md +++ b/tools/completion/README.md @@ -436,6 +436,19 @@ The Min-P sampling method was designed as an alternative to Top-P, and aims to e Example usage: `--min-p 0.05` +### Adaptive-P Sampling + +- `--adaptive-target N`: select tokens near this probability (valid range 0.0 to 1.0; negative = disabled) +- `--adaptive-decay N`: EMA decay for adaptation; history ≈ 1/(1-decay) tokens (valid range 0.0 - 0.99) + +Adaptive-P: Select tokens near a configurable target probability over time. + +The adaptive-p sampler transforms the token probability distribution to favor tokens that fall near a user-configurable probability target. Internally, the sampler maintains an exponential moving average of the *ORIGINAL* probabilities of selected tokens at each sampling step. It uses this EMA to compute an adapted target probability at each sampling step, thus maintaining the desired target probability over time. Only mild truncation before this sampler is recommended. It is suggested to apply min-p before adaptive-p as the only other active sampler. + +Recommended starting values: `--adaptive-target 0.55 --adaptive-decay 0.9` + +For more info, refer to: [llama.cpp#17927](https://github.com/ggml-org/llama.cpp/pull/17927) + ### Locally Typical Sampling - `--typical N`: Enable locally typical sampling with parameter p (default: 1.0, 1.0 = disabled). diff --git a/tools/server/README.md b/tools/server/README.md index b1a0583d1c..9fe8938768 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -130,6 +130,8 @@ For the ful list of features, please refer to [server's changelog](https://githu | `--top-k N` | top-k sampling (default: 40, 0 = disabled)
(env: LLAMA_ARG_TOP_K) | | `--top-p N` | top-p sampling (default: 0.9, 1.0 = disabled) | | `--min-p N` | min-p sampling (default: 0.1, 0.0 = disabled) | +| `--adaptive-target N` | adaptive-p: select tokens near this probability (valid range 0.0 to 1.0; negative = disabled) | +| `--adaptive-decay N` | adaptive-p: EMA decay for adaptation; effective history length ≈ 1/(1-decay) tokens (valid range 0.0 - 0.99) | | `--top-nsigma N` | top-n-sigma sampling (default: -1.0, -1.0 = disabled) | | `--xtc-probability N` | xtc probability (default: 0.0, 0.0 = disabled) | | `--xtc-threshold N` | xtc threshold (default: 0.1, 1.0 = disabled) | diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index aa4590e4ec..35ec7ad2ad 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -204,6 +204,8 @@ task_params server_task::params_from_json_cmpl( params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat); params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau); params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta); + params.sampling.adaptive_target = json_value(data, "adaptive_target", defaults.sampling.adaptive_target); + params.sampling.adaptive_decay = json_value(data, "adaptive_decay", defaults.sampling.adaptive_decay); params.sampling.seed = json_value(data, "seed", defaults.sampling.seed); params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs); params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep);