From 7f433763b6683b7f6e5a323729acf73a1a4e8ec9 Mon Sep 17 00:00:00 2001 From: Jan Boon Date: Mon, 9 Feb 2026 02:47:42 +0000 Subject: [PATCH] sampling : implement disabled branch to support blue noise --- src/llama-sampler.cpp | 54 ++++++++++++++++++++++++++++++++++++++----- 1 file changed, 48 insertions(+), 6 deletions(-) diff --git a/src/llama-sampler.cpp b/src/llama-sampler.cpp index 275e4c5b56..02258d981b 100644 --- a/src/llama-sampler.cpp +++ b/src/llama-sampler.cpp @@ -448,10 +448,12 @@ struct blue_noise_rng { struct llama_dist_rng { virtual ~llama_dist_rng() = default; - virtual bool requires_sorted() = 0; // whether the RNG requires sorted input for proper properties + // whether the RNG requires sorted input for proper properties + // this also indicates whether the RNG output itself must be consumed in a coherent order + virtual bool requires_sorted() = 0; - // for compatilibility with std::discrete_distribution - // nly used in a disabled branch of llama_sampler_dist_apply + // for compatibility with std::discrete_distribution + // only used in a disabled branch of llama_sampler_dist_apply virtual uint32_t rng_min() = 0; virtual uint32_t rng_max() = 0; virtual uint32_t next() = 0; // uniform bits in [rng_min(), rng_max()] @@ -474,6 +476,48 @@ struct llama_dist_urbg { result_type operator()() { return rng.next(); } }; +// wrapper to use existing llama_sample_dist for mt19937, otherwise implements CDF walk directly +// this is currently only used in a disabled branch of llama_sampler_dist_apply, added for compatibility and potential use by other samplers +// flag normalized to skip recomputing the probability sum when probs already sum to 1 +static int llama_sample_dist_rng(llama_token_data_array * cur_p, llama_dist_rng & rng, bool normalized = false) { + if (!rng.requires_sorted()) { + llama_dist_urbg urbg{rng}; + return llama_sample_dist(cur_p, urbg); + } + + if (!cur_p->sorted) { + llama_token_data_array_partial_sort_inplace(cur_p, cur_p->size); + } + const double rnd = rng.nextf(); + + double sum_run = 0.0; + + if (normalized) { + for (size_t i = 0; i < cur_p->size; ++i) { + sum_run += cur_p->data[i].p; + if (sum_run >= rnd) { + return i; + } + } + } else { + double sum_cum = 0.0; + for (size_t i = 0; i < cur_p->size; ++i) { + sum_cum += cur_p->data[i].p; + } + + const double sum_tgt = sum_cum * rnd; + + for (size_t i = 0; i < cur_p->size; ++i) { + sum_run += cur_p->data[i].p; + if (sum_run >= sum_tgt) { + return i; + } + } + } + + return (int)(cur_p->size - 1); +} + struct llama_dist_rng_mt19937 : llama_dist_rng { std::mt19937 rng; @@ -1301,9 +1345,7 @@ static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_da cur_p->data[i].p /= sum_cum; } - // this implementation is not guaranteed to preserve blue noise properties - llama_dist_urbg urbg{*ctx->rng}; - cur_p->selected = llama_sample_dist(cur_p, urbg); + cur_p->selected = llama_sample_dist_rng(cur_p, *ctx->rng, true); #endif }