diff --git a/src/llama-sampler.cpp b/src/llama-sampler.cpp index ad0a20f0ff..2ddb2978eb 100644 --- a/src/llama-sampler.cpp +++ b/src/llama-sampler.cpp @@ -334,6 +334,27 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) cur_p->size = k; } +// abstract RNG interface for the dist sampler +struct llama_dist_rng { + virtual ~llama_dist_rng() = default; + + // whether the RNG requires sorted input for proper properties + // this also indicates whether the RNG output itself must be consumed in a coherent order + virtual bool requires_sorted() = 0; + + // for compatibility with std::discrete_distribution + // only used in a disabled branch of llama_sampler_dist_apply + virtual uint32_t rng_min() = 0; + virtual uint32_t rng_max() = 0; + virtual uint32_t next() = 0; // uniform bits in [rng_min(), rng_max()] + + virtual uint32_t next32() = 0; // uniform 32 bits + virtual uint64_t next64() = 0; // uniform 64 bits + virtual double nextf() = 0; // uniform double in [0, 1) + virtual void reseed(uint32_t s) = 0; + virtual std::unique_ptr clone() const = 0; +}; + // generative error diffusion for sequential blue noise // pseudo-random number generator with ~6db/octave blue noise // this generator produces a uniform distribution @@ -343,32 +364,38 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) // nor when integer outputs are used with the modulo operator struct blue_noise_rng { uint8_t bit_depth = 0; - uint32_t seed = 0; - uint32_t position = 0; + std::unique_ptr rng; // binary tree of 1-bit 50% duty cycle error diffusion dithering blue noise generators std::vector> states; // {err0, err1} per tree node blue_noise_rng() = default; - blue_noise_rng(uint8_t bit_depth, uint32_t seed) { - init(bit_depth, seed); + blue_noise_rng(uint8_t bit_depth, std::unique_ptr rng) { + init(bit_depth, std::move(rng)); } - // currently this uses lowbias32 as the white noise RNG source - // in practice, any white noise RNG source works - // this random noise is used to perturb the error diffusion weights (binary decision) - // as well as to fill in the low bits of the double precision output to eliminate aliasing - static uint32_t hash(uint32_t x) { // lowbias32 - x ^= x >> 16; x *= 0x21f0aaad; - x ^= x >> 15; x *= 0x735a2d97; - x ^= x >> 15; - return x; + // custom copy (clone the underlying RNG) + blue_noise_rng(const blue_noise_rng & other) + : bit_depth(other.bit_depth) + , rng(other.rng ? other.rng->clone() : nullptr) + , states(other.states) {} + + blue_noise_rng & operator=(const blue_noise_rng & other) { + if (this != &other) { + bit_depth = other.bit_depth; + rng = other.rng ? other.rng->clone() : nullptr; + states = other.states; + } + return *this; } - void init(uint8_t depth, uint32_t s) { + blue_noise_rng(blue_noise_rng &&) = default; + blue_noise_rng & operator=(blue_noise_rng &&) = default; + + void init(uint8_t depth, std::unique_ptr source) { bit_depth = std::clamp(depth, 1, 16); - seed = hash(s); + rng = std::move(source); const int n = (1 << bit_depth) - 1; states.resize(n); // at 16-bit depth, this uses 128KB of state @@ -376,9 +403,13 @@ struct blue_noise_rng { reset(); } + void reseed(uint32_t s) { + rng->reseed(s); + reset(); + } + void reset() { const int n = (int)states.size(); - position = 0; // 5 reachable states with distribution 3:3:2:1:1 // established based on empirical testing @@ -390,15 +421,12 @@ struct blue_noise_rng { {-1, -1}, }; for (int i = 0; i < n; i++) { - uint32_t h = hash((uint32_t)i ^ seed) % 10; + uint32_t h = rng->next32() % 10; states[i] = {tbl[h][0], tbl[h][1]}; // random initial state } } - uint16_t next(uint32_t * hash_remainder = nullptr) { - uint32_t h = hash(position ^ seed); - position++; - + uint16_t advance(uint32_t h) { // traverse binary tree, one error diffusion ditherer per population split // thresholding output at any value still produces blue noise uint32_t acc = 0; @@ -416,50 +444,39 @@ struct blue_noise_rng { acc = acc * 2 + out; } - - if (hash_remainder) { - *hash_remainder = h >> bit_depth; // unused bits from random hash - } - return (uint16_t)acc; } - // blue noise in the upper bit_depth bits, white noise hash remainder in the lower bits + uint16_t next() { + uint32_t h = rng->next32(); + return advance(h); + } + + // blue noise in the upper bit_depth bits, white noise in the lower bits // do not use with modulo operator, as it would just produce white noise uint32_t next32() { - uint32_t rem; - uint32_t val = next(&rem); - return (val << (32 - bit_depth)) | rem; + uint32_t h = rng->next32(); + uint32_t val = advance(h); + return (val << (32 - bit_depth)) | (h >> bit_depth); + } + + // blue noise in the upper bits, white noise in the lower bits + uint64_t next64() { + uint64_t r = rng->next64(); + uint32_t lo = (uint32_t)r; + uint32_t h = (uint32_t)(r >> 32); + uint32_t val = advance(h); + uint32_t hi = (val << (32 - bit_depth)) | (h >> bit_depth); + return ((uint64_t)hi << 32) | lo; } // uniform double in [0, 1) with blue noise temporal autocorrelation double nextf() { - uint32_t lo = hash(position ^ ~seed); // white noise low bits - uint32_t hi = next32(); // blue noise high bits - uint64_t combined = ((uint64_t)hi << 32) | lo; + uint64_t combined = next64(); return (combined >> 11) * 0x1.0p-53; } }; -// abstract RNG interface for the dist sampler -struct llama_dist_rng { - virtual ~llama_dist_rng() = default; - - // whether the RNG requires sorted input for proper properties - // this also indicates whether the RNG output itself must be consumed in a coherent order - virtual bool requires_sorted() = 0; - - // for compatibility with std::discrete_distribution - // only used in a disabled branch of llama_sampler_dist_apply - virtual uint32_t rng_min() = 0; - virtual uint32_t rng_max() = 0; - virtual uint32_t next() = 0; // uniform bits in [rng_min(), rng_max()] - - virtual double nextf() = 0; // uniform double in [0, 1) - virtual void reseed(uint32_t s) = 0; - virtual std::unique_ptr clone() const = 0; -}; - // adapter to satisfy UniformRandomBitGenerator for std::discrete_distribution // note: not guaranteed to preserve blue noise properties // this is only used in a disabled branch of llama_sampler_dist_apply, added for compatibility @@ -515,6 +532,55 @@ static int llama_sample_dist_rng(llama_token_data_array * cur_p, llama_dist_rng return (int)(cur_p->size - 1); } +struct llama_dist_rng_lowbias32 : llama_dist_rng { + uint32_t hashed_seed = 0; + uint32_t position = 0; + + llama_dist_rng_lowbias32(uint32_t seed) : hashed_seed(hash(seed)), position(0) {} + + bool requires_sorted() override { return false; } + uint32_t rng_min() override { return 0; } + uint32_t rng_max() override { return UINT32_MAX; } + + static uint32_t hash(uint32_t x) { // lowbias32 + // coefficients from https://github.com/skeeto/hash-prospector/issues/19 + x ^= x >> 16; x *= 0x21f0aaad; + x ^= x >> 15; x *= 0x735a2d97; + x ^= x >> 15; + return x; + } + + uint32_t next() override { + uint32_t val = hash(position ^ hashed_seed); + position++; + return val; + } + + uint32_t next32() override { + return next(); + } + + uint64_t next64() override { + uint64_t lo = hash(position ^ ~hashed_seed); // secondary sequence using opposing seed + uint64_t hi = next(); + return (hi << 32) | lo; + } + + double nextf() override { + uint64_t combined = next64(); + return (combined >> 11) * 0x1.0p-53; + } + + void reseed(uint32_t s) override { + hashed_seed = hash(s); + position = 0; + } + + std::unique_ptr clone() const override { + return std::make_unique(*this); + } +}; + struct llama_dist_rng_mt19937 : llama_dist_rng { std::mt19937 rng; @@ -524,11 +590,18 @@ struct llama_dist_rng_mt19937 : llama_dist_rng { uint32_t rng_min() override { return std::mt19937::min(); } uint32_t rng_max() override { return std::mt19937::max(); } + uint32_t next() override { return rng(); } - uint32_t next() override { + uint32_t next32() override { return rng(); } + uint64_t next64() override { + uint64_t hi = (uint64_t)rng() << 32; + uint64_t lo = (uint64_t)rng(); + return hi | lo; + } + double nextf() override { std::uniform_real_distribution dist(0.0, 1.0); return dist(rng); @@ -546,15 +619,21 @@ struct llama_dist_rng_mt19937 : llama_dist_rng { struct llama_dist_rng_blue : llama_dist_rng { blue_noise_rng bn_rng; - llama_dist_rng_blue(uint32_t seed) : bn_rng(16, seed) {} + llama_dist_rng_blue(uint32_t seed) + : bn_rng(16, std::make_unique(seed)) {} bool requires_sorted() override { return true; } uint32_t rng_min() override { return 0; } uint32_t rng_max() override { return (1u << bn_rng.bit_depth) - 1; } + uint32_t next() override { return bn_rng.next(); } - uint32_t next() override { - return bn_rng.next(); + uint32_t next32() override { + return bn_rng.next32(); + } + + uint64_t next64() override { + return bn_rng.next64(); } double nextf() override { @@ -562,7 +641,7 @@ struct llama_dist_rng_blue : llama_dist_rng { } void reseed(uint32_t s) override { - bn_rng.init(16, s); + bn_rng.reseed(s); } std::unique_ptr clone() const override {