sampling : test against previous implementation

This commit is contained in:
Jan Boon 2026-02-09 04:17:07 +00:00
parent f3acd240d6
commit 75cb3e8f2e
1 changed files with 22 additions and 4 deletions

View File

@ -352,6 +352,7 @@ struct llama_dist_rng {
virtual uint64_t next64() = 0; // uniform 64 bits virtual uint64_t next64() = 0; // uniform 64 bits
virtual double nextf() = 0; // uniform double in [0, 1) virtual double nextf() = 0; // uniform double in [0, 1)
virtual void reseed(uint32_t s) = 0; virtual void reseed(uint32_t s) = 0;
virtual void reset() = 0; // reset to post-seed state
virtual std::unique_ptr<llama_dist_rng> clone() const = 0; virtual std::unique_ptr<llama_dist_rng> clone() const = 0;
}; };
@ -400,15 +401,15 @@ struct blue_noise_rng {
const int n = (1 << bit_depth) - 1; const int n = (1 << bit_depth) - 1;
states.resize(n); // at 16-bit depth, this uses 128KB of state states.resize(n); // at 16-bit depth, this uses 128KB of state
reset(); reset_states();
} }
void reseed(uint32_t s) { void reseed(uint32_t s) {
rng->reseed(s); rng->reseed(s);
reset(); reset_states();
} }
void reset() { void reset_states() {
const int n = (int)states.size(); const int n = (int)states.size();
// 5 reachable states with distribution 3:3:2:1:1 // 5 reachable states with distribution 3:3:2:1:1
@ -424,6 +425,8 @@ struct blue_noise_rng {
uint32_t h = rng->next32() % 10; uint32_t h = rng->next32() % 10;
states[i] = {tbl[h][0], tbl[h][1]}; // random initial state states[i] = {tbl[h][0], tbl[h][1]}; // random initial state
} }
rng->reset(); // reset position so generation starts from 0
} }
uint16_t advance(uint32_t h) { uint16_t advance(uint32_t h) {
@ -573,15 +576,20 @@ struct llama_dist_rng_lowbias32 : llama_dist_rng {
position = 0; position = 0;
} }
void reset() override {
position = 0;
}
std::unique_ptr<llama_dist_rng> clone() const override { std::unique_ptr<llama_dist_rng> clone() const override {
return std::make_unique<llama_dist_rng_lowbias32>(*this); return std::make_unique<llama_dist_rng_lowbias32>(*this);
} }
}; };
struct llama_dist_rng_mt19937 : llama_dist_rng { struct llama_dist_rng_mt19937 : llama_dist_rng {
uint32_t seed;
std::mt19937 rng; std::mt19937 rng;
llama_dist_rng_mt19937(uint32_t seed) : rng(seed) {} llama_dist_rng_mt19937(uint32_t seed) : seed(seed), rng(seed) {}
bool requires_sorted() override { return false; } bool requires_sorted() override { return false; }
@ -605,9 +613,14 @@ struct llama_dist_rng_mt19937 : llama_dist_rng {
} }
void reseed(uint32_t s) override { void reseed(uint32_t s) override {
seed = s;
rng.seed(s); rng.seed(s);
} }
void reset() override {
rng.seed(seed);
}
std::unique_ptr<llama_dist_rng> clone() const override { std::unique_ptr<llama_dist_rng> clone() const override {
return std::make_unique<llama_dist_rng_mt19937>(*this); return std::make_unique<llama_dist_rng_mt19937>(*this);
} }
@ -641,6 +654,11 @@ struct llama_dist_rng_blue : llama_dist_rng {
bn_rng.reseed(s); bn_rng.reseed(s);
} }
void reset() override {
bn_rng.rng->reset();
bn_rng.reset_states();
}
std::unique_ptr<llama_dist_rng> clone() const override { std::unique_ptr<llama_dist_rng> clone() const override {
return std::make_unique<llama_dist_rng_blue>(*this); return std::make_unique<llama_dist_rng_blue>(*this);
} }