Merge 660a3b275f into 18ddaea2ae
This commit is contained in:
commit
aa4dadf4b1
|
|
@ -1596,6 +1596,26 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
}
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--adaptive-target"}, "N",
|
||||
string_format("adaptive-p: select tokens near this probability (valid range 0.0 "
|
||||
"to 1.0; negative = disabled) (default: %.2f)\n"
|
||||
"[(more info)](https://github.com/ggml-org/llama.cpp/pull/17927)",
|
||||
(double)params.sampling.adaptive_target),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.adaptive_target = std::stof(value);
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--adaptive-decay"}, "N",
|
||||
string_format("adaptive-p: decay rate for target adaptation over time. lower values "
|
||||
"are more reactive, higher values are more stable.\n"
|
||||
"(valid range 0.0 to 0.99) (default: %.2f)",
|
||||
(double)params.sampling.adaptive_decay),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.adaptive_decay = std::stof(value);
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--dynatemp-range"}, "N",
|
||||
string_format("dynamic temperature range (default: %.1f, 0.0 = disabled)", (double)params.sampling.dynatemp_range),
|
||||
|
|
|
|||
|
|
@ -117,6 +117,7 @@ enum common_sampler_type {
|
|||
COMMON_SAMPLER_TYPE_INFILL = 9,
|
||||
COMMON_SAMPLER_TYPE_PENALTIES = 10,
|
||||
COMMON_SAMPLER_TYPE_TOP_N_SIGMA = 11,
|
||||
COMMON_SAMPLER_TYPE_ADAPTIVE_P = 12,
|
||||
};
|
||||
|
||||
// dimensionality reduction methods, used by cvector-generator
|
||||
|
|
@ -164,32 +165,34 @@ enum common_params_sampling_config : uint64_t {
|
|||
struct common_params_sampling {
|
||||
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
|
||||
|
||||
int32_t n_prev = 64; // number of previous tokens to remember
|
||||
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
|
||||
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
|
||||
int32_t top_k = 40; // <= 0 to use vocab size
|
||||
float top_p = 0.95f; // 1.0 = disabled
|
||||
float min_p = 0.05f; // 0.0 = disabled
|
||||
float xtc_probability = 0.00f; // 0.0 = disabled
|
||||
float xtc_threshold = 0.10f; // > 0.5 disables XTC
|
||||
float typ_p = 1.00f; // typical_p, 1.0 = disabled
|
||||
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
|
||||
float dynatemp_range = 0.00f; // 0.0 = disabled
|
||||
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
|
||||
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
|
||||
float penalty_repeat = 1.00f; // 1.0 = disabled
|
||||
float penalty_freq = 0.00f; // 0.0 = disabled
|
||||
float penalty_present = 0.00f; // 0.0 = disabled
|
||||
float dry_multiplier = 0.0f; // 0.0 = disabled; DRY repetition penalty for tokens extending repetition:
|
||||
float dry_base = 1.75f; // 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length)
|
||||
int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty
|
||||
int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
|
||||
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
||||
float top_n_sigma = -1.00f;// -1.0 = disabled
|
||||
float mirostat_tau = 5.00f; // target entropy
|
||||
float mirostat_eta = 0.10f; // learning rate
|
||||
int32_t n_prev = 64; // number of previous tokens to remember
|
||||
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
|
||||
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
|
||||
int32_t top_k = 40; // <= 0 to use vocab size
|
||||
float top_p = 0.95f; // 1.0 = disabled
|
||||
float min_p = 0.05f; // 0.0 = disabled
|
||||
float xtc_probability = 0.00f; // 0.0 = disabled
|
||||
float xtc_threshold = 0.10f; // > 0.5 disables XTC
|
||||
float typ_p = 1.00f; // typical_p, 1.0 = disabled
|
||||
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
|
||||
float dynatemp_range = 0.00f; // 0.0 = disabled
|
||||
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
|
||||
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
|
||||
float penalty_repeat = 1.00f; // 1.0 = disabled
|
||||
float penalty_freq = 0.00f; // 0.0 = disabled
|
||||
float penalty_present = 0.00f; // 0.0 = disabled
|
||||
float dry_multiplier = 0.0f; // 0.0 = disabled; DRY repetition penalty for tokens extending repetition:
|
||||
float dry_base = 1.75f; // 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length)
|
||||
int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty
|
||||
int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
|
||||
float adaptive_target = -1.0f; // select tokens near this probability (valid range 0.0 to 1.0; negative = disabled)
|
||||
float adaptive_decay = 0.90f; // EMA decay for adaptation; history ≈ 1/(1-decay) tokens (0.0 - 0.99)
|
||||
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
||||
float top_n_sigma = -1.00f; // -1.0 = disabled
|
||||
float mirostat_tau = 5.00f; // target entropy
|
||||
float mirostat_eta = 0.10f; // learning rate
|
||||
bool ignore_eos = false;
|
||||
bool no_perf = false; // disable performance metrics
|
||||
bool no_perf = false; // disable performance metrics
|
||||
bool timing_per_token = false;
|
||||
|
||||
uint64_t user_sampling_config = 0; // bitfield to track user-specified samplers
|
||||
|
|
|
|||
|
|
@ -150,11 +150,11 @@ std::string common_params_sampling::print() const {
|
|||
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
|
||||
"\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n"
|
||||
"\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, top_n_sigma = %.3f, temp = %.3f\n"
|
||||
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
|
||||
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f, adaptive_target = %.3f, adaptive_decay = %.3f",
|
||||
penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
|
||||
dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n,
|
||||
top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, top_n_sigma, temp,
|
||||
mirostat, mirostat_eta, mirostat_tau);
|
||||
mirostat, mirostat_eta, mirostat_tau, adaptive_target, adaptive_decay);
|
||||
|
||||
return std::string(result);
|
||||
}
|
||||
|
|
@ -236,6 +236,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||
}
|
||||
|
||||
if (params.mirostat == 0) {
|
||||
|
||||
bool use_adaptive_p = false; // see below
|
||||
|
||||
for (const auto & cnstr : params.samplers) {
|
||||
switch (cnstr) {
|
||||
case COMMON_SAMPLER_TYPE_DRY:
|
||||
|
|
@ -245,43 +248,54 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||
for (const auto & str : params.dry_sequence_breakers) {
|
||||
c_breakers.push_back(str.c_str());
|
||||
}
|
||||
|
||||
samplers.push_back(llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
|
||||
samplers.push_back(llama_sampler_init_dry(vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
|
||||
}
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TOP_K:
|
||||
samplers.push_back(llama_sampler_init_top_k (params.top_k));
|
||||
samplers.push_back(llama_sampler_init_top_k(params.top_k));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TOP_P:
|
||||
samplers.push_back(llama_sampler_init_top_p (params.top_p, params.min_keep));
|
||||
samplers.push_back(llama_sampler_init_top_p(params.top_p, params.min_keep));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
|
||||
samplers.push_back(llama_sampler_init_top_n_sigma(params.top_n_sigma));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_MIN_P:
|
||||
samplers.push_back(llama_sampler_init_min_p (params.min_p, params.min_keep));
|
||||
samplers.push_back(llama_sampler_init_min_p(params.min_p, params.min_keep));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_XTC:
|
||||
samplers.push_back(llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
|
||||
samplers.push_back(llama_sampler_init_xtc(params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TYPICAL_P:
|
||||
samplers.push_back(llama_sampler_init_typical (params.typ_p, params.min_keep));
|
||||
samplers.push_back(llama_sampler_init_typical(params.typ_p, params.min_keep));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
||||
samplers.push_back(llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
|
||||
samplers.push_back(llama_sampler_init_temp_ext(params.temp, params.dynatemp_range, params.dynatemp_exponent));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_INFILL:
|
||||
samplers.push_back(llama_sampler_init_infill (vocab));
|
||||
samplers.push_back(llama_sampler_init_infill(vocab));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_PENALTIES:
|
||||
samplers.push_back(llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
|
||||
samplers.push_back(llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_ADAPTIVE_P:
|
||||
// the `adaptive-p` sampler is like `dist` and `mirostat` in that it selects
|
||||
// a single token, so we will add `dist` at the end of the chain by default,
|
||||
// unless the user specifically included `adaptive-p`. we set this flag here
|
||||
// so we know to add the sampler at the very end.
|
||||
use_adaptive_p = true;
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false && "unknown sampler type");
|
||||
}
|
||||
}
|
||||
|
||||
samplers.push_back(llama_sampler_init_dist(params.seed));
|
||||
if (use_adaptive_p) {
|
||||
// only if user explicitly included adaptive-p sampler
|
||||
samplers.push_back(llama_sampler_init_adaptive_p(params.adaptive_target, params.adaptive_decay, params.seed));
|
||||
} else {
|
||||
// default: sample from distribution
|
||||
samplers.push_back(llama_sampler_init_dist(params.seed));
|
||||
}
|
||||
} else if (params.mirostat == 1) {
|
||||
samplers.push_back(llama_sampler_init_temp(params.temp));
|
||||
samplers.push_back(llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
|
||||
|
|
@ -567,6 +581,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
|
|||
case COMMON_SAMPLER_TYPE_XTC: return 'x';
|
||||
case COMMON_SAMPLER_TYPE_INFILL: return 'i';
|
||||
case COMMON_SAMPLER_TYPE_PENALTIES: return 'e';
|
||||
case COMMON_SAMPLER_TYPE_ADAPTIVE_P: return 'a';
|
||||
default : return '?';
|
||||
}
|
||||
}
|
||||
|
|
@ -583,6 +598,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
|
|||
case COMMON_SAMPLER_TYPE_XTC: return "xtc";
|
||||
case COMMON_SAMPLER_TYPE_INFILL: return "infill";
|
||||
case COMMON_SAMPLER_TYPE_PENALTIES: return "penalties";
|
||||
case COMMON_SAMPLER_TYPE_ADAPTIVE_P: return "adaptive_p";
|
||||
default : return "";
|
||||
}
|
||||
}
|
||||
|
|
@ -599,6 +615,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
|
|||
{ "xtc", COMMON_SAMPLER_TYPE_XTC },
|
||||
{ "infill", COMMON_SAMPLER_TYPE_INFILL },
|
||||
{ "penalties", COMMON_SAMPLER_TYPE_PENALTIES },
|
||||
{ "adaptive_p", COMMON_SAMPLER_TYPE_ADAPTIVE_P },
|
||||
};
|
||||
|
||||
// since samplers names are written multiple ways
|
||||
|
|
@ -614,6 +631,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
|
|||
{ "typ", COMMON_SAMPLER_TYPE_TYPICAL_P },
|
||||
{ "min-p", COMMON_SAMPLER_TYPE_MIN_P },
|
||||
{ "temp", COMMON_SAMPLER_TYPE_TEMPERATURE },
|
||||
{ "adaptive-p", COMMON_SAMPLER_TYPE_ADAPTIVE_P },
|
||||
};
|
||||
|
||||
std::vector<common_sampler_type> samplers;
|
||||
|
|
|
|||
|
|
@ -1313,6 +1313,33 @@ extern "C" {
|
|||
const char ** seq_breakers,
|
||||
size_t num_breakers);
|
||||
|
||||
/// adaptive-p: select tokens near a configurable target probability over time.
|
||||
///
|
||||
/// the adaptive-p sampler transforms the token probability distribution to favor tokens
|
||||
/// that fall near a user-configurable probability target.
|
||||
///
|
||||
/// internally, the sampler maintains an exponential moving average of the *ORIGINAL*
|
||||
/// probabilities of selected tokens at each sampling step. it uses this EMA to compute an
|
||||
/// adapted target probability at each sampling step, thus maintaining the desired target
|
||||
/// probability over time.
|
||||
///
|
||||
/// adaptive-p selects a token ID rather than just mutating candidates, so it must be last
|
||||
/// in the sampler chain (like mirostat, dist, greedy).
|
||||
///
|
||||
/// only mild truncation before this sampler is recommended. we suggest applying min-p
|
||||
/// before adaptive-p as the only other active sampler in the chain.
|
||||
///
|
||||
/// @param target select tokens near this probability (valid range 0.0 to 1.0; negative = disabled)
|
||||
/// @param decay EMA decay for adaptation; history ≈ 1/(1-decay) tokens (valid range 0.0 - 0.99)
|
||||
/// @param seed RNG seed
|
||||
///
|
||||
/// ref: https://github.com/ggml-org/llama.cpp/pull/17927
|
||||
///
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_adaptive_p(
|
||||
float target,
|
||||
float decay,
|
||||
uint32_t seed);
|
||||
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias(
|
||||
int32_t n_vocab,
|
||||
int32_t n_logit_bias,
|
||||
|
|
|
|||
|
|
@ -2340,6 +2340,138 @@ struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, floa
|
|||
return result;
|
||||
}
|
||||
|
||||
// adaptive-p sampler state
|
||||
//
|
||||
// maintains an exponential moving average of the *ORIGINAL* probabilities
|
||||
// of selected tokens, used to compute an adapted target at each sampling step.
|
||||
//
|
||||
// see llama.h for a full description of the sampler
|
||||
//
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/17927
|
||||
//
|
||||
struct llama_sampler_adaptive_p {
|
||||
const float target; // target probability (0.0 - 1.0; negative = disabled)
|
||||
const float decay; // EMA decay; history ~= 1/(1-decay) tokens (0.0 - 0.99)
|
||||
const uint32_t seed; // RNG seed
|
||||
std::mt19937 rng; // RNG
|
||||
float weighted_sum; // sum(p_i * decay^i)
|
||||
float total_weight; // sum(decay^i), converges to 1/(1-decay)
|
||||
std::vector<float> original_probs; // pre-transform probs, cached for EMA update
|
||||
};
|
||||
|
||||
// adaptive probability transformation constants
|
||||
static constexpr float DISTRIBUTION_WIDTH = 0.3f;
|
||||
static constexpr float PEAK_LOGIT_VALUE = 5.0f;
|
||||
static constexpr float SHARPNESS = 10.0f;
|
||||
static constexpr float INV_WIDTH = 1.0f / DISTRIBUTION_WIDTH;
|
||||
|
||||
static const char * llama_sampler_adaptive_p_name(const struct llama_sampler * /*smpl*/) {
|
||||
return "adaptive-p";
|
||||
}
|
||||
|
||||
static void llama_sampler_adaptive_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||
auto * ctx = (llama_sampler_adaptive_p *) smpl->ctx;
|
||||
|
||||
if (ctx->target < 0.0f) {
|
||||
// at negative target values, adaptive-p is no-op
|
||||
// we simply sample from the existing distribution
|
||||
llama_sampler_softmax_impl(cur_p, false);
|
||||
cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
|
||||
return;
|
||||
}
|
||||
|
||||
// softmax and store the original probabilities
|
||||
llama_sampler_softmax_impl(cur_p, false);
|
||||
ctx->original_probs.resize(cur_p->size);
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
ctx->original_probs[i] = cur_p->data[i].p;
|
||||
}
|
||||
|
||||
// compute the adapted target probability for the current sampling step
|
||||
auto target = std::clamp(ctx->target, 0.0f, 1.0f);
|
||||
float adapted_target = std::clamp(
|
||||
ctx->total_weight == 0.0f ? target : 2.0f * target - (ctx->weighted_sum / ctx->total_weight),
|
||||
0.0f, 1.0f
|
||||
);
|
||||
|
||||
// adaptive probability transform
|
||||
//
|
||||
// quadratic near target for fine differentiation, transitioning to linear decay in the
|
||||
// tails. unbounded negative logits ensure proper suppression of far-from-target tokens
|
||||
// after the softmax.
|
||||
//
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
float dist = std::abs((cur_p->data[i].p - adapted_target) * INV_WIDTH);
|
||||
cur_p->data[i].logit = PEAK_LOGIT_VALUE - SHARPNESS * dist * dist / (1.0f + dist);
|
||||
}
|
||||
|
||||
// softmax and sample from the transformed distribution
|
||||
llama_sampler_softmax_impl(cur_p, false);
|
||||
const int idx = llama_sample_dist(cur_p, ctx->rng);
|
||||
cur_p->selected = idx;
|
||||
|
||||
// update EMA with the original probability of the selected token
|
||||
ctx->weighted_sum = ctx->original_probs[idx] + ctx->decay * ctx->weighted_sum;
|
||||
ctx->total_weight = 1.0f + ctx->decay * ctx->total_weight;
|
||||
}
|
||||
|
||||
static void llama_sampler_adaptive_p_reset(struct llama_sampler * smpl) {
|
||||
auto * ctx = (llama_sampler_adaptive_p *) smpl->ctx;
|
||||
// ctx->target and ctx->decay never change after init, so it's safe to keep them as is.
|
||||
// original_probs is completely overwritten on every call to _apply.
|
||||
// so we only need to reset the EMA state.
|
||||
ctx->weighted_sum = ctx->target / (1.0f - ctx->decay);
|
||||
ctx->total_weight = 1.0f / (1.0f - ctx->decay);
|
||||
}
|
||||
|
||||
static struct llama_sampler * llama_sampler_adaptive_p_clone(const struct llama_sampler * smpl) {
|
||||
const auto * ctx = (const llama_sampler_adaptive_p *) smpl->ctx;
|
||||
auto * result = llama_sampler_init_adaptive_p(ctx->target, ctx->decay, ctx->seed);
|
||||
auto * result_ctx = (llama_sampler_adaptive_p *) result->ctx;
|
||||
|
||||
// copy everything (target, decay, and seed are already set)
|
||||
result_ctx->original_probs = ctx->original_probs;
|
||||
result_ctx->weighted_sum = ctx->weighted_sum;
|
||||
result_ctx->total_weight = ctx->total_weight;
|
||||
result_ctx->rng = ctx->rng;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
static void llama_sampler_adaptive_p_free(struct llama_sampler * smpl) {
|
||||
delete (llama_sampler_adaptive_p *) smpl->ctx;
|
||||
}
|
||||
|
||||
static struct llama_sampler_i llama_sampler_adaptive_p_i = {
|
||||
/* .name = */ llama_sampler_adaptive_p_name,
|
||||
/* .accept = */ nullptr,
|
||||
/* .apply = */ llama_sampler_adaptive_p_apply,
|
||||
/* .reset = */ llama_sampler_adaptive_p_reset,
|
||||
/* .clone = */ llama_sampler_adaptive_p_clone,
|
||||
/* .free = */ llama_sampler_adaptive_p_free,
|
||||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_init_adaptive_p(
|
||||
float target,
|
||||
float decay,
|
||||
uint32_t seed
|
||||
) {
|
||||
auto seed_cur = get_rng_seed(seed);
|
||||
float clamped_decay = std::clamp(decay, 0.0f, 0.99f);
|
||||
return llama_sampler_init(
|
||||
/* .iface = */ &llama_sampler_adaptive_p_i,
|
||||
/* .ctx = */ new llama_sampler_adaptive_p {
|
||||
/* .target = */ target,
|
||||
/* .decay = */ clamped_decay,
|
||||
/* .seed = */ seed_cur,
|
||||
/* .rng = */ std::mt19937(seed_cur),
|
||||
/* .weighted_sum = */ target / (1.0f - clamped_decay),
|
||||
/* .total_weight = */ 1.0f / (1.0f - clamped_decay),
|
||||
/* .original_probs = */ {},
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
// logit-bias
|
||||
|
||||
struct llama_sampler_logit_bias {
|
||||
|
|
|
|||
|
|
@ -201,6 +201,8 @@ task_params server_task::params_from_json_cmpl(
|
|||
params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat);
|
||||
params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau);
|
||||
params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta);
|
||||
params.sampling.adaptive_target = json_value(data, "adaptive_target", defaults.sampling.adaptive_target);
|
||||
params.sampling.adaptive_decay = json_value(data, "adaptive_decay", defaults.sampling.adaptive_decay);
|
||||
params.sampling.seed = json_value(data, "seed", defaults.sampling.seed);
|
||||
params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs);
|
||||
params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep);
|
||||
|
|
|
|||
Loading…
Reference in New Issue