sampling : test against previous implementation
This commit is contained in:
parent
f3acd240d6
commit
75cb3e8f2e
|
|
@ -352,6 +352,7 @@ struct llama_dist_rng {
|
||||||
virtual uint64_t next64() = 0; // uniform 64 bits
|
virtual uint64_t next64() = 0; // uniform 64 bits
|
||||||
virtual double nextf() = 0; // uniform double in [0, 1)
|
virtual double nextf() = 0; // uniform double in [0, 1)
|
||||||
virtual void reseed(uint32_t s) = 0;
|
virtual void reseed(uint32_t s) = 0;
|
||||||
|
virtual void reset() = 0; // reset to post-seed state
|
||||||
virtual std::unique_ptr<llama_dist_rng> clone() const = 0;
|
virtual std::unique_ptr<llama_dist_rng> clone() const = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -400,15 +401,15 @@ struct blue_noise_rng {
|
||||||
const int n = (1 << bit_depth) - 1;
|
const int n = (1 << bit_depth) - 1;
|
||||||
states.resize(n); // at 16-bit depth, this uses 128KB of state
|
states.resize(n); // at 16-bit depth, this uses 128KB of state
|
||||||
|
|
||||||
reset();
|
reset_states();
|
||||||
}
|
}
|
||||||
|
|
||||||
void reseed(uint32_t s) {
|
void reseed(uint32_t s) {
|
||||||
rng->reseed(s);
|
rng->reseed(s);
|
||||||
reset();
|
reset_states();
|
||||||
}
|
}
|
||||||
|
|
||||||
void reset() {
|
void reset_states() {
|
||||||
const int n = (int)states.size();
|
const int n = (int)states.size();
|
||||||
|
|
||||||
// 5 reachable states with distribution 3:3:2:1:1
|
// 5 reachable states with distribution 3:3:2:1:1
|
||||||
|
|
@ -424,6 +425,8 @@ struct blue_noise_rng {
|
||||||
uint32_t h = rng->next32() % 10;
|
uint32_t h = rng->next32() % 10;
|
||||||
states[i] = {tbl[h][0], tbl[h][1]}; // random initial state
|
states[i] = {tbl[h][0], tbl[h][1]}; // random initial state
|
||||||
}
|
}
|
||||||
|
|
||||||
|
rng->reset(); // reset position so generation starts from 0
|
||||||
}
|
}
|
||||||
|
|
||||||
uint16_t advance(uint32_t h) {
|
uint16_t advance(uint32_t h) {
|
||||||
|
|
@ -573,15 +576,20 @@ struct llama_dist_rng_lowbias32 : llama_dist_rng {
|
||||||
position = 0;
|
position = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void reset() override {
|
||||||
|
position = 0;
|
||||||
|
}
|
||||||
|
|
||||||
std::unique_ptr<llama_dist_rng> clone() const override {
|
std::unique_ptr<llama_dist_rng> clone() const override {
|
||||||
return std::make_unique<llama_dist_rng_lowbias32>(*this);
|
return std::make_unique<llama_dist_rng_lowbias32>(*this);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_dist_rng_mt19937 : llama_dist_rng {
|
struct llama_dist_rng_mt19937 : llama_dist_rng {
|
||||||
|
uint32_t seed;
|
||||||
std::mt19937 rng;
|
std::mt19937 rng;
|
||||||
|
|
||||||
llama_dist_rng_mt19937(uint32_t seed) : rng(seed) {}
|
llama_dist_rng_mt19937(uint32_t seed) : seed(seed), rng(seed) {}
|
||||||
|
|
||||||
bool requires_sorted() override { return false; }
|
bool requires_sorted() override { return false; }
|
||||||
|
|
||||||
|
|
@ -605,9 +613,14 @@ struct llama_dist_rng_mt19937 : llama_dist_rng {
|
||||||
}
|
}
|
||||||
|
|
||||||
void reseed(uint32_t s) override {
|
void reseed(uint32_t s) override {
|
||||||
|
seed = s;
|
||||||
rng.seed(s);
|
rng.seed(s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void reset() override {
|
||||||
|
rng.seed(seed);
|
||||||
|
}
|
||||||
|
|
||||||
std::unique_ptr<llama_dist_rng> clone() const override {
|
std::unique_ptr<llama_dist_rng> clone() const override {
|
||||||
return std::make_unique<llama_dist_rng_mt19937>(*this);
|
return std::make_unique<llama_dist_rng_mt19937>(*this);
|
||||||
}
|
}
|
||||||
|
|
@ -641,6 +654,11 @@ struct llama_dist_rng_blue : llama_dist_rng {
|
||||||
bn_rng.reseed(s);
|
bn_rng.reseed(s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void reset() override {
|
||||||
|
bn_rng.rng->reset();
|
||||||
|
bn_rng.reset_states();
|
||||||
|
}
|
||||||
|
|
||||||
std::unique_ptr<llama_dist_rng> clone() const override {
|
std::unique_ptr<llama_dist_rng> clone() const override {
|
||||||
return std::make_unique<llama_dist_rng_blue>(*this);
|
return std::make_unique<llama_dist_rng_blue>(*this);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue