|
|
|
|
@ -214,7 +214,8 @@ static void llama_token_data_array_partial_sort_inplace(llama_token_data_array *
|
|
|
|
|
cur_p->sorted = true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng) {
|
|
|
|
|
template<typename RNG>
|
|
|
|
|
static int llama_sample_dist(llama_token_data_array * cur_p, RNG & rng) {
|
|
|
|
|
// iterator for the probabilities
|
|
|
|
|
#ifdef __GNUC__
|
|
|
|
|
#pragma GCC diagnostic push
|
|
|
|
|
@ -333,6 +334,201 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k)
|
|
|
|
|
cur_p->size = k;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// generative error diffusion for sequential blue noise
|
|
|
|
|
// pseudo-random number generator with ~6db/octave blue noise
|
|
|
|
|
// this generator produces a uniform distribution
|
|
|
|
|
// important: blue noise properties cannot be preserved when
|
|
|
|
|
// the generator is used for multiple purposes simultaneously
|
|
|
|
|
// nor when multiple next calls are used to construct a larger value
|
|
|
|
|
// nor when integer outputs are used with the modulo operator
|
|
|
|
|
struct blue_noise_rng {
|
|
|
|
|
uint8_t bit_depth = 0;
|
|
|
|
|
uint32_t seed = 0;
|
|
|
|
|
uint32_t position = 0;
|
|
|
|
|
|
|
|
|
|
// binary tree of 1-bit 50% duty cycle error diffusion dithering blue noise generators
|
|
|
|
|
std::vector<std::array<int8_t, 2>> states; // {err0, err1} per tree node
|
|
|
|
|
|
|
|
|
|
blue_noise_rng() = default;
|
|
|
|
|
|
|
|
|
|
blue_noise_rng(uint8_t bit_depth, uint32_t seed) {
|
|
|
|
|
init(bit_depth, seed);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// currently this uses lowbias32 as the white noise RNG source
|
|
|
|
|
// in practice, any white noise RNG source works
|
|
|
|
|
// this random noise is used to perturb the error diffusion weights (binary decision)
|
|
|
|
|
// as well as to fill in the low bits of the double precision output to eliminate aliasing
|
|
|
|
|
static uint32_t hash(uint32_t x) { // lowbias32
|
|
|
|
|
x ^= x >> 16; x *= 0x21f0aaad;
|
|
|
|
|
x ^= x >> 15; x *= 0x735a2d97;
|
|
|
|
|
x ^= x >> 15;
|
|
|
|
|
return x;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void init(uint8_t depth, uint32_t s) {
|
|
|
|
|
bit_depth = std::clamp<uint8_t>(depth, 1, 16);
|
|
|
|
|
seed = hash(s);
|
|
|
|
|
|
|
|
|
|
const int n = (1 << bit_depth) - 1;
|
|
|
|
|
states.resize(n); // at 16-bit depth, this uses 128KB of state
|
|
|
|
|
|
|
|
|
|
reset();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void reset() {
|
|
|
|
|
const int n = (int)states.size();
|
|
|
|
|
position = 0;
|
|
|
|
|
|
|
|
|
|
// 5 reachable states with distribution 3:3:2:1:1
|
|
|
|
|
// established based on empirical testing
|
|
|
|
|
static const int8_t tbl[10][2] = {
|
|
|
|
|
{ 0, 0}, { 0, 0}, { 0, 0},
|
|
|
|
|
{-1, 0}, {-1, 0}, {-1, 0},
|
|
|
|
|
{ 0, -1}, { 0, -1},
|
|
|
|
|
{-2, 0},
|
|
|
|
|
{-1, -1},
|
|
|
|
|
};
|
|
|
|
|
for (int i = 0; i < n; i++) {
|
|
|
|
|
uint32_t h = hash((uint32_t)i ^ seed) % 10;
|
|
|
|
|
states[i] = {tbl[h][0], tbl[h][1]}; // random initial state
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
uint16_t next(uint32_t * hash_remainder = nullptr) {
|
|
|
|
|
uint32_t h = hash(position ^ seed);
|
|
|
|
|
position++;
|
|
|
|
|
|
|
|
|
|
// traverse binary tree, one error diffusion ditherer per population split
|
|
|
|
|
// thresholding output at any value still produces blue noise
|
|
|
|
|
uint32_t acc = 0;
|
|
|
|
|
for (int level = 0; level < bit_depth; level++) {
|
|
|
|
|
auto & s = states[(1 << level) - 1 + acc]; // heap-style index
|
|
|
|
|
|
|
|
|
|
int out = (s[0] >= 0) ? 1 : 0;
|
|
|
|
|
int8_t qe = s[0] + (int8_t)(out ? -1 : 1); // inverse autocorrelation
|
|
|
|
|
|
|
|
|
|
s[0] = s[1]; // step forward
|
|
|
|
|
s[1] = 0;
|
|
|
|
|
|
|
|
|
|
// error diffusion dithering using binary weight perturbation
|
|
|
|
|
s[(h >> level) & 1 ? 0 : 1] += qe; // forward to t+1 or defer to t+2
|
|
|
|
|
|
|
|
|
|
acc = acc * 2 + out;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (hash_remainder) {
|
|
|
|
|
*hash_remainder = h >> bit_depth; // unused bits from random hash
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return (uint16_t)acc;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// blue noise in the upper bit_depth bits, white noise hash remainder in the lower bits
|
|
|
|
|
// do not use with modulo operator, as it would just produce white noise
|
|
|
|
|
uint32_t next32() {
|
|
|
|
|
uint32_t rem;
|
|
|
|
|
uint32_t val = next(&rem);
|
|
|
|
|
return (val << (32 - bit_depth)) | rem;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// uniform double in [0, 1) with blue noise temporal autocorrelation
|
|
|
|
|
double nextf() {
|
|
|
|
|
double res = 0.0;
|
|
|
|
|
res += hash(position ^ ~seed); // fill low bits with white noise
|
|
|
|
|
res *= 1.0 / 4294967296.0;
|
|
|
|
|
res += next32();
|
|
|
|
|
res *= 1.0 / 4294967296.0;
|
|
|
|
|
if (res >= 1.0) res = std::nextafter(1.0, 0.0);
|
|
|
|
|
return res;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// abstract RNG interface for the dist sampler
|
|
|
|
|
struct llama_dist_rng {
|
|
|
|
|
virtual ~llama_dist_rng() = default;
|
|
|
|
|
|
|
|
|
|
virtual bool requires_sorted() = 0; // whether the RNG requires sorted input for proper properties
|
|
|
|
|
|
|
|
|
|
// for compatilibility with std::discrete_distribution
|
|
|
|
|
// nly used in a disabled branch of llama_sampler_dist_apply
|
|
|
|
|
virtual uint32_t rng_min() = 0;
|
|
|
|
|
virtual uint32_t rng_max() = 0;
|
|
|
|
|
virtual uint32_t next() = 0; // uniform bits in [rng_min(), rng_max()]
|
|
|
|
|
|
|
|
|
|
virtual double nextf() = 0; // uniform double in [0, 1)
|
|
|
|
|
virtual void reseed(uint32_t s) = 0;
|
|
|
|
|
virtual std::unique_ptr<llama_dist_rng> clone() const = 0;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// adapter to satisfy UniformRandomBitGenerator for std::discrete_distribution
|
|
|
|
|
// note: not guaranteed to preserve blue noise properties
|
|
|
|
|
// this is only used in a disabled branch of llama_sampler_dist_apply, added for compatibility
|
|
|
|
|
struct llama_dist_urbg {
|
|
|
|
|
using result_type = uint32_t;
|
|
|
|
|
|
|
|
|
|
llama_dist_rng & rng;
|
|
|
|
|
|
|
|
|
|
result_type min() { return rng.rng_min(); }
|
|
|
|
|
result_type max() { return rng.rng_max(); }
|
|
|
|
|
result_type operator()() { return rng.next(); }
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct llama_dist_rng_mt19937 : llama_dist_rng {
|
|
|
|
|
std::mt19937 rng;
|
|
|
|
|
|
|
|
|
|
llama_dist_rng_mt19937(uint32_t seed) : rng(seed) {}
|
|
|
|
|
|
|
|
|
|
bool requires_sorted() override { return false; }
|
|
|
|
|
|
|
|
|
|
uint32_t rng_min() override { return std::mt19937::min(); }
|
|
|
|
|
uint32_t rng_max() override { return std::mt19937::max(); }
|
|
|
|
|
|
|
|
|
|
uint32_t next() override {
|
|
|
|
|
return rng();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
double nextf() override {
|
|
|
|
|
std::uniform_real_distribution<double> dist(0.0, 1.0);
|
|
|
|
|
return dist(rng);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void reseed(uint32_t s) override {
|
|
|
|
|
rng.seed(s);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<llama_dist_rng> clone() const override {
|
|
|
|
|
return std::make_unique<llama_dist_rng_mt19937>(*this);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct llama_dist_rng_blue : llama_dist_rng {
|
|
|
|
|
blue_noise_rng bn_rng;
|
|
|
|
|
|
|
|
|
|
llama_dist_rng_blue(uint32_t seed) : bn_rng(16, seed) {}
|
|
|
|
|
|
|
|
|
|
bool requires_sorted() override { return true; }
|
|
|
|
|
|
|
|
|
|
uint32_t rng_min() override { return 0; }
|
|
|
|
|
uint32_t rng_max() override { return (1u << bn_rng.bit_depth) - 1; }
|
|
|
|
|
|
|
|
|
|
uint32_t next() override {
|
|
|
|
|
return bn_rng.next();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
double nextf() override {
|
|
|
|
|
return bn_rng.nextf();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void reseed(uint32_t s) override {
|
|
|
|
|
bn_rng.init(16, s);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<llama_dist_rng> clone() const override {
|
|
|
|
|
return std::make_unique<llama_dist_rng_blue>(*this);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
static uint32_t get_rng_seed(uint32_t seed) {
|
|
|
|
|
if (seed == LLAMA_DEFAULT_SEED) {
|
|
|
|
|
// use system clock if std::random_device is not a true RNG
|
|
|
|
|
@ -1023,7 +1219,7 @@ struct llama_sampler_dist : public llama_sampler_backend {
|
|
|
|
|
const uint32_t seed;
|
|
|
|
|
uint32_t seed_cur;
|
|
|
|
|
|
|
|
|
|
std::mt19937 rng;
|
|
|
|
|
std::unique_ptr<llama_dist_rng> rng;
|
|
|
|
|
|
|
|
|
|
ggml_tensor * inp_uniform;
|
|
|
|
|
};
|
|
|
|
|
@ -1049,6 +1245,11 @@ static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_da
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// sort if required by the RNG (e.g., blue noise needs sorted input for proper temporal properties)
|
|
|
|
|
if (ctx->rng->requires_sorted() && !cur_p->sorted) {
|
|
|
|
|
llama_token_data_array_partial_sort_inplace(cur_p, cur_p->size);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// max logit for numerical stability
|
|
|
|
|
float max_l = cur_p->data[0].logit;
|
|
|
|
|
if (!cur_p->sorted) {
|
|
|
|
|
@ -1069,8 +1270,7 @@ static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_da
|
|
|
|
|
// sample from the obtained probabilities and normalize the probs in a single pass
|
|
|
|
|
// this is ~3x faster on Mac with full gpt-oss vocab than the version below
|
|
|
|
|
//
|
|
|
|
|
std::uniform_real_distribution<double> dist(0.0f, 1.0f);
|
|
|
|
|
const double rnd = dist(ctx->rng);
|
|
|
|
|
const double rnd = ctx->rng->nextf();
|
|
|
|
|
|
|
|
|
|
double sum_run = 0.0f;
|
|
|
|
|
const double sum_tgt = sum_cum*rnd;
|
|
|
|
|
@ -1101,28 +1301,31 @@ static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_da
|
|
|
|
|
cur_p->data[i].p /= sum_cum;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
|
|
|
|
|
// this implementation is not guaranteed to preserve blue noise properties
|
|
|
|
|
llama_dist_urbg urbg{*ctx->rng};
|
|
|
|
|
cur_p->selected = llama_sample_dist(cur_p, urbg);
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void llama_sampler_dist_reset(struct llama_sampler * smpl) {
|
|
|
|
|
auto * ctx = (llama_sampler_dist *) smpl->ctx;
|
|
|
|
|
ctx->seed_cur = get_rng_seed(ctx->seed);
|
|
|
|
|
ctx->rng.seed(ctx->seed_cur);
|
|
|
|
|
ctx->rng->reseed(ctx->seed_cur);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) {
|
|
|
|
|
const auto * ctx = (const llama_sampler_dist *) smpl->ctx;
|
|
|
|
|
auto * result = llama_sampler_init_dist(ctx->seed);
|
|
|
|
|
auto * ctx = (llama_sampler_dist *) smpl->ctx;
|
|
|
|
|
|
|
|
|
|
// copy the state
|
|
|
|
|
{
|
|
|
|
|
auto * result_ctx = (llama_sampler_dist *) result->ctx;
|
|
|
|
|
|
|
|
|
|
result_ctx->rng = ctx->rng;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return result;
|
|
|
|
|
return llama_sampler_init(
|
|
|
|
|
/* .iface = */ smpl->iface,
|
|
|
|
|
/* .ctx = */ new llama_sampler_dist {
|
|
|
|
|
{ctx->get_name()},
|
|
|
|
|
/* .seed = */ ctx->seed,
|
|
|
|
|
/* .seed_cur = */ ctx->seed_cur,
|
|
|
|
|
/* .rng = */ ctx->rng->clone(),
|
|
|
|
|
/* .inp_uniform = */ nullptr,
|
|
|
|
|
}
|
|
|
|
|
);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void llama_sampler_dist_free(struct llama_sampler * smpl) {
|
|
|
|
|
@ -1154,6 +1357,30 @@ static void llama_sampler_dist_backend_apply(
|
|
|
|
|
ggml_set_name (sctx->inp_uniform, "uniform");
|
|
|
|
|
ggml_set_input(sctx->inp_uniform);
|
|
|
|
|
|
|
|
|
|
// If the RNG requires sorted input (e.g., blue noise), sort logits first
|
|
|
|
|
// so the CDF walk operates in probability-rank space, not arbitrary vocab order.
|
|
|
|
|
if (sctx->rng->requires_sorted()) {
|
|
|
|
|
auto ggml_sort = [ctx](struct ggml_tensor * a, struct ggml_tensor * b) {
|
|
|
|
|
GGML_ASSERT(ggml_nrows(a) == 1);
|
|
|
|
|
struct ggml_tensor * a_reshaped = ggml_reshape_2d(ctx, a, 1, a->ne[0]);
|
|
|
|
|
struct ggml_tensor * a_sorted = ggml_get_rows(ctx, a_reshaped, b);
|
|
|
|
|
return ggml_reshape_1d(ctx, a_sorted, a->ne[0]);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct ggml_tensor * sorted_idx = ggml_argsort(ctx, data->logits, GGML_SORT_ORDER_DESC);
|
|
|
|
|
ggml_set_name(sorted_idx, "dist_sorted_idx");
|
|
|
|
|
|
|
|
|
|
data->logits = ggml_sort(data->logits, sorted_idx);
|
|
|
|
|
ggml_set_name(data->logits, "dist_sorted_logits");
|
|
|
|
|
|
|
|
|
|
if (data->candidates) {
|
|
|
|
|
data->candidates = ggml_sort(data->candidates, sorted_idx);
|
|
|
|
|
} else {
|
|
|
|
|
data->candidates = sorted_idx;
|
|
|
|
|
}
|
|
|
|
|
ggml_set_name(data->candidates, "dist_sorted_candidates");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
struct ggml_tensor * probs = ggml_soft_max(ctx, data->logits);
|
|
|
|
|
ggml_set_name(probs, "dist_probs");
|
|
|
|
|
|
|
|
|
|
@ -1208,8 +1435,8 @@ static void llama_sampler_dist_backend_set_input(struct llama_sampler * smpl) {
|
|
|
|
|
// std::uniform_real_distribution<double> and
|
|
|
|
|
// std::uniform_real_distribution<float> with same rng will produce
|
|
|
|
|
// different sequences).
|
|
|
|
|
std::uniform_real_distribution<double> dist(0.0f, 1.0f);
|
|
|
|
|
const float rnd = dist(sctx->rng);
|
|
|
|
|
// nextf returns double, equivalent to std::uniform_real_distribution<double>
|
|
|
|
|
const float rnd = (float)sctx->rng->nextf();
|
|
|
|
|
|
|
|
|
|
ggml_backend_tensor_set(sctx->inp_uniform, &rnd, 0, sizeof(float));
|
|
|
|
|
}
|
|
|
|
|
@ -1232,10 +1459,24 @@ struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
|
|
|
|
|
return llama_sampler_init(
|
|
|
|
|
/* .iface = */ &llama_sampler_dist_i,
|
|
|
|
|
/* .ctx = */ new llama_sampler_dist {
|
|
|
|
|
("dist"),
|
|
|
|
|
{"dist"},
|
|
|
|
|
/* .seed = */ seed,
|
|
|
|
|
/* .seed_cur = */ seed_cur,
|
|
|
|
|
/* .rng = */ std::mt19937(seed_cur),
|
|
|
|
|
/* .rng = */ std::make_unique<llama_dist_rng_mt19937>(seed_cur),
|
|
|
|
|
/* .inp_uniform = */ nullptr,
|
|
|
|
|
}
|
|
|
|
|
);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
struct llama_sampler * llama_sampler_init_dist_blue_noise(uint32_t seed) {
|
|
|
|
|
auto seed_cur = get_rng_seed(seed);
|
|
|
|
|
return llama_sampler_init(
|
|
|
|
|
/* .iface = */ &llama_sampler_dist_i,
|
|
|
|
|
/* .ctx = */ new llama_sampler_dist {
|
|
|
|
|
{"dist-blue-noise"},
|
|
|
|
|
/* .seed = */ seed,
|
|
|
|
|
/* .seed_cur = */ seed_cur,
|
|
|
|
|
/* .rng = */ std::make_unique<llama_dist_rng_blue>(seed_cur),
|
|
|
|
|
/* .inp_uniform = */ nullptr,
|
|
|
|
|
}
|
|
|
|
|
);
|
|
|
|
|
|