sampling : blue noise requires tokens to be sorted

This commit is contained in:
Jan Boon 2026-02-05 10:46:18 +00:00
parent d5def78bb0
commit e829f2904e
1 changed files with 11 additions and 1 deletions

View File

@ -443,9 +443,10 @@ struct blue_noise_rng {
struct llama_dist_rng {
virtual ~llama_dist_rng() = default;
virtual bool requires_sorted() = 0; // whether the RNG requires sorted input for proper properties
virtual uint32_t rng_min() = 0;
virtual uint32_t rng_max() = 0;
virtual uint32_t next() = 0; // uniform bits in [rng_min(), rng_max()]
virtual uint32_t next() = 0; // uniform bits in [rng_min(), rng_max()]
virtual double nextf() = 0; // uniform double in [0, 1)
virtual void reseed(uint32_t s) = 0;
virtual std::unique_ptr<llama_dist_rng> clone() const = 0;
@ -468,6 +469,8 @@ struct llama_dist_rng_white : llama_dist_rng {
llama_dist_rng_white(uint32_t seed) : rng(seed) {}
bool requires_sorted() override { return false; }
uint32_t rng_min() override { return std::mt19937::min(); }
uint32_t rng_max() override { return std::mt19937::max(); }
@ -496,6 +499,8 @@ struct llama_dist_rng_blue : llama_dist_rng {
llama_dist_rng_blue(uint32_t seed) : bn_rng(16, seed) {}
bool requires_sorted() override { return true; }
uint32_t rng_min() override { return 0; }
uint32_t rng_max() override { return (1u << bn_rng.bit_depth) - 1; }
@ -1234,6 +1239,11 @@ static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_da
return;
}
// sort if required by the RNG (e.g., blue noise needs sorted input for proper temporal properties)
if (ctx->rng->requires_sorted() && !cur_p->sorted) {
llama_token_data_array_partial_sort_inplace(cur_p, cur_p->size);
}
// max logit for numerical stability
float max_l = cur_p->data[0].logit;
if (!cur_p->sorted) {