sampling : test against previous implementation

This commit is contained in:
Jan Boon 2026-02-09 04:17:07 +00:00
parent f3acd240d6
commit 75cb3e8f2e
1 changed files with 22 additions and 4 deletions

View File

@ -352,6 +352,7 @@ struct llama_dist_rng {
virtual uint64_t next64() = 0; // uniform 64 bits
virtual double nextf() = 0; // uniform double in [0, 1)
virtual void reseed(uint32_t s) = 0;
virtual void reset() = 0; // reset to post-seed state
virtual std::unique_ptr<llama_dist_rng> clone() const = 0;
};
@ -400,15 +401,15 @@ struct blue_noise_rng {
const int n = (1 << bit_depth) - 1;
states.resize(n); // at 16-bit depth, this uses 128KB of state
reset();
reset_states();
}
void reseed(uint32_t s) {
rng->reseed(s);
reset();
reset_states();
}
void reset() {
void reset_states() {
const int n = (int)states.size();
// 5 reachable states with distribution 3:3:2:1:1
@ -424,6 +425,8 @@ struct blue_noise_rng {
uint32_t h = rng->next32() % 10;
states[i] = {tbl[h][0], tbl[h][1]}; // random initial state
}
rng->reset(); // reset position so generation starts from 0
}
uint16_t advance(uint32_t h) {
@ -573,15 +576,20 @@ struct llama_dist_rng_lowbias32 : llama_dist_rng {
position = 0;
}
void reset() override {
position = 0;
}
std::unique_ptr<llama_dist_rng> clone() const override {
return std::make_unique<llama_dist_rng_lowbias32>(*this);
}
};
struct llama_dist_rng_mt19937 : llama_dist_rng {
uint32_t seed;
std::mt19937 rng;
llama_dist_rng_mt19937(uint32_t seed) : rng(seed) {}
llama_dist_rng_mt19937(uint32_t seed) : seed(seed), rng(seed) {}
bool requires_sorted() override { return false; }
@ -605,9 +613,14 @@ struct llama_dist_rng_mt19937 : llama_dist_rng {
}
void reseed(uint32_t s) override {
seed = s;
rng.seed(s);
}
void reset() override {
rng.seed(seed);
}
std::unique_ptr<llama_dist_rng> clone() const override {
return std::make_unique<llama_dist_rng_mt19937>(*this);
}
@ -641,6 +654,11 @@ struct llama_dist_rng_blue : llama_dist_rng {
bn_rng.reseed(s);
}
void reset() override {
bn_rng.rng->reset();
bn_rng.reset_states();
}
std::unique_ptr<llama_dist_rng> clone() const override {
return std::make_unique<llama_dist_rng_blue>(*this);
}