sampling : make rng selection fully modular
This commit is contained in:
parent
7bb5d4b890
commit
2826de3189
|
|
@ -1584,6 +1584,19 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
params.sampling.blue_noise = true;
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--rng-type"}, "{mt19937,lowbias32}",
|
||||
"RNG type for sampling (default: mt19937)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
if (value == "mt19937") {
|
||||
params.sampling.rng_type = LLAMA_RNG_TYPE_MT19937;
|
||||
} else if (value == "lowbias32") {
|
||||
params.sampling.rng_type = LLAMA_RNG_TYPE_LOWBIAS32;
|
||||
} else {
|
||||
throw std::invalid_argument("invalid value");
|
||||
}
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--temp"}, "N",
|
||||
string_format("temperature (default: %.2f)", (double)params.sampling.temp),
|
||||
|
|
|
|||
|
|
@ -210,6 +210,7 @@ struct common_params_sampling {
|
|||
bool no_perf = false; // disable performance metrics
|
||||
bool timing_per_token = false;
|
||||
bool blue_noise = false; // use blue noise RNG instead of white noise for dist sampler
|
||||
enum llama_rng_type rng_type = LLAMA_RNG_TYPE_MT19937; // RNG type for dist sampler
|
||||
|
||||
uint64_t user_sampling_config = 0; // bitfield to track user-specified samplers
|
||||
|
||||
|
|
|
|||
|
|
@ -167,11 +167,14 @@ std::string common_params_sampling::print() const {
|
|||
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
|
||||
"\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n"
|
||||
"\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, top_n_sigma = %.3f, temp = %.3f\n"
|
||||
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f, adaptive_target = %.3f, adaptive_decay = %.3f",
|
||||
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f, adaptive_target = %.3f, adaptive_decay = %.3f\n"
|
||||
"\tblue_noise = %s, rng_type = %s",
|
||||
penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
|
||||
dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n,
|
||||
top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, top_n_sigma, temp,
|
||||
mirostat, mirostat_eta, mirostat_tau, adaptive_target, adaptive_decay);
|
||||
mirostat, mirostat_eta, mirostat_tau, adaptive_target, adaptive_decay,
|
||||
blue_noise ? "true" : "false",
|
||||
rng_type == LLAMA_RNG_TYPE_LOWBIAS32 ? "lowbias32" : "mt19937");
|
||||
|
||||
return std::string(result);
|
||||
}
|
||||
|
|
@ -313,11 +316,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
|
|||
samplers.push_back(llama_sampler_init_adaptive_p(params.adaptive_target, params.adaptive_decay, params.seed));
|
||||
} else {
|
||||
// default: sample from distribution
|
||||
if (params.blue_noise) {
|
||||
samplers.push_back(llama_sampler_init_dist_blue_noise(params.seed));
|
||||
} else {
|
||||
samplers.push_back(llama_sampler_init_dist(params.seed));
|
||||
}
|
||||
samplers.push_back(llama_sampler_init_dist_rng(params.seed, params.blue_noise, params.rng_type));
|
||||
}
|
||||
} else if (params.mirostat == 1) {
|
||||
samplers.push_back(llama_sampler_init_temp(params.temp));
|
||||
|
|
|
|||
|
|
@ -188,6 +188,11 @@ extern "C" {
|
|||
|
||||
LLAMA_API const char * llama_flash_attn_type_name(enum llama_flash_attn_type flash_attn_type);
|
||||
|
||||
enum llama_rng_type {
|
||||
LLAMA_RNG_TYPE_MT19937 = 0,
|
||||
LLAMA_RNG_TYPE_LOWBIAS32 = 1,
|
||||
};
|
||||
|
||||
enum llama_split_mode {
|
||||
LLAMA_SPLIT_MODE_NONE = 0, // single GPU
|
||||
LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs
|
||||
|
|
@ -1295,8 +1300,8 @@ extern "C" {
|
|||
LLAMA_API struct llama_sampler * llama_sampler_init_greedy(void);
|
||||
|
||||
/// seed == LLAMA_DEFAULT_SEED to use a random seed.
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed);
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_dist_blue_noise(uint32_t seed);
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed);
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_dist_rng(uint32_t seed, bool blue_noise, enum llama_rng_type rng_type);
|
||||
|
||||
/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
||||
/// Setting k <= 0 makes this a noop
|
||||
|
|
|
|||
|
|
@ -633,8 +633,8 @@ struct llama_dist_rng_mt19937 : llama_dist_rng {
|
|||
struct llama_dist_rng_blue : llama_dist_rng {
|
||||
blue_noise_rng bn_rng;
|
||||
|
||||
llama_dist_rng_blue(uint32_t seed)
|
||||
: bn_rng(16, std::make_unique<llama_dist_rng_lowbias32>(seed)) {}
|
||||
llama_dist_rng_blue(std::unique_ptr<llama_dist_rng> source)
|
||||
: bn_rng(16, std::move(source)) {}
|
||||
|
||||
bool requires_sorted() override { return true; }
|
||||
|
||||
|
|
@ -1591,32 +1591,34 @@ static struct llama_sampler_i llama_sampler_dist_i = {
|
|||
/* .backend_set_input = */ llama_sampler_dist_backend_set_input,
|
||||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
|
||||
static std::unique_ptr<llama_dist_rng> make_dist_rng(uint32_t seed, enum llama_rng_type rng_type) {
|
||||
switch (rng_type) {
|
||||
case LLAMA_RNG_TYPE_LOWBIAS32: return std::make_unique<llama_dist_rng_lowbias32>(seed);
|
||||
case LLAMA_RNG_TYPE_MT19937:
|
||||
default: return std::make_unique<llama_dist_rng_mt19937>(seed);
|
||||
}
|
||||
}
|
||||
|
||||
struct llama_sampler * llama_sampler_init_dist_rng(uint32_t seed, bool blue_noise, enum llama_rng_type rng_type) {
|
||||
auto seed_cur = get_rng_seed(seed);
|
||||
auto rng = make_dist_rng(seed_cur, rng_type);
|
||||
if (blue_noise) {
|
||||
rng = std::make_unique<llama_dist_rng_blue>(std::move(rng));
|
||||
}
|
||||
return llama_sampler_init(
|
||||
/* .iface = */ &llama_sampler_dist_i,
|
||||
/* .ctx = */ new llama_sampler_dist {
|
||||
{"dist"},
|
||||
/* .seed = */ seed,
|
||||
/* .seed_cur = */ seed_cur,
|
||||
/* .rng = */ std::make_unique<llama_dist_rng_mt19937>(seed_cur),
|
||||
/* .rng = */ std::move(rng),
|
||||
/* .inp_uniform = */ nullptr,
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
struct llama_sampler * llama_sampler_init_dist_blue_noise(uint32_t seed) {
|
||||
auto seed_cur = get_rng_seed(seed);
|
||||
return llama_sampler_init(
|
||||
/* .iface = */ &llama_sampler_dist_i,
|
||||
/* .ctx = */ new llama_sampler_dist {
|
||||
{"dist-blue-noise"},
|
||||
/* .seed = */ seed,
|
||||
/* .seed_cur = */ seed_cur,
|
||||
/* .rng = */ std::make_unique<llama_dist_rng_blue>(seed_cur),
|
||||
/* .inp_uniform = */ nullptr,
|
||||
}
|
||||
);
|
||||
struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
|
||||
return llama_sampler_init_dist_rng(seed, false, LLAMA_RNG_TYPE_MT19937);
|
||||
}
|
||||
|
||||
// top-k
|
||||
|
|
|
|||
|
|
@ -67,6 +67,7 @@ json task_params::to_json(bool only_metrics) const {
|
|||
{"n_discard", n_discard},
|
||||
{"ignore_eos", sampling.ignore_eos},
|
||||
{"blue_noise", sampling.blue_noise},
|
||||
{"rng_type", sampling.rng_type == LLAMA_RNG_TYPE_LOWBIAS32 ? "lowbias32" : "mt19937"},
|
||||
{"stream", stream},
|
||||
{"n_probs", sampling.n_probs},
|
||||
{"min_keep", sampling.min_keep},
|
||||
|
|
@ -127,6 +128,7 @@ json task_params::to_json(bool only_metrics) const {
|
|||
{"n_discard", n_discard},
|
||||
{"ignore_eos", sampling.ignore_eos},
|
||||
{"blue_noise", sampling.blue_noise},
|
||||
{"rng_type", sampling.rng_type == LLAMA_RNG_TYPE_LOWBIAS32 ? "lowbias32" : "mt19937"},
|
||||
{"stream", stream},
|
||||
{"logit_bias", format_logit_bias(sampling.logit_bias)},
|
||||
{"n_probs", sampling.n_probs},
|
||||
|
|
@ -470,6 +472,14 @@ task_params server_task::params_from_json_cmpl(
|
|||
}
|
||||
|
||||
params.sampling.blue_noise = json_value(data, "blue_noise", params_base.sampling.blue_noise);
|
||||
{
|
||||
const auto rng_source = json_value(data, "rng_type", std::string(""));
|
||||
if (rng_source == "lowbias32") {
|
||||
params.sampling.rng_type = LLAMA_RNG_TYPE_LOWBIAS32;
|
||||
} else if (rng_source == "mt19937") {
|
||||
params.sampling.rng_type = LLAMA_RNG_TYPE_MT19937;
|
||||
}
|
||||
}
|
||||
params.sampling.ignore_eos = json_value(data, "ignore_eos", params_base.sampling.ignore_eos);
|
||||
if (params.sampling.ignore_eos) {
|
||||
params.sampling.logit_bias.insert(
|
||||
|
|
|
|||
Loading…
Reference in New Issue