sampling : implement disabled branch to support blue noise

This commit is contained in:
Jan Boon 2026-02-09 02:47:42 +00:00
parent 23b5a5c026
commit 7f433763b6
1 changed files with 48 additions and 6 deletions

View File

@ -448,10 +448,12 @@ struct blue_noise_rng {
struct llama_dist_rng {
virtual ~llama_dist_rng() = default;
virtual bool requires_sorted() = 0; // whether the RNG requires sorted input for proper properties
// whether the RNG requires sorted input for proper properties
// this also indicates whether the RNG output itself must be consumed in a coherent order
virtual bool requires_sorted() = 0;
// for compatilibility with std::discrete_distribution
// nly used in a disabled branch of llama_sampler_dist_apply
// for compatibility with std::discrete_distribution
// only used in a disabled branch of llama_sampler_dist_apply
virtual uint32_t rng_min() = 0;
virtual uint32_t rng_max() = 0;
virtual uint32_t next() = 0; // uniform bits in [rng_min(), rng_max()]
@ -474,6 +476,48 @@ struct llama_dist_urbg {
result_type operator()() { return rng.next(); }
};
// wrapper to use existing llama_sample_dist for mt19937, otherwise implements CDF walk directly
// this is currently only used in a disabled branch of llama_sampler_dist_apply, added for compatibility and potential use by other samplers
// flag normalized to skip recomputing the probability sum when probs already sum to 1
static int llama_sample_dist_rng(llama_token_data_array * cur_p, llama_dist_rng & rng, bool normalized = false) {
if (!rng.requires_sorted()) {
llama_dist_urbg urbg{rng};
return llama_sample_dist(cur_p, urbg);
}
if (!cur_p->sorted) {
llama_token_data_array_partial_sort_inplace(cur_p, cur_p->size);
}
const double rnd = rng.nextf();
double sum_run = 0.0;
if (normalized) {
for (size_t i = 0; i < cur_p->size; ++i) {
sum_run += cur_p->data[i].p;
if (sum_run >= rnd) {
return i;
}
}
} else {
double sum_cum = 0.0;
for (size_t i = 0; i < cur_p->size; ++i) {
sum_cum += cur_p->data[i].p;
}
const double sum_tgt = sum_cum * rnd;
for (size_t i = 0; i < cur_p->size; ++i) {
sum_run += cur_p->data[i].p;
if (sum_run >= sum_tgt) {
return i;
}
}
}
return (int)(cur_p->size - 1);
}
struct llama_dist_rng_mt19937 : llama_dist_rng {
std::mt19937 rng;
@ -1301,9 +1345,7 @@ static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_da
cur_p->data[i].p /= sum_cum;
}
// this implementation is not guaranteed to preserve blue noise properties
llama_dist_urbg urbg{*ctx->rng};
cur_p->selected = llama_sample_dist(cur_p, urbg);
cur_p->selected = llama_sample_dist_rng(cur_p, *ctx->rng, true);
#endif
}