re-write + change parameters + simplify
This commit is contained in:
parent
67a733670e
commit
a96ddd743a
|
|
@ -164,35 +164,35 @@ enum common_params_sampling_config : uint64_t {
|
||||||
struct common_params_sampling {
|
struct common_params_sampling {
|
||||||
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
|
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
|
||||||
|
|
||||||
int32_t n_prev = 64; // number of previous tokens to remember
|
int32_t n_prev = 64; // number of previous tokens to remember
|
||||||
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
|
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
|
||||||
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
|
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
|
||||||
int32_t top_k = 40; // <= 0 to use vocab size
|
int32_t top_k = 40; // <= 0 to use vocab size
|
||||||
float top_p = 0.95f; // 1.0 = disabled
|
float top_p = 0.95f; // 1.0 = disabled
|
||||||
float min_p = 0.05f; // 0.0 = disabled
|
float min_p = 0.05f; // 0.0 = disabled
|
||||||
float xtc_probability = 0.00f; // 0.0 = disabled
|
float xtc_probability = 0.00f; // 0.0 = disabled
|
||||||
float xtc_threshold = 0.10f; // > 0.5 disables XTC
|
float xtc_threshold = 0.10f; // > 0.5 disables XTC
|
||||||
float typ_p = 1.00f; // typical_p, 1.0 = disabled
|
float typ_p = 1.00f; // typical_p, 1.0 = disabled
|
||||||
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
|
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
|
||||||
float dynatemp_range = 0.00f; // 0.0 = disabled
|
float dynatemp_range = 0.00f; // 0.0 = disabled
|
||||||
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
|
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
|
||||||
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
|
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
|
||||||
float penalty_repeat = 1.00f; // 1.0 = disabled
|
float penalty_repeat = 1.00f; // 1.0 = disabled
|
||||||
float penalty_freq = 0.00f; // 0.0 = disabled
|
float penalty_freq = 0.00f; // 0.0 = disabled
|
||||||
float penalty_present = 0.00f; // 0.0 = disabled
|
float penalty_present = 0.00f; // 0.0 = disabled
|
||||||
float dry_multiplier = 0.0f; // 0.0 = disabled; DRY repetition penalty for tokens extending repetition:
|
float dry_multiplier = 0.0f; // 0.0 = disabled; DRY repetition penalty for tokens extending repetition:
|
||||||
float dry_base = 1.75f; // 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length)
|
float dry_base = 1.75f; // 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length)
|
||||||
int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty
|
int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty
|
||||||
int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
|
int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
|
||||||
float power_law_target = -1.0f; // target probability for Power Law sampling (valid range 0.0 to 1.0; <0 = disabled)
|
float power_law_target = -1.0f; // select tokens near this probability (valid range 0.0 to 1.0; <0 = disabled)
|
||||||
int32_t power_law_window_size = 10; // rolling window size for target adaptation in Power Law sampling (≤0 = fixed target)
|
float power_law_decay = 0.9f; // decay rate for target adaptation over time. lower values -> faster but less stable adaptation. (valid range 0.0 to 1.0; ≤0 = no adaptation)
|
||||||
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
||||||
float top_n_sigma = -1.00f; // -1.0 = disabled
|
float top_n_sigma = -1.00f; // -1.0 = disabled
|
||||||
float mirostat_tau = 5.00f; // target entropy
|
float mirostat_tau = 5.00f; // target entropy
|
||||||
float mirostat_eta = 0.10f; // learning rate
|
float mirostat_eta = 0.10f; // learning rate
|
||||||
bool ignore_eos = false;
|
bool ignore_eos = false;
|
||||||
bool no_perf = false; // disable performance metrics
|
bool no_perf = false; // disable performance metrics
|
||||||
bool timing_per_token = false;
|
bool timing_per_token = false;
|
||||||
|
|
||||||
uint64_t user_sampling_config = 0; // bitfield to track user-specified samplers
|
uint64_t user_sampling_config = 0; // bitfield to track user-specified samplers
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1289,24 +1289,28 @@ extern "C" {
|
||||||
const char ** seq_breakers,
|
const char ** seq_breakers,
|
||||||
size_t num_breakers);
|
size_t num_breakers);
|
||||||
|
|
||||||
/// @details power-law sampler - reshapes probability distribution to target specific probability ranges
|
/// power-law
|
||||||
|
///
|
||||||
|
/// this sampler implements a power law probability transformation with adaptive
|
||||||
|
/// target tracking. it reshapes token probability distributions to favor tokens near a
|
||||||
|
/// configurable target probability, rather than always selecting from the highest probability
|
||||||
|
/// candidates. it is ideal for creative, unpredictable text generation.
|
||||||
///
|
///
|
||||||
/// this sampler is like `greedy`, `dist`, and `mirostat` in that it actually selects a token ID
|
/// this sampler is like `greedy`, `dist`, and `mirostat` in that it actually selects a token ID
|
||||||
/// rather than just transforming logits. therefore it must always be the last sampler in the
|
/// rather than just transforming logits. therefore it must always be the last sampler in the
|
||||||
/// sampler chain.
|
/// sampler chain.
|
||||||
///
|
///
|
||||||
/// it is recommended to only perform minimal truncation before this sampler.
|
/// minimal truncation before this sampler is recommended.
|
||||||
///
|
///
|
||||||
/// @param target target probability (valid range 0.0 to 1.0; <0 = disabled)
|
/// @param target select tokens near this probability (valid range 0.0 to 1.0; <0 = disabled)
|
||||||
/// @param window_size rolling window size for target adaptation (≤0 = fixed target)
|
/// @param decay decay rate for target adaptation over time. lower values -> faster but less stable adaptation. (valid range 0.0 to 1.0; ≤0 = no adaptation)
|
||||||
/// @param seed RNG seed
|
|
||||||
///
|
///
|
||||||
/// ref: https://github.com/MrJackSpade/llama.cpp/tree/master (original impl, documentation)
|
/// ref: https://github.com/MrJackSpade/llama.cpp/tree/master (original impl)
|
||||||
/// ref: https://github.com/ggml-org/llama.cpp/pull/17927 (llama.cpp PR)
|
/// ref: https://github.com/ggml-org/llama.cpp/pull/17927 (llama.cpp PR)
|
||||||
LLAMA_API struct llama_sampler * llama_sampler_init_power_law(
|
LLAMA_API struct llama_sampler * llama_sampler_init_power_law(
|
||||||
float target,
|
float target,
|
||||||
int32_t window_size,
|
float decay,
|
||||||
uint32_t seed);
|
uint32_t seed);
|
||||||
|
|
||||||
LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias(
|
LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias(
|
||||||
int32_t n_vocab,
|
int32_t n_vocab,
|
||||||
|
|
|
||||||
|
|
@ -2315,133 +2315,62 @@ struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, floa
|
||||||
|
|
||||||
// power-law
|
// power-law
|
||||||
//
|
//
|
||||||
|
// this sampler implements a power law probability transformation with adaptive
|
||||||
|
// target tracking. it reshapes token probability distributions to favor tokens near a
|
||||||
|
// configurable target probability, rather than always selecting from the highest probability
|
||||||
|
// candidates. it is ideal for creative, unpredictable text generation.
|
||||||
|
//
|
||||||
// this sampler is like `greedy`, `dist`, and `mirostat` in that it actually selects a token ID
|
// this sampler is like `greedy`, `dist`, and `mirostat` in that it actually selects a token ID
|
||||||
// rather than just transforming logits. therefore it must always be the last sampler in the
|
// rather than just transforming logits. therefore it must always be the last sampler in the
|
||||||
// sampler chain.
|
// sampler chain.
|
||||||
//
|
//
|
||||||
// it is recommended to only perform minimal truncation before this sampler.
|
// minimal truncation before this sampler is recommended.
|
||||||
//
|
//
|
||||||
// ref: https://github.com/MrJackSpade/llama.cpp/tree/master (original impl, documentation)
|
// ref: https://github.com/MrJackSpade/llama.cpp/tree/master (original impl)
|
||||||
// ref: https://github.com/ggml-org/llama.cpp/pull/17927 (llama.cpp PR)
|
// ref: https://github.com/ggml-org/llama.cpp/pull/17927 (llama.cpp PR)
|
||||||
|
|
||||||
struct llama_sampler_power_law {
|
struct llama_sampler_power_law {
|
||||||
const float target;
|
|
||||||
const int32_t window_size;
|
|
||||||
|
|
||||||
const uint32_t seed;
|
// the desired average probability for selected tokens (0.0 to 1.0)
|
||||||
std::mt19937 rng;
|
// higher values favor more probable tokens (more deterministic)
|
||||||
ring_buffer<float> window;
|
// lower values favor less probable tokens (more creative)
|
||||||
|
// negative values disable Power Law sampling (sample from distribution as-is)
|
||||||
|
const float target;
|
||||||
|
|
||||||
|
// controls how quickly history influence fades (0.0 to 0.99)
|
||||||
|
// lower values = faster adaptation, more reactive to recent tokens
|
||||||
|
// higher values = slower adaptation, more stable over time
|
||||||
|
// effective history length ≈ 1/(1-decay) tokens
|
||||||
|
// examples: decay=0.5 → ~2 tokens, decay=0.9 → ~10, decay=0.95 → ~20
|
||||||
|
// internally clamped to <= 0.99 to prevent unbounded accumulation
|
||||||
|
const float decay;
|
||||||
|
|
||||||
|
const uint32_t seed;
|
||||||
|
std::mt19937 rng;
|
||||||
|
|
||||||
|
// historical token probabilities weighted by recency
|
||||||
|
float weighted_sum;
|
||||||
|
// sum of weights, converges to 1/(1-decay)
|
||||||
|
float total_weight;
|
||||||
};
|
};
|
||||||
|
|
||||||
static const char * llama_sampler_power_law_name(const struct llama_sampler * /*smpl*/) {
|
static const char * llama_sampler_power_law_name(const struct llama_sampler * /*smpl*/) {
|
||||||
return "power-law";
|
return "power-law";
|
||||||
}
|
}
|
||||||
|
|
||||||
// Computes the target probability for the current sampling step.
|
// compute the adaptive target probability for the current sampling step
|
||||||
//
|
static float llama_sampler_power_law_compute_target(const llama_sampler_power_law * ctx, float decay) {
|
||||||
// The target determines which token probabilities the power law distribution
|
if (ctx->total_weight == 0.0f) {
|
||||||
// will favor. This function implements a dynamic feedback mechanism to maintain
|
// if there is no history, just use base target
|
||||||
// an average selection probability close to the base target over time.
|
return ctx->target;
|
||||||
//
|
|
||||||
// When the window is empty:
|
|
||||||
// - Returns the base target value (ctx->target)
|
|
||||||
//
|
|
||||||
// When the window has entries:
|
|
||||||
// - Calculates what the next target should be to keep the weighted average
|
|
||||||
// of selected token probabilities equal to ctx->target
|
|
||||||
// - Uses exponential decay weighting: newer values have more influence
|
|
||||||
//
|
|
||||||
// Exponential Decay Weighting:
|
|
||||||
// After inserting the new value, the weights will be:
|
|
||||||
// new_value: weight = 1 (age 0, newest)
|
|
||||||
// rat(0): weight = decay (age 1)
|
|
||||||
// rat(1): weight = decay^2 (age 2)
|
|
||||||
// ...
|
|
||||||
// rat(sz-2): weight = decay^(sz-1)
|
|
||||||
// rat(sz-1): evicted (oldest)
|
|
||||||
//
|
|
||||||
// The "effective window size" is approximately 1/(1-decay):
|
|
||||||
// decay=0.9 → effective window ≈ 10 tokens
|
|
||||||
// decay=0.95 → effective window ≈ 20 tokens
|
|
||||||
// decay=1.0 → no decay, equivalent to simple average (original behavior)
|
|
||||||
//
|
|
||||||
// Formula derivation:
|
|
||||||
// We want the weighted average after insertion to equal target:
|
|
||||||
//
|
|
||||||
// (new_value * 1 + Σ rat(i) * decay^(i+1)) / total_weight = target
|
|
||||||
//
|
|
||||||
// Where total_weight = 1 + decay + decay^2 + ... + decay^(sz-1)
|
|
||||||
// = (1 - decay^sz) / (1 - decay) [geometric series]
|
|
||||||
//
|
|
||||||
// Solving for new_value:
|
|
||||||
// new_value = target * total_weight - decay * Σ rat(i) * decay^i
|
|
||||||
//
|
|
||||||
// The factor of 'decay' on the sum accounts for all existing values
|
|
||||||
// shifting one position older when the new value is inserted.
|
|
||||||
//
|
|
||||||
// The exponential decay helps prevent "fishtailing" - a phenomenon where
|
|
||||||
// forced high-probability selections (when the model is very confident)
|
|
||||||
// cause the algorithm to overcorrect with many low-probability selections,
|
|
||||||
// then swing back the other way. By decaying old values, the influence of
|
|
||||||
// forced selections fades faster, reducing oscillation amplitude and
|
|
||||||
// recovery time.
|
|
||||||
//
|
|
||||||
// Finally, the computed target is clamped to [min_target, max_target] to
|
|
||||||
// prevent extreme values that could destabilize sampling.
|
|
||||||
//
|
|
||||||
static float llama_sampler_power_law_compute_target(
|
|
||||||
const llama_sampler_power_law * ctx,
|
|
||||||
float min_target,
|
|
||||||
float max_target,
|
|
||||||
float tail_decay) {
|
|
||||||
|
|
||||||
float computed_target = ctx->target;
|
|
||||||
size_t sz = ctx->window.size();
|
|
||||||
|
|
||||||
if (sz > 0) {
|
|
||||||
// Check if window is at capacity (oldest element will be evicted on next push)
|
|
||||||
// Use the window_size parameter from context, not a capacity() method
|
|
||||||
const bool window_full = (sz == (size_t)ctx->window_size);
|
|
||||||
|
|
||||||
// Compute weighted sum with exponential decay
|
|
||||||
// rat(0) = newest in buffer, gets weight 1
|
|
||||||
// rat(i) gets weight decay^i
|
|
||||||
//
|
|
||||||
// When window is full: exclude oldest element (it will be evicted)
|
|
||||||
// When window is not full: include all elements (nothing evicted)
|
|
||||||
float weighted_sum = 0.0f;
|
|
||||||
float weight = 1.0f;
|
|
||||||
size_t elements_to_sum = window_full ? (sz - 1) : sz;
|
|
||||||
|
|
||||||
for (size_t i = 0; i < elements_to_sum; ++i) {
|
|
||||||
weighted_sum += ctx->window.rat(i) * weight;
|
|
||||||
weight *= tail_decay;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Shift weights to account for new value taking position 0
|
|
||||||
// All existing values age by 1, so multiply their weights by decay
|
|
||||||
float shifted_weighted_sum = weighted_sum * tail_decay;
|
|
||||||
|
|
||||||
// Compute total weight after new value is inserted
|
|
||||||
// When full: sz elements remain (oldest evicted, new added)
|
|
||||||
// When not full: sz + 1 elements (new added, nothing evicted)
|
|
||||||
size_t final_element_count = window_full ? sz : (sz + 1);
|
|
||||||
|
|
||||||
float total_weight;
|
|
||||||
if (std::abs(tail_decay - 1.0f) < FLT_EPSILON) {
|
|
||||||
total_weight = (float) final_element_count;
|
|
||||||
} else {
|
|
||||||
total_weight = (1.0f - std::pow(tail_decay, (float) final_element_count)) / (1.0f - tail_decay);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Solve for the new value that achieves target weighted average
|
|
||||||
float next_value = (ctx->target * total_weight) - shifted_weighted_sum;
|
|
||||||
|
|
||||||
// Clamp to allowed range
|
|
||||||
computed_target = std::max(min_target, std::min(next_value, max_target));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return computed_target;
|
// maintain a running weighted sum with exponential decay
|
||||||
|
float new_total_weight = 1.0f + decay * ctx->total_weight;
|
||||||
|
float next_value = ctx->target * new_total_weight - decay * ctx->weighted_sum;
|
||||||
|
|
||||||
|
// clamp to [0.0, 1.0]
|
||||||
|
return std::max(0.0f, std::min(next_value, 1.0f));
|
||||||
}
|
}
|
||||||
|
|
||||||
static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||||
|
|
@ -2455,30 +2384,25 @@ static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_tok
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// clamp decay to avoid degenerate case at 1.0 (unbounded accumulation)
|
||||||
|
const float decay = std::min(ctx->decay, 0.99f);
|
||||||
|
|
||||||
// fixed power law transform parameters
|
// fixed power law transform parameters
|
||||||
const float distribution_width = 0.3f;
|
const float distribution_width = 0.3f;
|
||||||
const float peak_logit_value = 5.0f;
|
const float peak_logit_value = 5.0f;
|
||||||
const float tail_heaviness = 2.0f;
|
const float tail_heaviness = 2.0f;
|
||||||
|
|
||||||
// target computation parameters
|
// get the original probabilities
|
||||||
const float min_target = 0.0f;
|
|
||||||
const float max_target = 1.0f;
|
|
||||||
const float tail_decay = 0.50f; // exponential decay factor for history weighting
|
|
||||||
// lower = faster response, higher = more stability
|
|
||||||
// effective window ≈ 1/(1-decay) ≈ 20 tokens
|
|
||||||
|
|
||||||
// compute probabilities to get the "original" values
|
|
||||||
llama_sampler_softmax_impl(cur_p, false);
|
llama_sampler_softmax_impl(cur_p, false);
|
||||||
|
|
||||||
// store original probabilities (used for future target adaptation)
|
// store the original probabilities (needed for history update after selection)
|
||||||
std::vector<float> original_probs;
|
std::vector<float> original_probs;
|
||||||
original_probs.reserve(cur_p->size);
|
original_probs.reserve(cur_p->size);
|
||||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||||
original_probs.push_back(cur_p->data[i].p);
|
original_probs.push_back(cur_p->data[i].p);
|
||||||
}
|
}
|
||||||
|
|
||||||
// calculate adaptive target
|
float computed_target = llama_sampler_power_law_compute_target(ctx, decay);
|
||||||
float computed_target = llama_sampler_power_law_compute_target(ctx, min_target, max_target, tail_decay);
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// power law transform
|
// power law transform
|
||||||
|
|
@ -2492,40 +2416,30 @@ static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_tok
|
||||||
|
|
||||||
llama_sampler_softmax_impl(cur_p, false);
|
llama_sampler_softmax_impl(cur_p, false);
|
||||||
|
|
||||||
// sample from the transformed distribution
|
// sample from transformed distribution
|
||||||
const int idx = llama_sample_dist(cur_p, ctx->rng);
|
const int idx = llama_sample_dist(cur_p, ctx->rng);
|
||||||
cur_p->selected = idx;
|
cur_p->selected = idx;
|
||||||
|
|
||||||
// uncomment this to log the target values and history window contents for every token
|
// update running history with the original probability of the selected token
|
||||||
//
|
float original_p = original_probs[idx];
|
||||||
// fprintf(stderr, "power_law: window_size=%zu/%d values=[",
|
ctx->weighted_sum = original_p + decay * ctx->weighted_sum;
|
||||||
// ctx->window.size(), ctx->window_size);
|
ctx->total_weight = 1.0f + decay * ctx->total_weight;
|
||||||
// for (size_t i = 0; i < ctx->window.size(); ++i) {
|
|
||||||
// fprintf(stderr, "%.1f", ctx->window.rat(i));
|
|
||||||
// if (i < ctx->window.size() - 1) fprintf(stderr, ",");
|
|
||||||
// }
|
|
||||||
// fprintf(stderr, "] computed_target=%.4f selected_token=%d orig_prob=%.4f\n",
|
|
||||||
// computed_target, cur_p->data[idx].id, original_probs[idx]);
|
|
||||||
// fflush(stderr);
|
|
||||||
|
|
||||||
// add the ORIGINAL probability to the rolling window
|
|
||||||
float original_p = original_probs[idx];
|
|
||||||
|
|
||||||
ctx->window.push_back(original_p);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static void llama_sampler_power_law_reset(struct llama_sampler * smpl) {
|
static void llama_sampler_power_law_reset(struct llama_sampler * smpl) {
|
||||||
auto * ctx = (llama_sampler_power_law *) smpl->ctx;
|
auto * ctx = (llama_sampler_power_law *) smpl->ctx;
|
||||||
ctx->window = ring_buffer<float>(ctx->window_size);
|
ctx->weighted_sum = 0.0f;
|
||||||
|
ctx->total_weight = 0.0f;
|
||||||
}
|
}
|
||||||
|
|
||||||
static struct llama_sampler * llama_sampler_power_law_clone(const struct llama_sampler * smpl) {
|
static struct llama_sampler * llama_sampler_power_law_clone(const struct llama_sampler * smpl) {
|
||||||
const auto * ctx = (const llama_sampler_power_law *) smpl->ctx;
|
const auto * ctx = (const llama_sampler_power_law *) smpl->ctx;
|
||||||
auto * result = llama_sampler_init_power_law(ctx->target, ctx->window_size, ctx->seed);
|
auto * result = llama_sampler_init_power_law(ctx->target, ctx->decay, ctx->seed);
|
||||||
auto * result_ctx = (llama_sampler_power_law *) result->ctx;
|
auto * result_ctx = (llama_sampler_power_law *) result->ctx;
|
||||||
|
|
||||||
result_ctx->rng = ctx->rng;
|
result_ctx->rng = ctx->rng;
|
||||||
result_ctx->window = ctx->window;
|
result_ctx->weighted_sum = ctx->weighted_sum;
|
||||||
|
result_ctx->total_weight = ctx->total_weight;
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
@ -2545,7 +2459,7 @@ static struct llama_sampler_i llama_sampler_power_law_i = {
|
||||||
|
|
||||||
struct llama_sampler * llama_sampler_init_power_law(
|
struct llama_sampler * llama_sampler_init_power_law(
|
||||||
float target,
|
float target,
|
||||||
int32_t window_size,
|
float decay,
|
||||||
uint32_t seed
|
uint32_t seed
|
||||||
) {
|
) {
|
||||||
auto seed_cur = get_rng_seed(seed);
|
auto seed_cur = get_rng_seed(seed);
|
||||||
|
|
@ -2553,10 +2467,11 @@ struct llama_sampler * llama_sampler_init_power_law(
|
||||||
/* .iface = */ &llama_sampler_power_law_i,
|
/* .iface = */ &llama_sampler_power_law_i,
|
||||||
/* .ctx = */ new llama_sampler_power_law {
|
/* .ctx = */ new llama_sampler_power_law {
|
||||||
/* .target = */ target,
|
/* .target = */ target,
|
||||||
/* .window_size = */ window_size,
|
/* .decay = */ decay,
|
||||||
/* .seed = */ seed_cur,
|
/* .seed = */ seed_cur,
|
||||||
/* .rng = */ std::mt19937(seed_cur),
|
/* .rng = */ std::mt19937(seed_cur),
|
||||||
/* .window = */ ring_buffer<float>(window_size),
|
/* .weighted_sum = */ 0.0f,
|
||||||
|
/* .total_weight = */ 0.0f,
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -182,33 +182,33 @@ task_params server_task::params_from_json_cmpl(
|
||||||
params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms);
|
params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms);
|
||||||
params.response_fields = json_value(data, "response_fields", std::vector<std::string>());
|
params.response_fields = json_value(data, "response_fields", std::vector<std::string>());
|
||||||
|
|
||||||
params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k);
|
params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k);
|
||||||
params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p);
|
params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p);
|
||||||
params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p);
|
params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p);
|
||||||
params.sampling.top_n_sigma = json_value(data, "top_n_sigma", defaults.sampling.top_n_sigma);
|
params.sampling.top_n_sigma = json_value(data, "top_n_sigma", defaults.sampling.top_n_sigma);
|
||||||
params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability);
|
params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability);
|
||||||
params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold);
|
params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold);
|
||||||
params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p);
|
params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p);
|
||||||
params.sampling.temp = json_value(data, "temperature", defaults.sampling.temp);
|
params.sampling.temp = json_value(data, "temperature", defaults.sampling.temp);
|
||||||
params.sampling.dynatemp_range = json_value(data, "dynatemp_range", defaults.sampling.dynatemp_range);
|
params.sampling.dynatemp_range = json_value(data, "dynatemp_range", defaults.sampling.dynatemp_range);
|
||||||
params.sampling.dynatemp_exponent = json_value(data, "dynatemp_exponent", defaults.sampling.dynatemp_exponent);
|
params.sampling.dynatemp_exponent = json_value(data, "dynatemp_exponent", defaults.sampling.dynatemp_exponent);
|
||||||
params.sampling.penalty_last_n = json_value(data, "repeat_last_n", defaults.sampling.penalty_last_n);
|
params.sampling.penalty_last_n = json_value(data, "repeat_last_n", defaults.sampling.penalty_last_n);
|
||||||
params.sampling.penalty_repeat = json_value(data, "repeat_penalty", defaults.sampling.penalty_repeat);
|
params.sampling.penalty_repeat = json_value(data, "repeat_penalty", defaults.sampling.penalty_repeat);
|
||||||
params.sampling.penalty_freq = json_value(data, "frequency_penalty", defaults.sampling.penalty_freq);
|
params.sampling.penalty_freq = json_value(data, "frequency_penalty", defaults.sampling.penalty_freq);
|
||||||
params.sampling.penalty_present = json_value(data, "presence_penalty", defaults.sampling.penalty_present);
|
params.sampling.penalty_present = json_value(data, "presence_penalty", defaults.sampling.penalty_present);
|
||||||
params.sampling.dry_multiplier = json_value(data, "dry_multiplier", defaults.sampling.dry_multiplier);
|
params.sampling.dry_multiplier = json_value(data, "dry_multiplier", defaults.sampling.dry_multiplier);
|
||||||
params.sampling.dry_base = json_value(data, "dry_base", defaults.sampling.dry_base);
|
params.sampling.dry_base = json_value(data, "dry_base", defaults.sampling.dry_base);
|
||||||
params.sampling.dry_allowed_length = json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length);
|
params.sampling.dry_allowed_length = json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length);
|
||||||
params.sampling.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n);
|
params.sampling.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n);
|
||||||
params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat);
|
params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat);
|
||||||
params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau);
|
params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau);
|
||||||
params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta);
|
params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta);
|
||||||
params.sampling.power_law_target = json_value(data, "power_law_target", defaults.sampling.power_law_target);
|
params.sampling.power_law_target = json_value(data, "power_law_target", defaults.sampling.power_law_target);
|
||||||
params.sampling.power_law_window_size = json_value(data, "power_law_window_size", defaults.sampling.power_law_window_size);
|
params.sampling.power_law_decay = json_value(data, "power_law_decay", defaults.sampling.power_law_decay);
|
||||||
params.sampling.seed = json_value(data, "seed", defaults.sampling.seed);
|
params.sampling.seed = json_value(data, "seed", defaults.sampling.seed);
|
||||||
params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs);
|
params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs);
|
||||||
params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep);
|
params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep);
|
||||||
params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs);
|
params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs);
|
||||||
|
|
||||||
params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min);
|
params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min);
|
||||||
params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max);
|
params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue