sampling : make rng selection fully modular

This commit is contained in:
Jan Boon 2026-02-09 05:23:58 +00:00
parent 7bb5d4b890
commit 2826de3189
6 changed files with 55 additions and 25 deletions

View File

@ -1584,6 +1584,19 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.sampling.blue_noise = true;
}
).set_sparam());
add_opt(common_arg(
{"--rng-type"}, "{mt19937,lowbias32}",
"RNG type for sampling (default: mt19937)",
[](common_params & params, const std::string & value) {
if (value == "mt19937") {
params.sampling.rng_type = LLAMA_RNG_TYPE_MT19937;
} else if (value == "lowbias32") {
params.sampling.rng_type = LLAMA_RNG_TYPE_LOWBIAS32;
} else {
throw std::invalid_argument("invalid value");
}
}
).set_sparam());
add_opt(common_arg(
{"--temp"}, "N",
string_format("temperature (default: %.2f)", (double)params.sampling.temp),

View File

@ -210,6 +210,7 @@ struct common_params_sampling {
bool no_perf = false; // disable performance metrics
bool timing_per_token = false;
bool blue_noise = false; // use blue noise RNG instead of white noise for dist sampler
enum llama_rng_type rng_type = LLAMA_RNG_TYPE_MT19937; // RNG type for dist sampler
uint64_t user_sampling_config = 0; // bitfield to track user-specified samplers

View File

@ -167,11 +167,14 @@ std::string common_params_sampling::print() const {
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
"\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n"
"\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, top_n_sigma = %.3f, temp = %.3f\n"
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f, adaptive_target = %.3f, adaptive_decay = %.3f",
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f, adaptive_target = %.3f, adaptive_decay = %.3f\n"
"\tblue_noise = %s, rng_type = %s",
penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n,
top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, top_n_sigma, temp,
mirostat, mirostat_eta, mirostat_tau, adaptive_target, adaptive_decay);
mirostat, mirostat_eta, mirostat_tau, adaptive_target, adaptive_decay,
blue_noise ? "true" : "false",
rng_type == LLAMA_RNG_TYPE_LOWBIAS32 ? "lowbias32" : "mt19937");
return std::string(result);
}
@ -313,11 +316,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
samplers.push_back(llama_sampler_init_adaptive_p(params.adaptive_target, params.adaptive_decay, params.seed));
} else {
// default: sample from distribution
if (params.blue_noise) {
samplers.push_back(llama_sampler_init_dist_blue_noise(params.seed));
} else {
samplers.push_back(llama_sampler_init_dist(params.seed));
}
samplers.push_back(llama_sampler_init_dist_rng(params.seed, params.blue_noise, params.rng_type));
}
} else if (params.mirostat == 1) {
samplers.push_back(llama_sampler_init_temp(params.temp));

View File

@ -188,6 +188,11 @@ extern "C" {
LLAMA_API const char * llama_flash_attn_type_name(enum llama_flash_attn_type flash_attn_type);
enum llama_rng_type {
LLAMA_RNG_TYPE_MT19937 = 0,
LLAMA_RNG_TYPE_LOWBIAS32 = 1,
};
enum llama_split_mode {
LLAMA_SPLIT_MODE_NONE = 0, // single GPU
LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs
@ -1295,8 +1300,8 @@ extern "C" {
LLAMA_API struct llama_sampler * llama_sampler_init_greedy(void);
/// seed == LLAMA_DEFAULT_SEED to use a random seed.
LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed);
LLAMA_API struct llama_sampler * llama_sampler_init_dist_blue_noise(uint32_t seed);
LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed);
LLAMA_API struct llama_sampler * llama_sampler_init_dist_rng(uint32_t seed, bool blue_noise, enum llama_rng_type rng_type);
/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
/// Setting k <= 0 makes this a noop

View File

@ -633,8 +633,8 @@ struct llama_dist_rng_mt19937 : llama_dist_rng {
struct llama_dist_rng_blue : llama_dist_rng {
blue_noise_rng bn_rng;
llama_dist_rng_blue(uint32_t seed)
: bn_rng(16, std::make_unique<llama_dist_rng_lowbias32>(seed)) {}
llama_dist_rng_blue(std::unique_ptr<llama_dist_rng> source)
: bn_rng(16, std::move(source)) {}
bool requires_sorted() override { return true; }
@ -1591,32 +1591,34 @@ static struct llama_sampler_i llama_sampler_dist_i = {
/* .backend_set_input = */ llama_sampler_dist_backend_set_input,
};
struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
static std::unique_ptr<llama_dist_rng> make_dist_rng(uint32_t seed, enum llama_rng_type rng_type) {
switch (rng_type) {
case LLAMA_RNG_TYPE_LOWBIAS32: return std::make_unique<llama_dist_rng_lowbias32>(seed);
case LLAMA_RNG_TYPE_MT19937:
default: return std::make_unique<llama_dist_rng_mt19937>(seed);
}
}
struct llama_sampler * llama_sampler_init_dist_rng(uint32_t seed, bool blue_noise, enum llama_rng_type rng_type) {
auto seed_cur = get_rng_seed(seed);
auto rng = make_dist_rng(seed_cur, rng_type);
if (blue_noise) {
rng = std::make_unique<llama_dist_rng_blue>(std::move(rng));
}
return llama_sampler_init(
/* .iface = */ &llama_sampler_dist_i,
/* .ctx = */ new llama_sampler_dist {
{"dist"},
/* .seed = */ seed,
/* .seed_cur = */ seed_cur,
/* .rng = */ std::make_unique<llama_dist_rng_mt19937>(seed_cur),
/* .rng = */ std::move(rng),
/* .inp_uniform = */ nullptr,
}
);
}
struct llama_sampler * llama_sampler_init_dist_blue_noise(uint32_t seed) {
auto seed_cur = get_rng_seed(seed);
return llama_sampler_init(
/* .iface = */ &llama_sampler_dist_i,
/* .ctx = */ new llama_sampler_dist {
{"dist-blue-noise"},
/* .seed = */ seed,
/* .seed_cur = */ seed_cur,
/* .rng = */ std::make_unique<llama_dist_rng_blue>(seed_cur),
/* .inp_uniform = */ nullptr,
}
);
struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
return llama_sampler_init_dist_rng(seed, false, LLAMA_RNG_TYPE_MT19937);
}
// top-k

View File

@ -67,6 +67,7 @@ json task_params::to_json(bool only_metrics) const {
{"n_discard", n_discard},
{"ignore_eos", sampling.ignore_eos},
{"blue_noise", sampling.blue_noise},
{"rng_type", sampling.rng_type == LLAMA_RNG_TYPE_LOWBIAS32 ? "lowbias32" : "mt19937"},
{"stream", stream},
{"n_probs", sampling.n_probs},
{"min_keep", sampling.min_keep},
@ -127,6 +128,7 @@ json task_params::to_json(bool only_metrics) const {
{"n_discard", n_discard},
{"ignore_eos", sampling.ignore_eos},
{"blue_noise", sampling.blue_noise},
{"rng_type", sampling.rng_type == LLAMA_RNG_TYPE_LOWBIAS32 ? "lowbias32" : "mt19937"},
{"stream", stream},
{"logit_bias", format_logit_bias(sampling.logit_bias)},
{"n_probs", sampling.n_probs},
@ -470,6 +472,14 @@ task_params server_task::params_from_json_cmpl(
}
params.sampling.blue_noise = json_value(data, "blue_noise", params_base.sampling.blue_noise);
{
const auto rng_source = json_value(data, "rng_type", std::string(""));
if (rng_source == "lowbias32") {
params.sampling.rng_type = LLAMA_RNG_TYPE_LOWBIAS32;
} else if (rng_source == "mt19937") {
params.sampling.rng_type = LLAMA_RNG_TYPE_MT19937;
}
}
params.sampling.ignore_eos = json_value(data, "ignore_eos", params_base.sampling.ignore_eos);
if (params.sampling.ignore_eos) {
params.sampling.logit_bias.insert(