diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 5dd094ce7a..7c83095582 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -333,13 +333,13 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) cur_p->size = k; } -// pseudo-random number generator with ~6db/octave blue noise temporal autocorrelation +// pseudo-random number generator with ~6db/octave blue noise struct blue_noise_rng { uint8_t bit_depth = 0; uint32_t seed = 0; uint32_t position = 0; - // binary tree of 1-bit 50% duty cycle blue noise generators + // binary tree of 1-bit 50% duty cycle error diffusion dithering blue noise generators std::vector> states; // {err0, err1} per tree node blue_noise_rng() = default; @@ -383,11 +383,12 @@ struct blue_noise_rng { } } - uint16_t next() { + uint16_t next(uint32_t * hash_remainder = nullptr) { uint32_t h = hash(position ^ seed); position++; - // traverse binary tree root-to-leaf, one error diffusion ditherer per bit + // traverse binary tree, one error diffusion ditherer per population split + // thresholding output at any value still produces blue noise uint32_t acc = 0; for (int level = 0; level < bit_depth; level++) { auto & s = states[(1 << level) - 1 + acc]; // heap-style index @@ -404,8 +405,31 @@ struct blue_noise_rng { acc = acc * 2 + out; } + if (hash_remainder) { + *hash_remainder = h >> bit_depth; // unused bits from random hash + } + return (uint16_t)acc; } + + // blue noise in the upper bit_depth bits, white noise hash remainder in the lower bits + // do not use with modulo operator, as it would just produce white noise + uint32_t next32() { + uint32_t rem; + uint32_t val = next(&rem); + return (val << (32 - bit_depth)) | rem; + } + + // uniform double in [0, 1) with blue noise temporal autocorrelation + double nextf() { + double res = 0.0; + res += hash(position ^ ~seed); // fill low bits with white noise + res *= 1.0 / 4294967296.0; + res += next32(); + res *= 1.0 / 4294967296.0; + if (res >= 1.0) res = std::nextafter(1.0, 0.0); + return res; + } }; static uint32_t get_rng_seed(uint32_t seed) {