sampling : make white noise source for blue noise modular as well

This commit is contained in:
Jan Boon 2026-02-09 04:03:10 +00:00
parent ae31b151e9
commit 2c7269fd8d
1 changed files with 136 additions and 57 deletions

View File

@ -334,6 +334,27 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k)
cur_p->size = k;
}
// abstract RNG interface for the dist sampler
struct llama_dist_rng {
virtual ~llama_dist_rng() = default;
// whether the RNG requires sorted input for proper properties
// this also indicates whether the RNG output itself must be consumed in a coherent order
virtual bool requires_sorted() = 0;
// for compatibility with std::discrete_distribution
// only used in a disabled branch of llama_sampler_dist_apply
virtual uint32_t rng_min() = 0;
virtual uint32_t rng_max() = 0;
virtual uint32_t next() = 0; // uniform bits in [rng_min(), rng_max()]
virtual uint32_t next32() = 0; // uniform 32 bits
virtual uint64_t next64() = 0; // uniform 64 bits
virtual double nextf() = 0; // uniform double in [0, 1)
virtual void reseed(uint32_t s) = 0;
virtual std::unique_ptr<llama_dist_rng> clone() const = 0;
};
// generative error diffusion for sequential blue noise
// pseudo-random number generator with ~6db/octave blue noise
// this generator produces a uniform distribution
@ -343,32 +364,38 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k)
// nor when integer outputs are used with the modulo operator
struct blue_noise_rng {
uint8_t bit_depth = 0;
uint32_t seed = 0;
uint32_t position = 0;
std::unique_ptr<llama_dist_rng> rng;
// binary tree of 1-bit 50% duty cycle error diffusion dithering blue noise generators
std::vector<std::array<int8_t, 2>> states; // {err0, err1} per tree node
blue_noise_rng() = default;
blue_noise_rng(uint8_t bit_depth, uint32_t seed) {
init(bit_depth, seed);
blue_noise_rng(uint8_t bit_depth, std::unique_ptr<llama_dist_rng> rng) {
init(bit_depth, std::move(rng));
}
// currently this uses lowbias32 as the white noise RNG source
// in practice, any white noise RNG source works
// this random noise is used to perturb the error diffusion weights (binary decision)
// as well as to fill in the low bits of the double precision output to eliminate aliasing
static uint32_t hash(uint32_t x) { // lowbias32
x ^= x >> 16; x *= 0x21f0aaad;
x ^= x >> 15; x *= 0x735a2d97;
x ^= x >> 15;
return x;
// custom copy (clone the underlying RNG)
blue_noise_rng(const blue_noise_rng & other)
: bit_depth(other.bit_depth)
, rng(other.rng ? other.rng->clone() : nullptr)
, states(other.states) {}
blue_noise_rng & operator=(const blue_noise_rng & other) {
if (this != &other) {
bit_depth = other.bit_depth;
rng = other.rng ? other.rng->clone() : nullptr;
states = other.states;
}
return *this;
}
void init(uint8_t depth, uint32_t s) {
blue_noise_rng(blue_noise_rng &&) = default;
blue_noise_rng & operator=(blue_noise_rng &&) = default;
void init(uint8_t depth, std::unique_ptr<llama_dist_rng> source) {
bit_depth = std::clamp<uint8_t>(depth, 1, 16);
seed = hash(s);
rng = std::move(source);
const int n = (1 << bit_depth) - 1;
states.resize(n); // at 16-bit depth, this uses 128KB of state
@ -376,9 +403,13 @@ struct blue_noise_rng {
reset();
}
void reseed(uint32_t s) {
rng->reseed(s);
reset();
}
void reset() {
const int n = (int)states.size();
position = 0;
// 5 reachable states with distribution 3:3:2:1:1
// established based on empirical testing
@ -390,15 +421,12 @@ struct blue_noise_rng {
{-1, -1},
};
for (int i = 0; i < n; i++) {
uint32_t h = hash((uint32_t)i ^ seed) % 10;
uint32_t h = rng->next32() % 10;
states[i] = {tbl[h][0], tbl[h][1]}; // random initial state
}
}
uint16_t next(uint32_t * hash_remainder = nullptr) {
uint32_t h = hash(position ^ seed);
position++;
uint16_t advance(uint32_t h) {
// traverse binary tree, one error diffusion ditherer per population split
// thresholding output at any value still produces blue noise
uint32_t acc = 0;
@ -416,50 +444,39 @@ struct blue_noise_rng {
acc = acc * 2 + out;
}
if (hash_remainder) {
*hash_remainder = h >> bit_depth; // unused bits from random hash
}
return (uint16_t)acc;
}
// blue noise in the upper bit_depth bits, white noise hash remainder in the lower bits
uint16_t next() {
uint32_t h = rng->next32();
return advance(h);
}
// blue noise in the upper bit_depth bits, white noise in the lower bits
// do not use with modulo operator, as it would just produce white noise
uint32_t next32() {
uint32_t rem;
uint32_t val = next(&rem);
return (val << (32 - bit_depth)) | rem;
uint32_t h = rng->next32();
uint32_t val = advance(h);
return (val << (32 - bit_depth)) | (h >> bit_depth);
}
// blue noise in the upper bits, white noise in the lower bits
uint64_t next64() {
uint64_t r = rng->next64();
uint32_t lo = (uint32_t)r;
uint32_t h = (uint32_t)(r >> 32);
uint32_t val = advance(h);
uint32_t hi = (val << (32 - bit_depth)) | (h >> bit_depth);
return ((uint64_t)hi << 32) | lo;
}
// uniform double in [0, 1) with blue noise temporal autocorrelation
double nextf() {
uint32_t lo = hash(position ^ ~seed); // white noise low bits
uint32_t hi = next32(); // blue noise high bits
uint64_t combined = ((uint64_t)hi << 32) | lo;
uint64_t combined = next64();
return (combined >> 11) * 0x1.0p-53;
}
};
// abstract RNG interface for the dist sampler
struct llama_dist_rng {
virtual ~llama_dist_rng() = default;
// whether the RNG requires sorted input for proper properties
// this also indicates whether the RNG output itself must be consumed in a coherent order
virtual bool requires_sorted() = 0;
// for compatibility with std::discrete_distribution
// only used in a disabled branch of llama_sampler_dist_apply
virtual uint32_t rng_min() = 0;
virtual uint32_t rng_max() = 0;
virtual uint32_t next() = 0; // uniform bits in [rng_min(), rng_max()]
virtual double nextf() = 0; // uniform double in [0, 1)
virtual void reseed(uint32_t s) = 0;
virtual std::unique_ptr<llama_dist_rng> clone() const = 0;
};
// adapter to satisfy UniformRandomBitGenerator for std::discrete_distribution
// note: not guaranteed to preserve blue noise properties
// this is only used in a disabled branch of llama_sampler_dist_apply, added for compatibility
@ -515,6 +532,55 @@ static int llama_sample_dist_rng(llama_token_data_array * cur_p, llama_dist_rng
return (int)(cur_p->size - 1);
}
struct llama_dist_rng_lowbias32 : llama_dist_rng {
uint32_t hashed_seed = 0;
uint32_t position = 0;
llama_dist_rng_lowbias32(uint32_t seed) : hashed_seed(hash(seed)), position(0) {}
bool requires_sorted() override { return false; }
uint32_t rng_min() override { return 0; }
uint32_t rng_max() override { return UINT32_MAX; }
static uint32_t hash(uint32_t x) { // lowbias32
// coefficients from https://github.com/skeeto/hash-prospector/issues/19
x ^= x >> 16; x *= 0x21f0aaad;
x ^= x >> 15; x *= 0x735a2d97;
x ^= x >> 15;
return x;
}
uint32_t next() override {
uint32_t val = hash(position ^ hashed_seed);
position++;
return val;
}
uint32_t next32() override {
return next();
}
uint64_t next64() override {
uint64_t lo = hash(position ^ ~hashed_seed); // secondary sequence using opposing seed
uint64_t hi = next();
return (hi << 32) | lo;
}
double nextf() override {
uint64_t combined = next64();
return (combined >> 11) * 0x1.0p-53;
}
void reseed(uint32_t s) override {
hashed_seed = hash(s);
position = 0;
}
std::unique_ptr<llama_dist_rng> clone() const override {
return std::make_unique<llama_dist_rng_lowbias32>(*this);
}
};
struct llama_dist_rng_mt19937 : llama_dist_rng {
std::mt19937 rng;
@ -524,11 +590,18 @@ struct llama_dist_rng_mt19937 : llama_dist_rng {
uint32_t rng_min() override { return std::mt19937::min(); }
uint32_t rng_max() override { return std::mt19937::max(); }
uint32_t next() override { return rng(); }
uint32_t next() override {
uint32_t next32() override {
return rng();
}
uint64_t next64() override {
uint64_t hi = (uint64_t)rng() << 32;
uint64_t lo = (uint64_t)rng();
return hi | lo;
}
double nextf() override {
std::uniform_real_distribution<double> dist(0.0, 1.0);
return dist(rng);
@ -546,15 +619,21 @@ struct llama_dist_rng_mt19937 : llama_dist_rng {
struct llama_dist_rng_blue : llama_dist_rng {
blue_noise_rng bn_rng;
llama_dist_rng_blue(uint32_t seed) : bn_rng(16, seed) {}
llama_dist_rng_blue(uint32_t seed)
: bn_rng(16, std::make_unique<llama_dist_rng_lowbias32>(seed)) {}
bool requires_sorted() override { return true; }
uint32_t rng_min() override { return 0; }
uint32_t rng_max() override { return (1u << bn_rng.bit_depth) - 1; }
uint32_t next() override { return bn_rng.next(); }
uint32_t next() override {
return bn_rng.next();
uint32_t next32() override {
return bn_rng.next32();
}
uint64_t next64() override {
return bn_rng.next64();
}
double nextf() override {
@ -562,7 +641,7 @@ struct llama_dist_rng_blue : llama_dist_rng {
}
void reseed(uint32_t s) override {
bn_rng.init(16, s);
bn_rng.reseed(s);
}
std::unique_ptr<llama_dist_rng> clone() const override {