diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 9b651c816d..0e7d4cb178 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -443,9 +443,10 @@ struct blue_noise_rng { struct llama_dist_rng { virtual ~llama_dist_rng() = default; + virtual bool requires_sorted() = 0; // whether the RNG requires sorted input for proper properties virtual uint32_t rng_min() = 0; virtual uint32_t rng_max() = 0; - virtual uint32_t next() = 0; // uniform bits in [rng_min(), rng_max()] + virtual uint32_t next() = 0; // uniform bits in [rng_min(), rng_max()] virtual double nextf() = 0; // uniform double in [0, 1) virtual void reseed(uint32_t s) = 0; virtual std::unique_ptr clone() const = 0; @@ -468,6 +469,8 @@ struct llama_dist_rng_white : llama_dist_rng { llama_dist_rng_white(uint32_t seed) : rng(seed) {} + bool requires_sorted() override { return false; } + uint32_t rng_min() override { return std::mt19937::min(); } uint32_t rng_max() override { return std::mt19937::max(); } @@ -496,6 +499,8 @@ struct llama_dist_rng_blue : llama_dist_rng { llama_dist_rng_blue(uint32_t seed) : bn_rng(16, seed) {} + bool requires_sorted() override { return true; } + uint32_t rng_min() override { return 0; } uint32_t rng_max() override { return (1u << bn_rng.bit_depth) - 1; } @@ -1234,6 +1239,11 @@ static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_da return; } + // sort if required by the RNG (e.g., blue noise needs sorted input for proper temporal properties) + if (ctx->rng->requires_sorted() && !cur_p->sorted) { + llama_token_data_array_partial_sort_inplace(cur_p, cur_p->size); + } + // max logit for numerical stability float max_l = cur_p->data[0].logit; if (!cur_p->sorted) {