sampling : blue noise requires tokens to be sorted
This commit is contained in:
parent
d5def78bb0
commit
e829f2904e
|
|
@ -443,9 +443,10 @@ struct blue_noise_rng {
|
||||||
struct llama_dist_rng {
|
struct llama_dist_rng {
|
||||||
virtual ~llama_dist_rng() = default;
|
virtual ~llama_dist_rng() = default;
|
||||||
|
|
||||||
|
virtual bool requires_sorted() = 0; // whether the RNG requires sorted input for proper properties
|
||||||
virtual uint32_t rng_min() = 0;
|
virtual uint32_t rng_min() = 0;
|
||||||
virtual uint32_t rng_max() = 0;
|
virtual uint32_t rng_max() = 0;
|
||||||
virtual uint32_t next() = 0; // uniform bits in [rng_min(), rng_max()]
|
virtual uint32_t next() = 0; // uniform bits in [rng_min(), rng_max()]
|
||||||
virtual double nextf() = 0; // uniform double in [0, 1)
|
virtual double nextf() = 0; // uniform double in [0, 1)
|
||||||
virtual void reseed(uint32_t s) = 0;
|
virtual void reseed(uint32_t s) = 0;
|
||||||
virtual std::unique_ptr<llama_dist_rng> clone() const = 0;
|
virtual std::unique_ptr<llama_dist_rng> clone() const = 0;
|
||||||
|
|
@ -468,6 +469,8 @@ struct llama_dist_rng_white : llama_dist_rng {
|
||||||
|
|
||||||
llama_dist_rng_white(uint32_t seed) : rng(seed) {}
|
llama_dist_rng_white(uint32_t seed) : rng(seed) {}
|
||||||
|
|
||||||
|
bool requires_sorted() override { return false; }
|
||||||
|
|
||||||
uint32_t rng_min() override { return std::mt19937::min(); }
|
uint32_t rng_min() override { return std::mt19937::min(); }
|
||||||
uint32_t rng_max() override { return std::mt19937::max(); }
|
uint32_t rng_max() override { return std::mt19937::max(); }
|
||||||
|
|
||||||
|
|
@ -496,6 +499,8 @@ struct llama_dist_rng_blue : llama_dist_rng {
|
||||||
|
|
||||||
llama_dist_rng_blue(uint32_t seed) : bn_rng(16, seed) {}
|
llama_dist_rng_blue(uint32_t seed) : bn_rng(16, seed) {}
|
||||||
|
|
||||||
|
bool requires_sorted() override { return true; }
|
||||||
|
|
||||||
uint32_t rng_min() override { return 0; }
|
uint32_t rng_min() override { return 0; }
|
||||||
uint32_t rng_max() override { return (1u << bn_rng.bit_depth) - 1; }
|
uint32_t rng_max() override { return (1u << bn_rng.bit_depth) - 1; }
|
||||||
|
|
||||||
|
|
@ -1234,6 +1239,11 @@ static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_da
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// sort if required by the RNG (e.g., blue noise needs sorted input for proper temporal properties)
|
||||||
|
if (ctx->rng->requires_sorted() && !cur_p->sorted) {
|
||||||
|
llama_token_data_array_partial_sort_inplace(cur_p, cur_p->size);
|
||||||
|
}
|
||||||
|
|
||||||
// max logit for numerical stability
|
// max logit for numerical stability
|
||||||
float max_l = cur_p->data[0].logit;
|
float max_l = cur_p->data[0].logit;
|
||||||
if (!cur_p->sorted) {
|
if (!cur_p->sorted) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue