llama : make the sampler rng modular

This commit is contained in:
Jan Boon 2026-02-04 23:05:54 +00:00
parent f271576d81
commit 3b4061981b
1 changed files with 81 additions and 204 deletions

View File

@ -432,6 +432,56 @@ struct blue_noise_rng {
}
};
// abstract RNG interface for the dist sampler
struct llama_dist_rng {
virtual ~llama_dist_rng() = default;
virtual double nextf() = 0; // uniform double in [0, 1)
virtual void reseed(uint32_t s) = 0;
virtual std::unique_ptr<llama_dist_rng> clone() const = 0;
};
struct llama_dist_rng_white : llama_dist_rng {
std::mt19937 rng;
llama_dist_rng_white(uint32_t seed) : rng(seed) {}
double nextf() override {
std::uniform_real_distribution<double> dist(0.0, 1.0);
return dist(rng);
}
void reseed(uint32_t s) override {
rng.seed(s);
}
std::unique_ptr<llama_dist_rng> clone() const override {
auto c = std::make_unique<llama_dist_rng_white>(0);
c->rng = rng;
return c;
}
};
struct llama_dist_rng_blue : llama_dist_rng {
blue_noise_rng bn_rng;
llama_dist_rng_blue(uint32_t seed) : bn_rng(16, seed) {}
double nextf() override {
return bn_rng.nextf();
}
void reseed(uint32_t s) override {
bn_rng.init(16, s);
}
std::unique_ptr<llama_dist_rng> clone() const override {
auto c = std::make_unique<llama_dist_rng_blue>(0);
c->bn_rng = bn_rng;
return c;
}
};
static uint32_t get_rng_seed(uint32_t seed) {
if (seed == LLAMA_DEFAULT_SEED) {
// use system clock if std::random_device is not a true RNG
@ -1122,7 +1172,7 @@ struct llama_sampler_dist : public llama_sampler_backend {
const uint32_t seed;
uint32_t seed_cur;
std::mt19937 rng;
std::unique_ptr<llama_dist_rng> rng;
ggml_tensor * inp_uniform;
};
@ -1168,8 +1218,7 @@ static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_da
// sample from the obtained probabilities and normalize the probs in a single pass
// this is ~3x faster on Mac with full gpt-oss vocab than the version below
//
std::uniform_real_distribution<double> dist(0.0f, 1.0f);
const double rnd = dist(ctx->rng);
const double rnd = ctx->rng->nextf();
double sum_run = 0.0f;
const double sum_tgt = sum_cum*rnd;
@ -1200,28 +1249,37 @@ static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_da
cur_p->data[i].p /= sum_cum;
}
cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
const double rnd = ctx->rng->nextf();
double cum = 0.0;
for (size_t i = 0; i < cur_p->size; ++i) {
cum += cur_p->data[i].p;
if (cum >= rnd) {
cur_p->selected = i;
break;
}
}
#endif
}
static void llama_sampler_dist_reset(struct llama_sampler * smpl) {
auto * ctx = (llama_sampler_dist *) smpl->ctx;
ctx->seed_cur = get_rng_seed(ctx->seed);
ctx->rng.seed(ctx->seed_cur);
ctx->rng->reseed(ctx->seed_cur);
}
static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_dist *) smpl->ctx;
auto * result = llama_sampler_init_dist(ctx->seed);
auto * ctx = (llama_sampler_dist *) smpl->ctx;
// copy the state
{
auto * result_ctx = (llama_sampler_dist *) result->ctx;
result_ctx->rng = ctx->rng;
}
return result;
return llama_sampler_init(
/* .iface = */ smpl->iface,
/* .ctx = */ new llama_sampler_dist {
{ctx->get_name()},
/* .seed = */ ctx->seed,
/* .seed_cur = */ ctx->seed_cur,
/* .rng = */ ctx->rng->clone(),
/* .inp_uniform = */ nullptr,
}
);
}
static void llama_sampler_dist_free(struct llama_sampler * smpl) {
@ -1307,8 +1365,8 @@ static void llama_sampler_dist_backend_set_input(struct llama_sampler * smpl) {
// std::uniform_real_distribution<double> and
// std::uniform_real_distribution<float> with same rng will produce
// different sequences).
std::uniform_real_distribution<double> dist(0.0f, 1.0f);
const float rnd = dist(sctx->rng);
// nextf returns double, equivalent to std::uniform_real_distribution<double>
const float rnd = (float)sctx->rng->nextf();
ggml_backend_tensor_set(sctx->inp_uniform, &rnd, 0, sizeof(float));
}
@ -1331,201 +1389,24 @@ struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
return llama_sampler_init(
/* .iface = */ &llama_sampler_dist_i,
/* .ctx = */ new llama_sampler_dist {
("dist"),
{"dist"},
/* .seed = */ seed,
/* .seed_cur = */ seed_cur,
/* .rng = */ std::mt19937(seed_cur),
/* .rng = */ std::make_unique<llama_dist_rng_white>(seed_cur),
/* .inp_uniform = */ nullptr,
}
);
}
// dist (blue noise)
struct llama_sampler_dist_blue_noise : public llama_sampler_backend {
const uint32_t seed;
uint32_t seed_cur;
blue_noise_rng bn_rng;
ggml_tensor * inp_uniform;
};
static const char * llama_sampler_dist_blue_noise_name(const struct llama_sampler * smpl) {
auto * sctx = (llama_sampler_dist_blue_noise *) smpl->ctx;
return sctx->get_name();
}
static void llama_sampler_dist_blue_noise_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
auto * ctx = (llama_sampler_dist_blue_noise *) smpl->ctx;
// edge cases
if (cur_p->size == 0) {
cur_p->selected = -1;
return;
}
cur_p->selected = 0;
if (cur_p->size == 1) {
cur_p->data[0].p = 1.0f;
return;
}
// max logit for numerical stability
float max_l = cur_p->data[0].logit;
if (!cur_p->sorted) {
for (size_t i = 1; i < cur_p->size; ++i) {
max_l = std::max(max_l, cur_p->data[i].logit);
}
}
// apply softmax to obtain the probabilities
double sum_cum = 0.0f;
for (size_t i = 0; i < cur_p->size; ++i) {
float p = expf(cur_p->data[i].logit - max_l);
cur_p->data[i].p = p;
sum_cum += p;
}
// sample using blue noise RNG
const double rnd = ctx->bn_rng.nextf();
double sum_run = 0.0f;
const double sum_tgt = sum_cum*rnd;
bool found = false;
for (size_t i = 0; i < cur_p->size; ++i) {
if (!found) {
sum_run += cur_p->data[i].p;
if (sum_run >= sum_tgt) {
cur_p->selected = i;
found = true;
}
}
// normalize probs
cur_p->data[i].p /= sum_cum;
}
assert(found);
if (!found) {
cur_p->selected = cur_p->size - 1;
}
}
static void llama_sampler_dist_blue_noise_reset(struct llama_sampler * smpl) {
auto * ctx = (llama_sampler_dist_blue_noise *) smpl->ctx;
ctx->seed_cur = get_rng_seed(ctx->seed);
ctx->bn_rng.init(16, ctx->seed_cur);
}
static struct llama_sampler * llama_sampler_dist_blue_noise_clone(const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_dist_blue_noise *) smpl->ctx;
auto * result = llama_sampler_init_dist_blue_noise(ctx->seed);
// copy the state
{
auto * result_ctx = (llama_sampler_dist_blue_noise *) result->ctx;
result_ctx->seed_cur = ctx->seed_cur;
result_ctx->bn_rng = ctx->bn_rng;
}
return result;
}
static void llama_sampler_dist_blue_noise_free(struct llama_sampler * smpl) {
delete (llama_sampler_dist_blue_noise *) smpl->ctx;
}
static bool llama_sampler_dist_blue_noise_backend_init(
struct llama_sampler * smpl,
ggml_backend_buffer_type_t buft) {
auto * sctx = (llama_sampler_dist_blue_noise *) smpl->ctx;
const bool res = llama_sampler_backend_support(smpl, buft);
sctx->init(res);
return res;
}
static void llama_sampler_dist_blue_noise_backend_apply(
struct llama_sampler * smpl,
struct ggml_context * ctx,
struct ggml_cgraph * gf,
struct llama_sampler_data * data) {
GGML_UNUSED(gf);
auto * sctx = (llama_sampler_dist_blue_noise *) smpl->ctx;
sctx->inp_uniform = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
ggml_set_name (sctx->inp_uniform, "uniform");
ggml_set_input(sctx->inp_uniform);
struct ggml_tensor * probs = ggml_soft_max(ctx, data->logits);
ggml_set_name(probs, "dist_probs");
struct ggml_tensor * cumsum = ggml_cumsum(ctx, probs);
ggml_set_name(cumsum, "dist_cumsum");
struct ggml_tensor * diff = ggml_sub(ctx, cumsum, sctx->inp_uniform);
ggml_set_name(diff, "dist_cumsum");
struct ggml_tensor * mask = ggml_step(ctx, diff);
ggml_set_name(mask, "dist_mask");
struct ggml_tensor * idxf = ggml_sum(ctx, mask);
ggml_set_name(idxf, "dist_index_f32");
struct ggml_tensor * idx = ggml_cast(ctx, ggml_scale_bias(ctx, idxf, -1.0f, mask->ne[0]), GGML_TYPE_I32);
ggml_set_name(idx, "dist_index_i32");
struct ggml_tensor * sampled_token = idx;
if (data->candidates != nullptr) {
struct ggml_tensor * candidates = ggml_reshape_2d(ctx, data->candidates, 1, ggml_nelements(data->candidates));
sampled_token = ggml_get_rows(ctx, candidates, idx);
ggml_set_name(sampled_token, "dist_sampled_token");
}
data->sampled = sampled_token;
data->probs = probs;
}
static void llama_sampler_dist_blue_noise_backend_set_input(struct llama_sampler * smpl) {
auto * sctx = (llama_sampler_dist_blue_noise *) smpl->ctx;
GGML_ASSERT(sctx->inp_uniform != nullptr);
const float rnd = (float)sctx->bn_rng.nextf();
ggml_backend_tensor_set(sctx->inp_uniform, &rnd, 0, sizeof(float));
}
static struct llama_sampler_i llama_sampler_dist_blue_noise_i = {
/* .name = */ llama_sampler_dist_blue_noise_name,
/* .accept = */ nullptr,
/* .apply = */ llama_sampler_dist_blue_noise_apply,
/* .reset = */ llama_sampler_dist_blue_noise_reset,
/* .clone = */ llama_sampler_dist_blue_noise_clone,
/* .free = */ llama_sampler_dist_blue_noise_free,
/* .backend_init = */ llama_sampler_dist_blue_noise_backend_init,
/* .backend_accept = */ nullptr,
/* .backend_apply = */ llama_sampler_dist_blue_noise_backend_apply,
/* .backend_set_input = */ llama_sampler_dist_blue_noise_backend_set_input,
};
struct llama_sampler * llama_sampler_init_dist_blue_noise(uint32_t seed) {
auto seed_cur = get_rng_seed(seed);
return llama_sampler_init(
/* .iface = */ &llama_sampler_dist_blue_noise_i,
/* .ctx = */ new llama_sampler_dist_blue_noise {
("dist-blue-noise"),
/* .iface = */ &llama_sampler_dist_i,
/* .ctx = */ new llama_sampler_dist {
{"dist-blue-noise"},
/* .seed = */ seed,
/* .seed_cur = */ seed_cur,
/* .bn_rng = */ blue_noise_rng(16, seed_cur),
/* .rng = */ std::make_unique<llama_dist_rng_blue>(seed_cur),
/* .inp_uniform = */ nullptr,
}
);
@ -4119,10 +4000,6 @@ uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) {
return ((const llama_sampler_dist *) smpl->ctx)->seed_cur;
}
if (smpl->iface == &llama_sampler_dist_blue_noise_i) {
return ((const llama_sampler_dist_blue_noise *) smpl->ctx)->seed_cur;
}
if (smpl->iface == &llama_sampler_mirostat_i) {
return ((const llama_sampler_mirostat *) smpl->ctx)->seed_cur;
}