diff --git a/common/arg.cpp b/common/arg.cpp index 924b5198a2..7181e31cd7 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1584,6 +1584,19 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.sampling.blue_noise = true; } ).set_sparam()); + add_opt(common_arg( + {"--rng-type"}, "{mt19937,lowbias32}", + "RNG type for sampling (default: mt19937)", + [](common_params & params, const std::string & value) { + if (value == "mt19937") { + params.sampling.rng_type = LLAMA_RNG_TYPE_MT19937; + } else if (value == "lowbias32") { + params.sampling.rng_type = LLAMA_RNG_TYPE_LOWBIAS32; + } else { + throw std::invalid_argument("invalid value"); + } + } + ).set_sparam()); add_opt(common_arg( {"--temp"}, "N", string_format("temperature (default: %.2f)", (double)params.sampling.temp), diff --git a/common/common.h b/common/common.h index 0a76a1e26c..662eeb51e2 100644 --- a/common/common.h +++ b/common/common.h @@ -210,6 +210,7 @@ struct common_params_sampling { bool no_perf = false; // disable performance metrics bool timing_per_token = false; bool blue_noise = false; // use blue noise RNG instead of white noise for dist sampler + enum llama_rng_type rng_type = LLAMA_RNG_TYPE_MT19937; // RNG type for dist sampler uint64_t user_sampling_config = 0; // bitfield to track user-specified samplers diff --git a/common/sampling.cpp b/common/sampling.cpp index 2811eb3a48..f98bd7b311 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -167,11 +167,14 @@ std::string common_params_sampling::print() const { "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n" "\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n" "\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, top_n_sigma = %.3f, temp = %.3f\n" - "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f, adaptive_target = %.3f, adaptive_decay = %.3f", + "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f, adaptive_target = %.3f, adaptive_decay = %.3f\n" + "\tblue_noise = %s, rng_type = %s", penalty_last_n, penalty_repeat, penalty_freq, penalty_present, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, top_n_sigma, temp, - mirostat, mirostat_eta, mirostat_tau, adaptive_target, adaptive_decay); + mirostat, mirostat_eta, mirostat_tau, adaptive_target, adaptive_decay, + blue_noise ? "true" : "false", + rng_type == LLAMA_RNG_TYPE_LOWBIAS32 ? "lowbias32" : "mt19937"); return std::string(result); } @@ -313,11 +316,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st samplers.push_back(llama_sampler_init_adaptive_p(params.adaptive_target, params.adaptive_decay, params.seed)); } else { // default: sample from distribution - if (params.blue_noise) { - samplers.push_back(llama_sampler_init_dist_blue_noise(params.seed)); - } else { - samplers.push_back(llama_sampler_init_dist(params.seed)); - } + samplers.push_back(llama_sampler_init_dist_rng(params.seed, params.blue_noise, params.rng_type)); } } else if (params.mirostat == 1) { samplers.push_back(llama_sampler_init_temp(params.temp)); diff --git a/include/llama.h b/include/llama.h index 22f08e1683..d9f4acc5c7 100644 --- a/include/llama.h +++ b/include/llama.h @@ -188,6 +188,11 @@ extern "C" { LLAMA_API const char * llama_flash_attn_type_name(enum llama_flash_attn_type flash_attn_type); + enum llama_rng_type { + LLAMA_RNG_TYPE_MT19937 = 0, + LLAMA_RNG_TYPE_LOWBIAS32 = 1, + }; + enum llama_split_mode { LLAMA_SPLIT_MODE_NONE = 0, // single GPU LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs @@ -1295,8 +1300,8 @@ extern "C" { LLAMA_API struct llama_sampler * llama_sampler_init_greedy(void); /// seed == LLAMA_DEFAULT_SEED to use a random seed. - LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed); - LLAMA_API struct llama_sampler * llama_sampler_init_dist_blue_noise(uint32_t seed); + LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed); + LLAMA_API struct llama_sampler * llama_sampler_init_dist_rng(uint32_t seed, bool blue_noise, enum llama_rng_type rng_type); /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 /// Setting k <= 0 makes this a noop diff --git a/src/llama-sampler.cpp b/src/llama-sampler.cpp index 34f8a62ab4..0a74f2d26f 100644 --- a/src/llama-sampler.cpp +++ b/src/llama-sampler.cpp @@ -633,8 +633,8 @@ struct llama_dist_rng_mt19937 : llama_dist_rng { struct llama_dist_rng_blue : llama_dist_rng { blue_noise_rng bn_rng; - llama_dist_rng_blue(uint32_t seed) - : bn_rng(16, std::make_unique(seed)) {} + llama_dist_rng_blue(std::unique_ptr source) + : bn_rng(16, std::move(source)) {} bool requires_sorted() override { return true; } @@ -1591,32 +1591,34 @@ static struct llama_sampler_i llama_sampler_dist_i = { /* .backend_set_input = */ llama_sampler_dist_backend_set_input, }; -struct llama_sampler * llama_sampler_init_dist(uint32_t seed) { +static std::unique_ptr make_dist_rng(uint32_t seed, enum llama_rng_type rng_type) { + switch (rng_type) { + case LLAMA_RNG_TYPE_LOWBIAS32: return std::make_unique(seed); + case LLAMA_RNG_TYPE_MT19937: + default: return std::make_unique(seed); + } +} + +struct llama_sampler * llama_sampler_init_dist_rng(uint32_t seed, bool blue_noise, enum llama_rng_type rng_type) { auto seed_cur = get_rng_seed(seed); + auto rng = make_dist_rng(seed_cur, rng_type); + if (blue_noise) { + rng = std::make_unique(std::move(rng)); + } return llama_sampler_init( /* .iface = */ &llama_sampler_dist_i, /* .ctx = */ new llama_sampler_dist { {"dist"}, /* .seed = */ seed, /* .seed_cur = */ seed_cur, - /* .rng = */ std::make_unique(seed_cur), + /* .rng = */ std::move(rng), /* .inp_uniform = */ nullptr, } ); } -struct llama_sampler * llama_sampler_init_dist_blue_noise(uint32_t seed) { - auto seed_cur = get_rng_seed(seed); - return llama_sampler_init( - /* .iface = */ &llama_sampler_dist_i, - /* .ctx = */ new llama_sampler_dist { - {"dist-blue-noise"}, - /* .seed = */ seed, - /* .seed_cur = */ seed_cur, - /* .rng = */ std::make_unique(seed_cur), - /* .inp_uniform = */ nullptr, - } - ); +struct llama_sampler * llama_sampler_init_dist(uint32_t seed) { + return llama_sampler_init_dist_rng(seed, false, LLAMA_RNG_TYPE_MT19937); } // top-k diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index 16c3cf12d0..d717165daa 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -67,6 +67,7 @@ json task_params::to_json(bool only_metrics) const { {"n_discard", n_discard}, {"ignore_eos", sampling.ignore_eos}, {"blue_noise", sampling.blue_noise}, + {"rng_type", sampling.rng_type == LLAMA_RNG_TYPE_LOWBIAS32 ? "lowbias32" : "mt19937"}, {"stream", stream}, {"n_probs", sampling.n_probs}, {"min_keep", sampling.min_keep}, @@ -127,6 +128,7 @@ json task_params::to_json(bool only_metrics) const { {"n_discard", n_discard}, {"ignore_eos", sampling.ignore_eos}, {"blue_noise", sampling.blue_noise}, + {"rng_type", sampling.rng_type == LLAMA_RNG_TYPE_LOWBIAS32 ? "lowbias32" : "mt19937"}, {"stream", stream}, {"logit_bias", format_logit_bias(sampling.logit_bias)}, {"n_probs", sampling.n_probs}, @@ -470,6 +472,14 @@ task_params server_task::params_from_json_cmpl( } params.sampling.blue_noise = json_value(data, "blue_noise", params_base.sampling.blue_noise); + { + const auto rng_source = json_value(data, "rng_type", std::string("")); + if (rng_source == "lowbias32") { + params.sampling.rng_type = LLAMA_RNG_TYPE_LOWBIAS32; + } else if (rng_source == "mt19937") { + params.sampling.rng_type = LLAMA_RNG_TYPE_MT19937; + } + } params.sampling.ignore_eos = json_value(data, "ignore_eos", params_base.sampling.ignore_eos); if (params.sampling.ignore_eos) { params.sampling.logit_bias.insert(