sampling : test against previous implementation
This commit is contained in:
parent
f3acd240d6
commit
75cb3e8f2e
|
|
@ -352,6 +352,7 @@ struct llama_dist_rng {
|
|||
virtual uint64_t next64() = 0; // uniform 64 bits
|
||||
virtual double nextf() = 0; // uniform double in [0, 1)
|
||||
virtual void reseed(uint32_t s) = 0;
|
||||
virtual void reset() = 0; // reset to post-seed state
|
||||
virtual std::unique_ptr<llama_dist_rng> clone() const = 0;
|
||||
};
|
||||
|
||||
|
|
@ -400,15 +401,15 @@ struct blue_noise_rng {
|
|||
const int n = (1 << bit_depth) - 1;
|
||||
states.resize(n); // at 16-bit depth, this uses 128KB of state
|
||||
|
||||
reset();
|
||||
reset_states();
|
||||
}
|
||||
|
||||
void reseed(uint32_t s) {
|
||||
rng->reseed(s);
|
||||
reset();
|
||||
reset_states();
|
||||
}
|
||||
|
||||
void reset() {
|
||||
void reset_states() {
|
||||
const int n = (int)states.size();
|
||||
|
||||
// 5 reachable states with distribution 3:3:2:1:1
|
||||
|
|
@ -424,6 +425,8 @@ struct blue_noise_rng {
|
|||
uint32_t h = rng->next32() % 10;
|
||||
states[i] = {tbl[h][0], tbl[h][1]}; // random initial state
|
||||
}
|
||||
|
||||
rng->reset(); // reset position so generation starts from 0
|
||||
}
|
||||
|
||||
uint16_t advance(uint32_t h) {
|
||||
|
|
@ -573,15 +576,20 @@ struct llama_dist_rng_lowbias32 : llama_dist_rng {
|
|||
position = 0;
|
||||
}
|
||||
|
||||
void reset() override {
|
||||
position = 0;
|
||||
}
|
||||
|
||||
std::unique_ptr<llama_dist_rng> clone() const override {
|
||||
return std::make_unique<llama_dist_rng_lowbias32>(*this);
|
||||
}
|
||||
};
|
||||
|
||||
struct llama_dist_rng_mt19937 : llama_dist_rng {
|
||||
uint32_t seed;
|
||||
std::mt19937 rng;
|
||||
|
||||
llama_dist_rng_mt19937(uint32_t seed) : rng(seed) {}
|
||||
llama_dist_rng_mt19937(uint32_t seed) : seed(seed), rng(seed) {}
|
||||
|
||||
bool requires_sorted() override { return false; }
|
||||
|
||||
|
|
@ -605,9 +613,14 @@ struct llama_dist_rng_mt19937 : llama_dist_rng {
|
|||
}
|
||||
|
||||
void reseed(uint32_t s) override {
|
||||
seed = s;
|
||||
rng.seed(s);
|
||||
}
|
||||
|
||||
void reset() override {
|
||||
rng.seed(seed);
|
||||
}
|
||||
|
||||
std::unique_ptr<llama_dist_rng> clone() const override {
|
||||
return std::make_unique<llama_dist_rng_mt19937>(*this);
|
||||
}
|
||||
|
|
@ -641,6 +654,11 @@ struct llama_dist_rng_blue : llama_dist_rng {
|
|||
bn_rng.reseed(s);
|
||||
}
|
||||
|
||||
void reset() override {
|
||||
bn_rng.rng->reset();
|
||||
bn_rng.reset_states();
|
||||
}
|
||||
|
||||
std::unique_ptr<llama_dist_rng> clone() const override {
|
||||
return std::make_unique<llama_dist_rng_blue>(*this);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue