llama : make the sampler rng modular
This commit is contained in:
parent
f271576d81
commit
3b4061981b
|
|
@ -432,6 +432,56 @@ struct blue_noise_rng {
|
|||
}
|
||||
};
|
||||
|
||||
// abstract RNG interface for the dist sampler
|
||||
struct llama_dist_rng {
|
||||
virtual ~llama_dist_rng() = default;
|
||||
|
||||
virtual double nextf() = 0; // uniform double in [0, 1)
|
||||
virtual void reseed(uint32_t s) = 0;
|
||||
virtual std::unique_ptr<llama_dist_rng> clone() const = 0;
|
||||
};
|
||||
|
||||
struct llama_dist_rng_white : llama_dist_rng {
|
||||
std::mt19937 rng;
|
||||
|
||||
llama_dist_rng_white(uint32_t seed) : rng(seed) {}
|
||||
|
||||
double nextf() override {
|
||||
std::uniform_real_distribution<double> dist(0.0, 1.0);
|
||||
return dist(rng);
|
||||
}
|
||||
|
||||
void reseed(uint32_t s) override {
|
||||
rng.seed(s);
|
||||
}
|
||||
|
||||
std::unique_ptr<llama_dist_rng> clone() const override {
|
||||
auto c = std::make_unique<llama_dist_rng_white>(0);
|
||||
c->rng = rng;
|
||||
return c;
|
||||
}
|
||||
};
|
||||
|
||||
struct llama_dist_rng_blue : llama_dist_rng {
|
||||
blue_noise_rng bn_rng;
|
||||
|
||||
llama_dist_rng_blue(uint32_t seed) : bn_rng(16, seed) {}
|
||||
|
||||
double nextf() override {
|
||||
return bn_rng.nextf();
|
||||
}
|
||||
|
||||
void reseed(uint32_t s) override {
|
||||
bn_rng.init(16, s);
|
||||
}
|
||||
|
||||
std::unique_ptr<llama_dist_rng> clone() const override {
|
||||
auto c = std::make_unique<llama_dist_rng_blue>(0);
|
||||
c->bn_rng = bn_rng;
|
||||
return c;
|
||||
}
|
||||
};
|
||||
|
||||
static uint32_t get_rng_seed(uint32_t seed) {
|
||||
if (seed == LLAMA_DEFAULT_SEED) {
|
||||
// use system clock if std::random_device is not a true RNG
|
||||
|
|
@ -1122,7 +1172,7 @@ struct llama_sampler_dist : public llama_sampler_backend {
|
|||
const uint32_t seed;
|
||||
uint32_t seed_cur;
|
||||
|
||||
std::mt19937 rng;
|
||||
std::unique_ptr<llama_dist_rng> rng;
|
||||
|
||||
ggml_tensor * inp_uniform;
|
||||
};
|
||||
|
|
@ -1168,8 +1218,7 @@ static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_da
|
|||
// sample from the obtained probabilities and normalize the probs in a single pass
|
||||
// this is ~3x faster on Mac with full gpt-oss vocab than the version below
|
||||
//
|
||||
std::uniform_real_distribution<double> dist(0.0f, 1.0f);
|
||||
const double rnd = dist(ctx->rng);
|
||||
const double rnd = ctx->rng->nextf();
|
||||
|
||||
double sum_run = 0.0f;
|
||||
const double sum_tgt = sum_cum*rnd;
|
||||
|
|
@ -1200,28 +1249,37 @@ static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_da
|
|||
cur_p->data[i].p /= sum_cum;
|
||||
}
|
||||
|
||||
cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
|
||||
const double rnd = ctx->rng->nextf();
|
||||
double cum = 0.0;
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
cum += cur_p->data[i].p;
|
||||
if (cum >= rnd) {
|
||||
cur_p->selected = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
static void llama_sampler_dist_reset(struct llama_sampler * smpl) {
|
||||
auto * ctx = (llama_sampler_dist *) smpl->ctx;
|
||||
ctx->seed_cur = get_rng_seed(ctx->seed);
|
||||
ctx->rng.seed(ctx->seed_cur);
|
||||
ctx->rng->reseed(ctx->seed_cur);
|
||||
}
|
||||
|
||||
static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) {
|
||||
const auto * ctx = (const llama_sampler_dist *) smpl->ctx;
|
||||
auto * result = llama_sampler_init_dist(ctx->seed);
|
||||
auto * ctx = (llama_sampler_dist *) smpl->ctx;
|
||||
|
||||
// copy the state
|
||||
{
|
||||
auto * result_ctx = (llama_sampler_dist *) result->ctx;
|
||||
|
||||
result_ctx->rng = ctx->rng;
|
||||
}
|
||||
|
||||
return result;
|
||||
return llama_sampler_init(
|
||||
/* .iface = */ smpl->iface,
|
||||
/* .ctx = */ new llama_sampler_dist {
|
||||
{ctx->get_name()},
|
||||
/* .seed = */ ctx->seed,
|
||||
/* .seed_cur = */ ctx->seed_cur,
|
||||
/* .rng = */ ctx->rng->clone(),
|
||||
/* .inp_uniform = */ nullptr,
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
static void llama_sampler_dist_free(struct llama_sampler * smpl) {
|
||||
|
|
@ -1307,8 +1365,8 @@ static void llama_sampler_dist_backend_set_input(struct llama_sampler * smpl) {
|
|||
// std::uniform_real_distribution<double> and
|
||||
// std::uniform_real_distribution<float> with same rng will produce
|
||||
// different sequences).
|
||||
std::uniform_real_distribution<double> dist(0.0f, 1.0f);
|
||||
const float rnd = dist(sctx->rng);
|
||||
// nextf returns double, equivalent to std::uniform_real_distribution<double>
|
||||
const float rnd = (float)sctx->rng->nextf();
|
||||
|
||||
ggml_backend_tensor_set(sctx->inp_uniform, &rnd, 0, sizeof(float));
|
||||
}
|
||||
|
|
@ -1331,201 +1389,24 @@ struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
|
|||
return llama_sampler_init(
|
||||
/* .iface = */ &llama_sampler_dist_i,
|
||||
/* .ctx = */ new llama_sampler_dist {
|
||||
("dist"),
|
||||
{"dist"},
|
||||
/* .seed = */ seed,
|
||||
/* .seed_cur = */ seed_cur,
|
||||
/* .rng = */ std::mt19937(seed_cur),
|
||||
/* .rng = */ std::make_unique<llama_dist_rng_white>(seed_cur),
|
||||
/* .inp_uniform = */ nullptr,
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
// dist (blue noise)
|
||||
|
||||
struct llama_sampler_dist_blue_noise : public llama_sampler_backend {
|
||||
const uint32_t seed;
|
||||
uint32_t seed_cur;
|
||||
|
||||
blue_noise_rng bn_rng;
|
||||
|
||||
ggml_tensor * inp_uniform;
|
||||
};
|
||||
|
||||
static const char * llama_sampler_dist_blue_noise_name(const struct llama_sampler * smpl) {
|
||||
auto * sctx = (llama_sampler_dist_blue_noise *) smpl->ctx;
|
||||
return sctx->get_name();
|
||||
}
|
||||
|
||||
static void llama_sampler_dist_blue_noise_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||
auto * ctx = (llama_sampler_dist_blue_noise *) smpl->ctx;
|
||||
|
||||
// edge cases
|
||||
if (cur_p->size == 0) {
|
||||
cur_p->selected = -1;
|
||||
return;
|
||||
}
|
||||
|
||||
cur_p->selected = 0;
|
||||
|
||||
if (cur_p->size == 1) {
|
||||
cur_p->data[0].p = 1.0f;
|
||||
return;
|
||||
}
|
||||
|
||||
// max logit for numerical stability
|
||||
float max_l = cur_p->data[0].logit;
|
||||
if (!cur_p->sorted) {
|
||||
for (size_t i = 1; i < cur_p->size; ++i) {
|
||||
max_l = std::max(max_l, cur_p->data[i].logit);
|
||||
}
|
||||
}
|
||||
|
||||
// apply softmax to obtain the probabilities
|
||||
double sum_cum = 0.0f;
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
float p = expf(cur_p->data[i].logit - max_l);
|
||||
cur_p->data[i].p = p;
|
||||
sum_cum += p;
|
||||
}
|
||||
|
||||
// sample using blue noise RNG
|
||||
const double rnd = ctx->bn_rng.nextf();
|
||||
|
||||
double sum_run = 0.0f;
|
||||
const double sum_tgt = sum_cum*rnd;
|
||||
|
||||
bool found = false;
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
if (!found) {
|
||||
sum_run += cur_p->data[i].p;
|
||||
if (sum_run >= sum_tgt) {
|
||||
cur_p->selected = i;
|
||||
found = true;
|
||||
}
|
||||
}
|
||||
|
||||
// normalize probs
|
||||
cur_p->data[i].p /= sum_cum;
|
||||
}
|
||||
|
||||
assert(found);
|
||||
if (!found) {
|
||||
cur_p->selected = cur_p->size - 1;
|
||||
}
|
||||
}
|
||||
|
||||
static void llama_sampler_dist_blue_noise_reset(struct llama_sampler * smpl) {
|
||||
auto * ctx = (llama_sampler_dist_blue_noise *) smpl->ctx;
|
||||
ctx->seed_cur = get_rng_seed(ctx->seed);
|
||||
ctx->bn_rng.init(16, ctx->seed_cur);
|
||||
}
|
||||
|
||||
static struct llama_sampler * llama_sampler_dist_blue_noise_clone(const struct llama_sampler * smpl) {
|
||||
const auto * ctx = (const llama_sampler_dist_blue_noise *) smpl->ctx;
|
||||
auto * result = llama_sampler_init_dist_blue_noise(ctx->seed);
|
||||
|
||||
// copy the state
|
||||
{
|
||||
auto * result_ctx = (llama_sampler_dist_blue_noise *) result->ctx;
|
||||
|
||||
result_ctx->seed_cur = ctx->seed_cur;
|
||||
result_ctx->bn_rng = ctx->bn_rng;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
static void llama_sampler_dist_blue_noise_free(struct llama_sampler * smpl) {
|
||||
delete (llama_sampler_dist_blue_noise *) smpl->ctx;
|
||||
}
|
||||
|
||||
static bool llama_sampler_dist_blue_noise_backend_init(
|
||||
struct llama_sampler * smpl,
|
||||
ggml_backend_buffer_type_t buft) {
|
||||
auto * sctx = (llama_sampler_dist_blue_noise *) smpl->ctx;
|
||||
|
||||
const bool res = llama_sampler_backend_support(smpl, buft);
|
||||
|
||||
sctx->init(res);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
static void llama_sampler_dist_blue_noise_backend_apply(
|
||||
struct llama_sampler * smpl,
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_cgraph * gf,
|
||||
struct llama_sampler_data * data) {
|
||||
GGML_UNUSED(gf);
|
||||
|
||||
auto * sctx = (llama_sampler_dist_blue_noise *) smpl->ctx;
|
||||
|
||||
sctx->inp_uniform = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
|
||||
ggml_set_name (sctx->inp_uniform, "uniform");
|
||||
ggml_set_input(sctx->inp_uniform);
|
||||
|
||||
struct ggml_tensor * probs = ggml_soft_max(ctx, data->logits);
|
||||
ggml_set_name(probs, "dist_probs");
|
||||
|
||||
struct ggml_tensor * cumsum = ggml_cumsum(ctx, probs);
|
||||
ggml_set_name(cumsum, "dist_cumsum");
|
||||
|
||||
struct ggml_tensor * diff = ggml_sub(ctx, cumsum, sctx->inp_uniform);
|
||||
ggml_set_name(diff, "dist_cumsum");
|
||||
|
||||
struct ggml_tensor * mask = ggml_step(ctx, diff);
|
||||
ggml_set_name(mask, "dist_mask");
|
||||
|
||||
struct ggml_tensor * idxf = ggml_sum(ctx, mask);
|
||||
ggml_set_name(idxf, "dist_index_f32");
|
||||
|
||||
struct ggml_tensor * idx = ggml_cast(ctx, ggml_scale_bias(ctx, idxf, -1.0f, mask->ne[0]), GGML_TYPE_I32);
|
||||
ggml_set_name(idx, "dist_index_i32");
|
||||
|
||||
struct ggml_tensor * sampled_token = idx;
|
||||
if (data->candidates != nullptr) {
|
||||
struct ggml_tensor * candidates = ggml_reshape_2d(ctx, data->candidates, 1, ggml_nelements(data->candidates));
|
||||
|
||||
sampled_token = ggml_get_rows(ctx, candidates, idx);
|
||||
ggml_set_name(sampled_token, "dist_sampled_token");
|
||||
}
|
||||
|
||||
data->sampled = sampled_token;
|
||||
data->probs = probs;
|
||||
}
|
||||
|
||||
static void llama_sampler_dist_blue_noise_backend_set_input(struct llama_sampler * smpl) {
|
||||
auto * sctx = (llama_sampler_dist_blue_noise *) smpl->ctx;
|
||||
|
||||
GGML_ASSERT(sctx->inp_uniform != nullptr);
|
||||
|
||||
const float rnd = (float)sctx->bn_rng.nextf();
|
||||
|
||||
ggml_backend_tensor_set(sctx->inp_uniform, &rnd, 0, sizeof(float));
|
||||
}
|
||||
|
||||
static struct llama_sampler_i llama_sampler_dist_blue_noise_i = {
|
||||
/* .name = */ llama_sampler_dist_blue_noise_name,
|
||||
/* .accept = */ nullptr,
|
||||
/* .apply = */ llama_sampler_dist_blue_noise_apply,
|
||||
/* .reset = */ llama_sampler_dist_blue_noise_reset,
|
||||
/* .clone = */ llama_sampler_dist_blue_noise_clone,
|
||||
/* .free = */ llama_sampler_dist_blue_noise_free,
|
||||
/* .backend_init = */ llama_sampler_dist_blue_noise_backend_init,
|
||||
/* .backend_accept = */ nullptr,
|
||||
/* .backend_apply = */ llama_sampler_dist_blue_noise_backend_apply,
|
||||
/* .backend_set_input = */ llama_sampler_dist_blue_noise_backend_set_input,
|
||||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_init_dist_blue_noise(uint32_t seed) {
|
||||
auto seed_cur = get_rng_seed(seed);
|
||||
return llama_sampler_init(
|
||||
/* .iface = */ &llama_sampler_dist_blue_noise_i,
|
||||
/* .ctx = */ new llama_sampler_dist_blue_noise {
|
||||
("dist-blue-noise"),
|
||||
/* .iface = */ &llama_sampler_dist_i,
|
||||
/* .ctx = */ new llama_sampler_dist {
|
||||
{"dist-blue-noise"},
|
||||
/* .seed = */ seed,
|
||||
/* .seed_cur = */ seed_cur,
|
||||
/* .bn_rng = */ blue_noise_rng(16, seed_cur),
|
||||
/* .rng = */ std::make_unique<llama_dist_rng_blue>(seed_cur),
|
||||
/* .inp_uniform = */ nullptr,
|
||||
}
|
||||
);
|
||||
|
|
@ -4119,10 +4000,6 @@ uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) {
|
|||
return ((const llama_sampler_dist *) smpl->ctx)->seed_cur;
|
||||
}
|
||||
|
||||
if (smpl->iface == &llama_sampler_dist_blue_noise_i) {
|
||||
return ((const llama_sampler_dist_blue_noise *) smpl->ctx)->seed_cur;
|
||||
}
|
||||
|
||||
if (smpl->iface == &llama_sampler_mirostat_i) {
|
||||
return ((const llama_sampler_mirostat *) smpl->ctx)->seed_cur;
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue