sampling : build fix and cleanup

This commit is contained in:
Jan Boon 2026-02-09 05:36:32 +00:00
parent e896007ad1
commit a4858de4e4
1 changed files with 4 additions and 20 deletions

View File

@ -342,12 +342,6 @@ struct llama_dist_rng {
// this also indicates whether the RNG output itself must be consumed in a sequential order
virtual bool requires_sorted() = 0;
// for compatibility with std::discrete_distribution
// only used in a disabled branch of llama_sampler_dist_apply
virtual uint32_t rng_min() = 0;
virtual uint32_t rng_max() = 0;
virtual uint32_t next() = 0; // uniform bits in [rng_min(), rng_max()]
virtual uint32_t next32() = 0; // uniform 32 bits
virtual uint64_t next64() = 0; // uniform 64 bits
virtual double nextf() = 0; // uniform double in [0, 1)
@ -489,9 +483,9 @@ struct llama_dist_urbg {
llama_dist_rng & rng;
result_type min() { return rng.rng_min(); }
result_type max() { return rng.rng_max(); }
result_type operator()() { return rng.next(); }
static constexpr result_type min() { return 0; }
static constexpr result_type max() { return UINT32_MAX; }
result_type operator()() { return rng.next32(); }
};
// wrapper to use existing llama_sample_dist for mt19937, otherwise implements CDF walk directly
@ -543,8 +537,6 @@ struct llama_dist_rng_lowbias32 : llama_dist_rng {
llama_dist_rng_lowbias32(uint32_t seed) : hashed_seed(hash(seed)), position(0) {}
bool requires_sorted() override { return false; }
uint32_t rng_min() override { return 0; }
uint32_t rng_max() override { return UINT32_MAX; }
static uint32_t hash(uint32_t x) { // lowbias32
// coefficients from https://github.com/skeeto/hash-prospector/issues/19
@ -554,7 +546,7 @@ struct llama_dist_rng_lowbias32 : llama_dist_rng {
return x;
}
uint32_t next() override {
uint32_t next() {
uint32_t val = hash(position ^ hashed_seed);
position++;
return val;
@ -597,10 +589,6 @@ struct llama_dist_rng_mt19937 : llama_dist_rng {
bool requires_sorted() override { return false; }
uint32_t rng_min() override { return std::mt19937::min(); }
uint32_t rng_max() override { return std::mt19937::max(); }
uint32_t next() override { return rng(); }
uint32_t next32() override {
return rng();
}
@ -638,10 +626,6 @@ struct llama_dist_rng_blue : llama_dist_rng {
bool requires_sorted() override { return true; }
uint32_t rng_min() override { return 0; }
uint32_t rng_max() override { return (1u << bn_rng.bit_depth) - 1; }
uint32_t next() override { return bn_rng.next(); }
uint32_t next32() override {
return bn_rng.next32();
}