sampling : implement disabled branch to support blue noise
This commit is contained in:
parent
23b5a5c026
commit
7f433763b6
|
|
@ -448,10 +448,12 @@ struct blue_noise_rng {
|
|||
struct llama_dist_rng {
|
||||
virtual ~llama_dist_rng() = default;
|
||||
|
||||
virtual bool requires_sorted() = 0; // whether the RNG requires sorted input for proper properties
|
||||
// whether the RNG requires sorted input for proper properties
|
||||
// this also indicates whether the RNG output itself must be consumed in a coherent order
|
||||
virtual bool requires_sorted() = 0;
|
||||
|
||||
// for compatilibility with std::discrete_distribution
|
||||
// nly used in a disabled branch of llama_sampler_dist_apply
|
||||
// for compatibility with std::discrete_distribution
|
||||
// only used in a disabled branch of llama_sampler_dist_apply
|
||||
virtual uint32_t rng_min() = 0;
|
||||
virtual uint32_t rng_max() = 0;
|
||||
virtual uint32_t next() = 0; // uniform bits in [rng_min(), rng_max()]
|
||||
|
|
@ -474,6 +476,48 @@ struct llama_dist_urbg {
|
|||
result_type operator()() { return rng.next(); }
|
||||
};
|
||||
|
||||
// wrapper to use existing llama_sample_dist for mt19937, otherwise implements CDF walk directly
|
||||
// this is currently only used in a disabled branch of llama_sampler_dist_apply, added for compatibility and potential use by other samplers
|
||||
// flag normalized to skip recomputing the probability sum when probs already sum to 1
|
||||
static int llama_sample_dist_rng(llama_token_data_array * cur_p, llama_dist_rng & rng, bool normalized = false) {
|
||||
if (!rng.requires_sorted()) {
|
||||
llama_dist_urbg urbg{rng};
|
||||
return llama_sample_dist(cur_p, urbg);
|
||||
}
|
||||
|
||||
if (!cur_p->sorted) {
|
||||
llama_token_data_array_partial_sort_inplace(cur_p, cur_p->size);
|
||||
}
|
||||
const double rnd = rng.nextf();
|
||||
|
||||
double sum_run = 0.0;
|
||||
|
||||
if (normalized) {
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
sum_run += cur_p->data[i].p;
|
||||
if (sum_run >= rnd) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
double sum_cum = 0.0;
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
sum_cum += cur_p->data[i].p;
|
||||
}
|
||||
|
||||
const double sum_tgt = sum_cum * rnd;
|
||||
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
sum_run += cur_p->data[i].p;
|
||||
if (sum_run >= sum_tgt) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return (int)(cur_p->size - 1);
|
||||
}
|
||||
|
||||
struct llama_dist_rng_mt19937 : llama_dist_rng {
|
||||
std::mt19937 rng;
|
||||
|
||||
|
|
@ -1301,9 +1345,7 @@ static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_da
|
|||
cur_p->data[i].p /= sum_cum;
|
||||
}
|
||||
|
||||
// this implementation is not guaranteed to preserve blue noise properties
|
||||
llama_dist_urbg urbg{*ctx->rng};
|
||||
cur_p->selected = llama_sample_dist(cur_p, urbg);
|
||||
cur_p->selected = llama_sample_dist_rng(cur_p, *ctx->rng, true);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue