sampling : make white noise source for blue noise modular as well
This commit is contained in:
parent
ae31b151e9
commit
2c7269fd8d
|
|
@ -334,6 +334,27 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k)
|
|||
cur_p->size = k;
|
||||
}
|
||||
|
||||
// abstract RNG interface for the dist sampler
|
||||
struct llama_dist_rng {
|
||||
virtual ~llama_dist_rng() = default;
|
||||
|
||||
// whether the RNG requires sorted input for proper properties
|
||||
// this also indicates whether the RNG output itself must be consumed in a coherent order
|
||||
virtual bool requires_sorted() = 0;
|
||||
|
||||
// for compatibility with std::discrete_distribution
|
||||
// only used in a disabled branch of llama_sampler_dist_apply
|
||||
virtual uint32_t rng_min() = 0;
|
||||
virtual uint32_t rng_max() = 0;
|
||||
virtual uint32_t next() = 0; // uniform bits in [rng_min(), rng_max()]
|
||||
|
||||
virtual uint32_t next32() = 0; // uniform 32 bits
|
||||
virtual uint64_t next64() = 0; // uniform 64 bits
|
||||
virtual double nextf() = 0; // uniform double in [0, 1)
|
||||
virtual void reseed(uint32_t s) = 0;
|
||||
virtual std::unique_ptr<llama_dist_rng> clone() const = 0;
|
||||
};
|
||||
|
||||
// generative error diffusion for sequential blue noise
|
||||
// pseudo-random number generator with ~6db/octave blue noise
|
||||
// this generator produces a uniform distribution
|
||||
|
|
@ -343,32 +364,38 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k)
|
|||
// nor when integer outputs are used with the modulo operator
|
||||
struct blue_noise_rng {
|
||||
uint8_t bit_depth = 0;
|
||||
uint32_t seed = 0;
|
||||
uint32_t position = 0;
|
||||
std::unique_ptr<llama_dist_rng> rng;
|
||||
|
||||
// binary tree of 1-bit 50% duty cycle error diffusion dithering blue noise generators
|
||||
std::vector<std::array<int8_t, 2>> states; // {err0, err1} per tree node
|
||||
|
||||
blue_noise_rng() = default;
|
||||
|
||||
blue_noise_rng(uint8_t bit_depth, uint32_t seed) {
|
||||
init(bit_depth, seed);
|
||||
blue_noise_rng(uint8_t bit_depth, std::unique_ptr<llama_dist_rng> rng) {
|
||||
init(bit_depth, std::move(rng));
|
||||
}
|
||||
|
||||
// currently this uses lowbias32 as the white noise RNG source
|
||||
// in practice, any white noise RNG source works
|
||||
// this random noise is used to perturb the error diffusion weights (binary decision)
|
||||
// as well as to fill in the low bits of the double precision output to eliminate aliasing
|
||||
static uint32_t hash(uint32_t x) { // lowbias32
|
||||
x ^= x >> 16; x *= 0x21f0aaad;
|
||||
x ^= x >> 15; x *= 0x735a2d97;
|
||||
x ^= x >> 15;
|
||||
return x;
|
||||
// custom copy (clone the underlying RNG)
|
||||
blue_noise_rng(const blue_noise_rng & other)
|
||||
: bit_depth(other.bit_depth)
|
||||
, rng(other.rng ? other.rng->clone() : nullptr)
|
||||
, states(other.states) {}
|
||||
|
||||
blue_noise_rng & operator=(const blue_noise_rng & other) {
|
||||
if (this != &other) {
|
||||
bit_depth = other.bit_depth;
|
||||
rng = other.rng ? other.rng->clone() : nullptr;
|
||||
states = other.states;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
void init(uint8_t depth, uint32_t s) {
|
||||
blue_noise_rng(blue_noise_rng &&) = default;
|
||||
blue_noise_rng & operator=(blue_noise_rng &&) = default;
|
||||
|
||||
void init(uint8_t depth, std::unique_ptr<llama_dist_rng> source) {
|
||||
bit_depth = std::clamp<uint8_t>(depth, 1, 16);
|
||||
seed = hash(s);
|
||||
rng = std::move(source);
|
||||
|
||||
const int n = (1 << bit_depth) - 1;
|
||||
states.resize(n); // at 16-bit depth, this uses 128KB of state
|
||||
|
|
@ -376,9 +403,13 @@ struct blue_noise_rng {
|
|||
reset();
|
||||
}
|
||||
|
||||
void reseed(uint32_t s) {
|
||||
rng->reseed(s);
|
||||
reset();
|
||||
}
|
||||
|
||||
void reset() {
|
||||
const int n = (int)states.size();
|
||||
position = 0;
|
||||
|
||||
// 5 reachable states with distribution 3:3:2:1:1
|
||||
// established based on empirical testing
|
||||
|
|
@ -390,15 +421,12 @@ struct blue_noise_rng {
|
|||
{-1, -1},
|
||||
};
|
||||
for (int i = 0; i < n; i++) {
|
||||
uint32_t h = hash((uint32_t)i ^ seed) % 10;
|
||||
uint32_t h = rng->next32() % 10;
|
||||
states[i] = {tbl[h][0], tbl[h][1]}; // random initial state
|
||||
}
|
||||
}
|
||||
|
||||
uint16_t next(uint32_t * hash_remainder = nullptr) {
|
||||
uint32_t h = hash(position ^ seed);
|
||||
position++;
|
||||
|
||||
uint16_t advance(uint32_t h) {
|
||||
// traverse binary tree, one error diffusion ditherer per population split
|
||||
// thresholding output at any value still produces blue noise
|
||||
uint32_t acc = 0;
|
||||
|
|
@ -416,50 +444,39 @@ struct blue_noise_rng {
|
|||
|
||||
acc = acc * 2 + out;
|
||||
}
|
||||
|
||||
if (hash_remainder) {
|
||||
*hash_remainder = h >> bit_depth; // unused bits from random hash
|
||||
}
|
||||
|
||||
return (uint16_t)acc;
|
||||
}
|
||||
|
||||
// blue noise in the upper bit_depth bits, white noise hash remainder in the lower bits
|
||||
uint16_t next() {
|
||||
uint32_t h = rng->next32();
|
||||
return advance(h);
|
||||
}
|
||||
|
||||
// blue noise in the upper bit_depth bits, white noise in the lower bits
|
||||
// do not use with modulo operator, as it would just produce white noise
|
||||
uint32_t next32() {
|
||||
uint32_t rem;
|
||||
uint32_t val = next(&rem);
|
||||
return (val << (32 - bit_depth)) | rem;
|
||||
uint32_t h = rng->next32();
|
||||
uint32_t val = advance(h);
|
||||
return (val << (32 - bit_depth)) | (h >> bit_depth);
|
||||
}
|
||||
|
||||
// blue noise in the upper bits, white noise in the lower bits
|
||||
uint64_t next64() {
|
||||
uint64_t r = rng->next64();
|
||||
uint32_t lo = (uint32_t)r;
|
||||
uint32_t h = (uint32_t)(r >> 32);
|
||||
uint32_t val = advance(h);
|
||||
uint32_t hi = (val << (32 - bit_depth)) | (h >> bit_depth);
|
||||
return ((uint64_t)hi << 32) | lo;
|
||||
}
|
||||
|
||||
// uniform double in [0, 1) with blue noise temporal autocorrelation
|
||||
double nextf() {
|
||||
uint32_t lo = hash(position ^ ~seed); // white noise low bits
|
||||
uint32_t hi = next32(); // blue noise high bits
|
||||
uint64_t combined = ((uint64_t)hi << 32) | lo;
|
||||
uint64_t combined = next64();
|
||||
return (combined >> 11) * 0x1.0p-53;
|
||||
}
|
||||
};
|
||||
|
||||
// abstract RNG interface for the dist sampler
|
||||
struct llama_dist_rng {
|
||||
virtual ~llama_dist_rng() = default;
|
||||
|
||||
// whether the RNG requires sorted input for proper properties
|
||||
// this also indicates whether the RNG output itself must be consumed in a coherent order
|
||||
virtual bool requires_sorted() = 0;
|
||||
|
||||
// for compatibility with std::discrete_distribution
|
||||
// only used in a disabled branch of llama_sampler_dist_apply
|
||||
virtual uint32_t rng_min() = 0;
|
||||
virtual uint32_t rng_max() = 0;
|
||||
virtual uint32_t next() = 0; // uniform bits in [rng_min(), rng_max()]
|
||||
|
||||
virtual double nextf() = 0; // uniform double in [0, 1)
|
||||
virtual void reseed(uint32_t s) = 0;
|
||||
virtual std::unique_ptr<llama_dist_rng> clone() const = 0;
|
||||
};
|
||||
|
||||
// adapter to satisfy UniformRandomBitGenerator for std::discrete_distribution
|
||||
// note: not guaranteed to preserve blue noise properties
|
||||
// this is only used in a disabled branch of llama_sampler_dist_apply, added for compatibility
|
||||
|
|
@ -515,6 +532,55 @@ static int llama_sample_dist_rng(llama_token_data_array * cur_p, llama_dist_rng
|
|||
return (int)(cur_p->size - 1);
|
||||
}
|
||||
|
||||
struct llama_dist_rng_lowbias32 : llama_dist_rng {
|
||||
uint32_t hashed_seed = 0;
|
||||
uint32_t position = 0;
|
||||
|
||||
llama_dist_rng_lowbias32(uint32_t seed) : hashed_seed(hash(seed)), position(0) {}
|
||||
|
||||
bool requires_sorted() override { return false; }
|
||||
uint32_t rng_min() override { return 0; }
|
||||
uint32_t rng_max() override { return UINT32_MAX; }
|
||||
|
||||
static uint32_t hash(uint32_t x) { // lowbias32
|
||||
// coefficients from https://github.com/skeeto/hash-prospector/issues/19
|
||||
x ^= x >> 16; x *= 0x21f0aaad;
|
||||
x ^= x >> 15; x *= 0x735a2d97;
|
||||
x ^= x >> 15;
|
||||
return x;
|
||||
}
|
||||
|
||||
uint32_t next() override {
|
||||
uint32_t val = hash(position ^ hashed_seed);
|
||||
position++;
|
||||
return val;
|
||||
}
|
||||
|
||||
uint32_t next32() override {
|
||||
return next();
|
||||
}
|
||||
|
||||
uint64_t next64() override {
|
||||
uint64_t lo = hash(position ^ ~hashed_seed); // secondary sequence using opposing seed
|
||||
uint64_t hi = next();
|
||||
return (hi << 32) | lo;
|
||||
}
|
||||
|
||||
double nextf() override {
|
||||
uint64_t combined = next64();
|
||||
return (combined >> 11) * 0x1.0p-53;
|
||||
}
|
||||
|
||||
void reseed(uint32_t s) override {
|
||||
hashed_seed = hash(s);
|
||||
position = 0;
|
||||
}
|
||||
|
||||
std::unique_ptr<llama_dist_rng> clone() const override {
|
||||
return std::make_unique<llama_dist_rng_lowbias32>(*this);
|
||||
}
|
||||
};
|
||||
|
||||
struct llama_dist_rng_mt19937 : llama_dist_rng {
|
||||
std::mt19937 rng;
|
||||
|
||||
|
|
@ -524,11 +590,18 @@ struct llama_dist_rng_mt19937 : llama_dist_rng {
|
|||
|
||||
uint32_t rng_min() override { return std::mt19937::min(); }
|
||||
uint32_t rng_max() override { return std::mt19937::max(); }
|
||||
uint32_t next() override { return rng(); }
|
||||
|
||||
uint32_t next() override {
|
||||
uint32_t next32() override {
|
||||
return rng();
|
||||
}
|
||||
|
||||
uint64_t next64() override {
|
||||
uint64_t hi = (uint64_t)rng() << 32;
|
||||
uint64_t lo = (uint64_t)rng();
|
||||
return hi | lo;
|
||||
}
|
||||
|
||||
double nextf() override {
|
||||
std::uniform_real_distribution<double> dist(0.0, 1.0);
|
||||
return dist(rng);
|
||||
|
|
@ -546,15 +619,21 @@ struct llama_dist_rng_mt19937 : llama_dist_rng {
|
|||
struct llama_dist_rng_blue : llama_dist_rng {
|
||||
blue_noise_rng bn_rng;
|
||||
|
||||
llama_dist_rng_blue(uint32_t seed) : bn_rng(16, seed) {}
|
||||
llama_dist_rng_blue(uint32_t seed)
|
||||
: bn_rng(16, std::make_unique<llama_dist_rng_lowbias32>(seed)) {}
|
||||
|
||||
bool requires_sorted() override { return true; }
|
||||
|
||||
uint32_t rng_min() override { return 0; }
|
||||
uint32_t rng_max() override { return (1u << bn_rng.bit_depth) - 1; }
|
||||
uint32_t next() override { return bn_rng.next(); }
|
||||
|
||||
uint32_t next() override {
|
||||
return bn_rng.next();
|
||||
uint32_t next32() override {
|
||||
return bn_rng.next32();
|
||||
}
|
||||
|
||||
uint64_t next64() override {
|
||||
return bn_rng.next64();
|
||||
}
|
||||
|
||||
double nextf() override {
|
||||
|
|
@ -562,7 +641,7 @@ struct llama_dist_rng_blue : llama_dist_rng {
|
|||
}
|
||||
|
||||
void reseed(uint32_t s) override {
|
||||
bn_rng.init(16, s);
|
||||
bn_rng.reseed(s);
|
||||
}
|
||||
|
||||
std::unique_ptr<llama_dist_rng> clone() const override {
|
||||
|
|
|
|||
Loading…
Reference in New Issue