diff --git a/src/llama-sampler.cpp b/src/llama-sampler.cpp index 0a74f2d26f..f06c76077b 100644 --- a/src/llama-sampler.cpp +++ b/src/llama-sampler.cpp @@ -342,12 +342,6 @@ struct llama_dist_rng { // this also indicates whether the RNG output itself must be consumed in a sequential order virtual bool requires_sorted() = 0; - // for compatibility with std::discrete_distribution - // only used in a disabled branch of llama_sampler_dist_apply - virtual uint32_t rng_min() = 0; - virtual uint32_t rng_max() = 0; - virtual uint32_t next() = 0; // uniform bits in [rng_min(), rng_max()] - virtual uint32_t next32() = 0; // uniform 32 bits virtual uint64_t next64() = 0; // uniform 64 bits virtual double nextf() = 0; // uniform double in [0, 1) @@ -489,9 +483,9 @@ struct llama_dist_urbg { llama_dist_rng & rng; - result_type min() { return rng.rng_min(); } - result_type max() { return rng.rng_max(); } - result_type operator()() { return rng.next(); } + static constexpr result_type min() { return 0; } + static constexpr result_type max() { return UINT32_MAX; } + result_type operator()() { return rng.next32(); } }; // wrapper to use existing llama_sample_dist for mt19937, otherwise implements CDF walk directly @@ -543,8 +537,6 @@ struct llama_dist_rng_lowbias32 : llama_dist_rng { llama_dist_rng_lowbias32(uint32_t seed) : hashed_seed(hash(seed)), position(0) {} bool requires_sorted() override { return false; } - uint32_t rng_min() override { return 0; } - uint32_t rng_max() override { return UINT32_MAX; } static uint32_t hash(uint32_t x) { // lowbias32 // coefficients from https://github.com/skeeto/hash-prospector/issues/19 @@ -554,7 +546,7 @@ struct llama_dist_rng_lowbias32 : llama_dist_rng { return x; } - uint32_t next() override { + uint32_t next() { uint32_t val = hash(position ^ hashed_seed); position++; return val; @@ -597,10 +589,6 @@ struct llama_dist_rng_mt19937 : llama_dist_rng { bool requires_sorted() override { return false; } - uint32_t rng_min() override { return std::mt19937::min(); } - uint32_t rng_max() override { return std::mt19937::max(); } - uint32_t next() override { return rng(); } - uint32_t next32() override { return rng(); } @@ -638,10 +626,6 @@ struct llama_dist_rng_blue : llama_dist_rng { bool requires_sorted() override { return true; } - uint32_t rng_min() override { return 0; } - uint32_t rng_max() override { return (1u << bn_rng.bit_depth) - 1; } - uint32_t next() override { return bn_rng.next(); } - uint32_t next32() override { return bn_rng.next32(); }