llama : add blue noise rng implementation
This commit is contained in:
parent
e0c93af2a0
commit
e856c8f959
|
|
@ -333,6 +333,81 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k)
|
|||
cur_p->size = k;
|
||||
}
|
||||
|
||||
// pseudo-random number generator with ~6db/octave blue noise temporal autocorrelation
|
||||
struct blue_noise_rng {
|
||||
uint8_t bit_depth = 0;
|
||||
uint32_t seed = 0;
|
||||
uint32_t position = 0;
|
||||
|
||||
// binary tree of 1-bit 50% duty cycle blue noise generators
|
||||
std::vector<std::array<int8_t, 2>> states; // {err0, err1} per tree node
|
||||
|
||||
blue_noise_rng() = default;
|
||||
|
||||
blue_noise_rng(uint8_t bit_depth, uint32_t seed) {
|
||||
init(bit_depth, seed);
|
||||
}
|
||||
|
||||
static uint32_t hash(uint32_t x) { // lowbias32
|
||||
x ^= x >> 16; x *= 0x21f0aaad;
|
||||
x ^= x >> 15; x *= 0x735a2d97;
|
||||
x ^= x >> 15;
|
||||
return x;
|
||||
}
|
||||
|
||||
void init(uint8_t depth, uint32_t s) {
|
||||
bit_depth = std::clamp<uint8_t>(depth, 1, 16);
|
||||
seed = hash(s);
|
||||
|
||||
const int n = (1 << bit_depth) - 1;
|
||||
states.resize(n);
|
||||
|
||||
reset();
|
||||
}
|
||||
|
||||
void reset() {
|
||||
const int n = (int)states.size();
|
||||
position = 0;
|
||||
|
||||
// 5 reachable states with stationary distribution 3:3:2:1:1 (out of 10)
|
||||
static const int8_t tbl[10][2] = {
|
||||
{ 0, 0}, { 0, 0}, { 0, 0},
|
||||
{-1, 0}, {-1, 0}, {-1, 0},
|
||||
{ 0, -1}, { 0, -1},
|
||||
{-2, 0},
|
||||
{-1, -1},
|
||||
};
|
||||
for (int i = 0; i < n; i++) {
|
||||
uint32_t h = hash((uint32_t)i ^ seed) % 10;
|
||||
states[i] = {tbl[h][0], tbl[h][1]}; // random initial state
|
||||
}
|
||||
}
|
||||
|
||||
uint16_t next() {
|
||||
uint32_t h = hash(position ^ seed);
|
||||
position++;
|
||||
|
||||
// traverse binary tree root-to-leaf, one error diffusion ditherer per bit
|
||||
uint32_t acc = 0;
|
||||
for (int level = 0; level < bit_depth; level++) {
|
||||
auto & s = states[(1 << level) - 1 + acc]; // heap-style index
|
||||
|
||||
int out = (s[0] >= 0) ? 1 : 0;
|
||||
int8_t qe = s[0] + (int8_t)(out ? -1 : 1); // inverse autocorrelation
|
||||
|
||||
s[0] = s[1]; // step forward
|
||||
s[1] = 0;
|
||||
|
||||
// error diffusion dithering using binary weight perturbation
|
||||
s[(h >> level) & 1 ? 0 : 1] += qe; // forward to t+1 or defer to t+2
|
||||
|
||||
acc = acc * 2 + out;
|
||||
}
|
||||
|
||||
return (uint16_t)acc;
|
||||
}
|
||||
};
|
||||
|
||||
static uint32_t get_rng_seed(uint32_t seed) {
|
||||
if (seed == LLAMA_DEFAULT_SEED) {
|
||||
// use system clock if std::random_device is not a true RNG
|
||||
|
|
|
|||
Loading…
Reference in New Issue