diff --git a/src/llama-sampler.cpp b/src/llama-sampler.cpp index df30a613d5..9f5cba09c2 100644 --- a/src/llama-sampler.cpp +++ b/src/llama-sampler.cpp @@ -352,6 +352,7 @@ struct llama_dist_rng { virtual uint64_t next64() = 0; // uniform 64 bits virtual double nextf() = 0; // uniform double in [0, 1) virtual void reseed(uint32_t s) = 0; + virtual void reset() = 0; // reset to post-seed state virtual std::unique_ptr clone() const = 0; }; @@ -400,15 +401,15 @@ struct blue_noise_rng { const int n = (1 << bit_depth) - 1; states.resize(n); // at 16-bit depth, this uses 128KB of state - reset(); + reset_states(); } void reseed(uint32_t s) { rng->reseed(s); - reset(); + reset_states(); } - void reset() { + void reset_states() { const int n = (int)states.size(); // 5 reachable states with distribution 3:3:2:1:1 @@ -424,6 +425,8 @@ struct blue_noise_rng { uint32_t h = rng->next32() % 10; states[i] = {tbl[h][0], tbl[h][1]}; // random initial state } + + rng->reset(); // reset position so generation starts from 0 } uint16_t advance(uint32_t h) { @@ -573,15 +576,20 @@ struct llama_dist_rng_lowbias32 : llama_dist_rng { position = 0; } + void reset() override { + position = 0; + } + std::unique_ptr clone() const override { return std::make_unique(*this); } }; struct llama_dist_rng_mt19937 : llama_dist_rng { + uint32_t seed; std::mt19937 rng; - llama_dist_rng_mt19937(uint32_t seed) : rng(seed) {} + llama_dist_rng_mt19937(uint32_t seed) : seed(seed), rng(seed) {} bool requires_sorted() override { return false; } @@ -605,9 +613,14 @@ struct llama_dist_rng_mt19937 : llama_dist_rng { } void reseed(uint32_t s) override { + seed = s; rng.seed(s); } + void reset() override { + rng.seed(seed); + } + std::unique_ptr clone() const override { return std::make_unique(*this); } @@ -641,6 +654,11 @@ struct llama_dist_rng_blue : llama_dist_rng { bn_rng.reseed(s); } + void reset() override { + bn_rng.rng->reset(); + bn_rng.reset_states(); + } + std::unique_ptr clone() const override { return std::make_unique(*this); }