llama : cleanup and restore alternate code path
This commit is contained in:
parent
3b4061981b
commit
766d86df29
|
|
@ -214,7 +214,8 @@ static void llama_token_data_array_partial_sort_inplace(llama_token_data_array *
|
|||
cur_p->sorted = true;
|
||||
}
|
||||
|
||||
static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng) {
|
||||
template<typename RNG>
|
||||
static int llama_sample_dist(llama_token_data_array * cur_p, RNG & rng) {
|
||||
// iterator for the probabilities
|
||||
#ifdef __GNUC__
|
||||
#pragma GCC diagnostic push
|
||||
|
|
@ -334,6 +335,10 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k)
|
|||
}
|
||||
|
||||
// pseudo-random number generator with ~6db/octave blue noise
|
||||
// important: blue noise properties cannot be preserved when
|
||||
// the generator is used for multiple purposes simultaneously
|
||||
// nor when multiple next calls are used to construct a larger value
|
||||
// nor when integer outputs are used with the modulo operator
|
||||
struct blue_noise_rng {
|
||||
uint8_t bit_depth = 0;
|
||||
uint32_t seed = 0;
|
||||
|
|
@ -436,16 +441,38 @@ struct blue_noise_rng {
|
|||
struct llama_dist_rng {
|
||||
virtual ~llama_dist_rng() = default;
|
||||
|
||||
virtual uint32_t rng_min() = 0;
|
||||
virtual uint32_t rng_max() = 0;
|
||||
virtual uint32_t next() = 0; // uniform bits in [rng_min(), rng_max()]
|
||||
virtual double nextf() = 0; // uniform double in [0, 1)
|
||||
virtual void reseed(uint32_t s) = 0;
|
||||
virtual std::unique_ptr<llama_dist_rng> clone() const = 0;
|
||||
};
|
||||
|
||||
// adapter to satisfy UniformRandomBitGenerator for std::discrete_distribution
|
||||
// note: not guaranteed to preserve blue noise properties
|
||||
struct llama_dist_urbg {
|
||||
using result_type = uint32_t;
|
||||
|
||||
llama_dist_rng & rng;
|
||||
|
||||
result_type min() { return rng.rng_min(); }
|
||||
result_type max() { return rng.rng_max(); }
|
||||
result_type operator()() { return rng.next(); }
|
||||
};
|
||||
|
||||
struct llama_dist_rng_white : llama_dist_rng {
|
||||
std::mt19937 rng;
|
||||
|
||||
llama_dist_rng_white(uint32_t seed) : rng(seed) {}
|
||||
|
||||
uint32_t rng_min() override { return std::mt19937::min(); }
|
||||
uint32_t rng_max() override { return std::mt19937::max(); }
|
||||
|
||||
uint32_t next() override {
|
||||
return rng();
|
||||
}
|
||||
|
||||
double nextf() override {
|
||||
std::uniform_real_distribution<double> dist(0.0, 1.0);
|
||||
return dist(rng);
|
||||
|
|
@ -467,6 +494,13 @@ struct llama_dist_rng_blue : llama_dist_rng {
|
|||
|
||||
llama_dist_rng_blue(uint32_t seed) : bn_rng(16, seed) {}
|
||||
|
||||
uint32_t rng_min() override { return 0; }
|
||||
uint32_t rng_max() override { return (1u << bn_rng.bit_depth) - 1; }
|
||||
|
||||
uint32_t next() override {
|
||||
return bn_rng.next();
|
||||
}
|
||||
|
||||
double nextf() override {
|
||||
return bn_rng.nextf();
|
||||
}
|
||||
|
|
@ -1249,15 +1283,9 @@ static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_da
|
|||
cur_p->data[i].p /= sum_cum;
|
||||
}
|
||||
|
||||
const double rnd = ctx->rng->nextf();
|
||||
double cum = 0.0;
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
cum += cur_p->data[i].p;
|
||||
if (cum >= rnd) {
|
||||
cur_p->selected = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
// this implementation is not guaranteed to preserve blue noise properties
|
||||
llama_dist_urbg urbg{*ctx->rng};
|
||||
cur_p->selected = llama_sample_dist(cur_p, urbg);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue