sampling : blue noise requires tokens to be sorted
This commit is contained in:
parent
d5def78bb0
commit
e829f2904e
|
|
@ -443,9 +443,10 @@ struct blue_noise_rng {
|
|||
struct llama_dist_rng {
|
||||
virtual ~llama_dist_rng() = default;
|
||||
|
||||
virtual bool requires_sorted() = 0; // whether the RNG requires sorted input for proper properties
|
||||
virtual uint32_t rng_min() = 0;
|
||||
virtual uint32_t rng_max() = 0;
|
||||
virtual uint32_t next() = 0; // uniform bits in [rng_min(), rng_max()]
|
||||
virtual uint32_t next() = 0; // uniform bits in [rng_min(), rng_max()]
|
||||
virtual double nextf() = 0; // uniform double in [0, 1)
|
||||
virtual void reseed(uint32_t s) = 0;
|
||||
virtual std::unique_ptr<llama_dist_rng> clone() const = 0;
|
||||
|
|
@ -468,6 +469,8 @@ struct llama_dist_rng_white : llama_dist_rng {
|
|||
|
||||
llama_dist_rng_white(uint32_t seed) : rng(seed) {}
|
||||
|
||||
bool requires_sorted() override { return false; }
|
||||
|
||||
uint32_t rng_min() override { return std::mt19937::min(); }
|
||||
uint32_t rng_max() override { return std::mt19937::max(); }
|
||||
|
||||
|
|
@ -496,6 +499,8 @@ struct llama_dist_rng_blue : llama_dist_rng {
|
|||
|
||||
llama_dist_rng_blue(uint32_t seed) : bn_rng(16, seed) {}
|
||||
|
||||
bool requires_sorted() override { return true; }
|
||||
|
||||
uint32_t rng_min() override { return 0; }
|
||||
uint32_t rng_max() override { return (1u << bn_rng.bit_depth) - 1; }
|
||||
|
||||
|
|
@ -1234,6 +1239,11 @@ static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_da
|
|||
return;
|
||||
}
|
||||
|
||||
// sort if required by the RNG (e.g., blue noise needs sorted input for proper temporal properties)
|
||||
if (ctx->rng->requires_sorted() && !cur_p->sorted) {
|
||||
llama_token_data_array_partial_sort_inplace(cur_p, cur_p->size);
|
||||
}
|
||||
|
||||
// max logit for numerical stability
|
||||
float max_l = cur_p->data[0].logit;
|
||||
if (!cur_p->sorted) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue