sampling : implement disabled branch to support blue noise
This commit is contained in:
parent
23b5a5c026
commit
7f433763b6
|
|
@ -448,10 +448,12 @@ struct blue_noise_rng {
|
||||||
struct llama_dist_rng {
|
struct llama_dist_rng {
|
||||||
virtual ~llama_dist_rng() = default;
|
virtual ~llama_dist_rng() = default;
|
||||||
|
|
||||||
virtual bool requires_sorted() = 0; // whether the RNG requires sorted input for proper properties
|
// whether the RNG requires sorted input for proper properties
|
||||||
|
// this also indicates whether the RNG output itself must be consumed in a coherent order
|
||||||
|
virtual bool requires_sorted() = 0;
|
||||||
|
|
||||||
// for compatilibility with std::discrete_distribution
|
// for compatibility with std::discrete_distribution
|
||||||
// nly used in a disabled branch of llama_sampler_dist_apply
|
// only used in a disabled branch of llama_sampler_dist_apply
|
||||||
virtual uint32_t rng_min() = 0;
|
virtual uint32_t rng_min() = 0;
|
||||||
virtual uint32_t rng_max() = 0;
|
virtual uint32_t rng_max() = 0;
|
||||||
virtual uint32_t next() = 0; // uniform bits in [rng_min(), rng_max()]
|
virtual uint32_t next() = 0; // uniform bits in [rng_min(), rng_max()]
|
||||||
|
|
@ -474,6 +476,48 @@ struct llama_dist_urbg {
|
||||||
result_type operator()() { return rng.next(); }
|
result_type operator()() { return rng.next(); }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// wrapper to use existing llama_sample_dist for mt19937, otherwise implements CDF walk directly
|
||||||
|
// this is currently only used in a disabled branch of llama_sampler_dist_apply, added for compatibility and potential use by other samplers
|
||||||
|
// flag normalized to skip recomputing the probability sum when probs already sum to 1
|
||||||
|
static int llama_sample_dist_rng(llama_token_data_array * cur_p, llama_dist_rng & rng, bool normalized = false) {
|
||||||
|
if (!rng.requires_sorted()) {
|
||||||
|
llama_dist_urbg urbg{rng};
|
||||||
|
return llama_sample_dist(cur_p, urbg);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!cur_p->sorted) {
|
||||||
|
llama_token_data_array_partial_sort_inplace(cur_p, cur_p->size);
|
||||||
|
}
|
||||||
|
const double rnd = rng.nextf();
|
||||||
|
|
||||||
|
double sum_run = 0.0;
|
||||||
|
|
||||||
|
if (normalized) {
|
||||||
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||||
|
sum_run += cur_p->data[i].p;
|
||||||
|
if (sum_run >= rnd) {
|
||||||
|
return i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
double sum_cum = 0.0;
|
||||||
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||||
|
sum_cum += cur_p->data[i].p;
|
||||||
|
}
|
||||||
|
|
||||||
|
const double sum_tgt = sum_cum * rnd;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||||
|
sum_run += cur_p->data[i].p;
|
||||||
|
if (sum_run >= sum_tgt) {
|
||||||
|
return i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return (int)(cur_p->size - 1);
|
||||||
|
}
|
||||||
|
|
||||||
struct llama_dist_rng_mt19937 : llama_dist_rng {
|
struct llama_dist_rng_mt19937 : llama_dist_rng {
|
||||||
std::mt19937 rng;
|
std::mt19937 rng;
|
||||||
|
|
||||||
|
|
@ -1301,9 +1345,7 @@ static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_da
|
||||||
cur_p->data[i].p /= sum_cum;
|
cur_p->data[i].p /= sum_cum;
|
||||||
}
|
}
|
||||||
|
|
||||||
// this implementation is not guaranteed to preserve blue noise properties
|
cur_p->selected = llama_sample_dist_rng(cur_p, *ctx->rng, true);
|
||||||
llama_dist_urbg urbg{*ctx->rng};
|
|
||||||
cur_p->selected = llama_sample_dist(cur_p, urbg);
|
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue