llama : cleanup and restore alternate code path

This commit is contained in:
Jan Boon 2026-02-04 23:35:22 +00:00
parent 3b4061981b
commit 766d86df29
1 changed files with 38 additions and 10 deletions

View File

@ -214,7 +214,8 @@ static void llama_token_data_array_partial_sort_inplace(llama_token_data_array *
cur_p->sorted = true;
}
static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng) {
template<typename RNG>
static int llama_sample_dist(llama_token_data_array * cur_p, RNG & rng) {
// iterator for the probabilities
#ifdef __GNUC__
#pragma GCC diagnostic push
@ -334,6 +335,10 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k)
}
// pseudo-random number generator with ~6db/octave blue noise
// important: blue noise properties cannot be preserved when
// the generator is used for multiple purposes simultaneously
// nor when multiple next calls are used to construct a larger value
// nor when integer outputs are used with the modulo operator
struct blue_noise_rng {
uint8_t bit_depth = 0;
uint32_t seed = 0;
@ -436,16 +441,38 @@ struct blue_noise_rng {
struct llama_dist_rng {
virtual ~llama_dist_rng() = default;
virtual uint32_t rng_min() = 0;
virtual uint32_t rng_max() = 0;
virtual uint32_t next() = 0; // uniform bits in [rng_min(), rng_max()]
virtual double nextf() = 0; // uniform double in [0, 1)
virtual void reseed(uint32_t s) = 0;
virtual std::unique_ptr<llama_dist_rng> clone() const = 0;
};
// adapter to satisfy UniformRandomBitGenerator for std::discrete_distribution
// note: not guaranteed to preserve blue noise properties
struct llama_dist_urbg {
using result_type = uint32_t;
llama_dist_rng & rng;
result_type min() { return rng.rng_min(); }
result_type max() { return rng.rng_max(); }
result_type operator()() { return rng.next(); }
};
struct llama_dist_rng_white : llama_dist_rng {
std::mt19937 rng;
llama_dist_rng_white(uint32_t seed) : rng(seed) {}
uint32_t rng_min() override { return std::mt19937::min(); }
uint32_t rng_max() override { return std::mt19937::max(); }
uint32_t next() override {
return rng();
}
double nextf() override {
std::uniform_real_distribution<double> dist(0.0, 1.0);
return dist(rng);
@ -467,6 +494,13 @@ struct llama_dist_rng_blue : llama_dist_rng {
llama_dist_rng_blue(uint32_t seed) : bn_rng(16, seed) {}
uint32_t rng_min() override { return 0; }
uint32_t rng_max() override { return (1u << bn_rng.bit_depth) - 1; }
uint32_t next() override {
return bn_rng.next();
}
double nextf() override {
return bn_rng.nextf();
}
@ -1249,15 +1283,9 @@ static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_da
cur_p->data[i].p /= sum_cum;
}
const double rnd = ctx->rng->nextf();
double cum = 0.0;
for (size_t i = 0; i < cur_p->size; ++i) {
cum += cur_p->data[i].p;
if (cum >= rnd) {
cur_p->selected = i;
break;
}
}
// this implementation is not guaranteed to preserve blue noise properties
llama_dist_urbg urbg{*ctx->rng};
cur_p->selected = llama_sample_dist(cur_p, urbg);
#endif
}