diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 09fd3a4700..c41666aaa7 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -432,6 +432,56 @@ struct blue_noise_rng { } }; +// abstract RNG interface for the dist sampler +struct llama_dist_rng { + virtual ~llama_dist_rng() = default; + + virtual double nextf() = 0; // uniform double in [0, 1) + virtual void reseed(uint32_t s) = 0; + virtual std::unique_ptr clone() const = 0; +}; + +struct llama_dist_rng_white : llama_dist_rng { + std::mt19937 rng; + + llama_dist_rng_white(uint32_t seed) : rng(seed) {} + + double nextf() override { + std::uniform_real_distribution dist(0.0, 1.0); + return dist(rng); + } + + void reseed(uint32_t s) override { + rng.seed(s); + } + + std::unique_ptr clone() const override { + auto c = std::make_unique(0); + c->rng = rng; + return c; + } +}; + +struct llama_dist_rng_blue : llama_dist_rng { + blue_noise_rng bn_rng; + + llama_dist_rng_blue(uint32_t seed) : bn_rng(16, seed) {} + + double nextf() override { + return bn_rng.nextf(); + } + + void reseed(uint32_t s) override { + bn_rng.init(16, s); + } + + std::unique_ptr clone() const override { + auto c = std::make_unique(0); + c->bn_rng = bn_rng; + return c; + } +}; + static uint32_t get_rng_seed(uint32_t seed) { if (seed == LLAMA_DEFAULT_SEED) { // use system clock if std::random_device is not a true RNG @@ -1122,7 +1172,7 @@ struct llama_sampler_dist : public llama_sampler_backend { const uint32_t seed; uint32_t seed_cur; - std::mt19937 rng; + std::unique_ptr rng; ggml_tensor * inp_uniform; }; @@ -1168,8 +1218,7 @@ static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_da // sample from the obtained probabilities and normalize the probs in a single pass // this is ~3x faster on Mac with full gpt-oss vocab than the version below // - std::uniform_real_distribution dist(0.0f, 1.0f); - const double rnd = dist(ctx->rng); + const double rnd = ctx->rng->nextf(); double sum_run = 0.0f; const double sum_tgt = sum_cum*rnd; @@ -1200,28 +1249,37 @@ static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_da cur_p->data[i].p /= sum_cum; } - cur_p->selected = llama_sample_dist(cur_p, ctx->rng); + const double rnd = ctx->rng->nextf(); + double cum = 0.0; + for (size_t i = 0; i < cur_p->size; ++i) { + cum += cur_p->data[i].p; + if (cum >= rnd) { + cur_p->selected = i; + break; + } + } #endif } static void llama_sampler_dist_reset(struct llama_sampler * smpl) { auto * ctx = (llama_sampler_dist *) smpl->ctx; ctx->seed_cur = get_rng_seed(ctx->seed); - ctx->rng.seed(ctx->seed_cur); + ctx->rng->reseed(ctx->seed_cur); } static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) { - const auto * ctx = (const llama_sampler_dist *) smpl->ctx; - auto * result = llama_sampler_init_dist(ctx->seed); + auto * ctx = (llama_sampler_dist *) smpl->ctx; - // copy the state - { - auto * result_ctx = (llama_sampler_dist *) result->ctx; - - result_ctx->rng = ctx->rng; - } - - return result; + return llama_sampler_init( + /* .iface = */ smpl->iface, + /* .ctx = */ new llama_sampler_dist { + {ctx->get_name()}, + /* .seed = */ ctx->seed, + /* .seed_cur = */ ctx->seed_cur, + /* .rng = */ ctx->rng->clone(), + /* .inp_uniform = */ nullptr, + } + ); } static void llama_sampler_dist_free(struct llama_sampler * smpl) { @@ -1307,8 +1365,8 @@ static void llama_sampler_dist_backend_set_input(struct llama_sampler * smpl) { // std::uniform_real_distribution and // std::uniform_real_distribution with same rng will produce // different sequences). - std::uniform_real_distribution dist(0.0f, 1.0f); - const float rnd = dist(sctx->rng); + // nextf returns double, equivalent to std::uniform_real_distribution + const float rnd = (float)sctx->rng->nextf(); ggml_backend_tensor_set(sctx->inp_uniform, &rnd, 0, sizeof(float)); } @@ -1331,201 +1389,24 @@ struct llama_sampler * llama_sampler_init_dist(uint32_t seed) { return llama_sampler_init( /* .iface = */ &llama_sampler_dist_i, /* .ctx = */ new llama_sampler_dist { - ("dist"), + {"dist"}, /* .seed = */ seed, /* .seed_cur = */ seed_cur, - /* .rng = */ std::mt19937(seed_cur), + /* .rng = */ std::make_unique(seed_cur), /* .inp_uniform = */ nullptr, } ); } -// dist (blue noise) - -struct llama_sampler_dist_blue_noise : public llama_sampler_backend { - const uint32_t seed; - uint32_t seed_cur; - - blue_noise_rng bn_rng; - - ggml_tensor * inp_uniform; -}; - -static const char * llama_sampler_dist_blue_noise_name(const struct llama_sampler * smpl) { - auto * sctx = (llama_sampler_dist_blue_noise *) smpl->ctx; - return sctx->get_name(); -} - -static void llama_sampler_dist_blue_noise_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { - auto * ctx = (llama_sampler_dist_blue_noise *) smpl->ctx; - - // edge cases - if (cur_p->size == 0) { - cur_p->selected = -1; - return; - } - - cur_p->selected = 0; - - if (cur_p->size == 1) { - cur_p->data[0].p = 1.0f; - return; - } - - // max logit for numerical stability - float max_l = cur_p->data[0].logit; - if (!cur_p->sorted) { - for (size_t i = 1; i < cur_p->size; ++i) { - max_l = std::max(max_l, cur_p->data[i].logit); - } - } - - // apply softmax to obtain the probabilities - double sum_cum = 0.0f; - for (size_t i = 0; i < cur_p->size; ++i) { - float p = expf(cur_p->data[i].logit - max_l); - cur_p->data[i].p = p; - sum_cum += p; - } - - // sample using blue noise RNG - const double rnd = ctx->bn_rng.nextf(); - - double sum_run = 0.0f; - const double sum_tgt = sum_cum*rnd; - - bool found = false; - for (size_t i = 0; i < cur_p->size; ++i) { - if (!found) { - sum_run += cur_p->data[i].p; - if (sum_run >= sum_tgt) { - cur_p->selected = i; - found = true; - } - } - - // normalize probs - cur_p->data[i].p /= sum_cum; - } - - assert(found); - if (!found) { - cur_p->selected = cur_p->size - 1; - } -} - -static void llama_sampler_dist_blue_noise_reset(struct llama_sampler * smpl) { - auto * ctx = (llama_sampler_dist_blue_noise *) smpl->ctx; - ctx->seed_cur = get_rng_seed(ctx->seed); - ctx->bn_rng.init(16, ctx->seed_cur); -} - -static struct llama_sampler * llama_sampler_dist_blue_noise_clone(const struct llama_sampler * smpl) { - const auto * ctx = (const llama_sampler_dist_blue_noise *) smpl->ctx; - auto * result = llama_sampler_init_dist_blue_noise(ctx->seed); - - // copy the state - { - auto * result_ctx = (llama_sampler_dist_blue_noise *) result->ctx; - - result_ctx->seed_cur = ctx->seed_cur; - result_ctx->bn_rng = ctx->bn_rng; - } - - return result; -} - -static void llama_sampler_dist_blue_noise_free(struct llama_sampler * smpl) { - delete (llama_sampler_dist_blue_noise *) smpl->ctx; -} - -static bool llama_sampler_dist_blue_noise_backend_init( - struct llama_sampler * smpl, - ggml_backend_buffer_type_t buft) { - auto * sctx = (llama_sampler_dist_blue_noise *) smpl->ctx; - - const bool res = llama_sampler_backend_support(smpl, buft); - - sctx->init(res); - - return res; -} - -static void llama_sampler_dist_blue_noise_backend_apply( - struct llama_sampler * smpl, - struct ggml_context * ctx, - struct ggml_cgraph * gf, - struct llama_sampler_data * data) { - GGML_UNUSED(gf); - - auto * sctx = (llama_sampler_dist_blue_noise *) smpl->ctx; - - sctx->inp_uniform = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); - ggml_set_name (sctx->inp_uniform, "uniform"); - ggml_set_input(sctx->inp_uniform); - - struct ggml_tensor * probs = ggml_soft_max(ctx, data->logits); - ggml_set_name(probs, "dist_probs"); - - struct ggml_tensor * cumsum = ggml_cumsum(ctx, probs); - ggml_set_name(cumsum, "dist_cumsum"); - - struct ggml_tensor * diff = ggml_sub(ctx, cumsum, sctx->inp_uniform); - ggml_set_name(diff, "dist_cumsum"); - - struct ggml_tensor * mask = ggml_step(ctx, diff); - ggml_set_name(mask, "dist_mask"); - - struct ggml_tensor * idxf = ggml_sum(ctx, mask); - ggml_set_name(idxf, "dist_index_f32"); - - struct ggml_tensor * idx = ggml_cast(ctx, ggml_scale_bias(ctx, idxf, -1.0f, mask->ne[0]), GGML_TYPE_I32); - ggml_set_name(idx, "dist_index_i32"); - - struct ggml_tensor * sampled_token = idx; - if (data->candidates != nullptr) { - struct ggml_tensor * candidates = ggml_reshape_2d(ctx, data->candidates, 1, ggml_nelements(data->candidates)); - - sampled_token = ggml_get_rows(ctx, candidates, idx); - ggml_set_name(sampled_token, "dist_sampled_token"); - } - - data->sampled = sampled_token; - data->probs = probs; -} - -static void llama_sampler_dist_blue_noise_backend_set_input(struct llama_sampler * smpl) { - auto * sctx = (llama_sampler_dist_blue_noise *) smpl->ctx; - - GGML_ASSERT(sctx->inp_uniform != nullptr); - - const float rnd = (float)sctx->bn_rng.nextf(); - - ggml_backend_tensor_set(sctx->inp_uniform, &rnd, 0, sizeof(float)); -} - -static struct llama_sampler_i llama_sampler_dist_blue_noise_i = { - /* .name = */ llama_sampler_dist_blue_noise_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_dist_blue_noise_apply, - /* .reset = */ llama_sampler_dist_blue_noise_reset, - /* .clone = */ llama_sampler_dist_blue_noise_clone, - /* .free = */ llama_sampler_dist_blue_noise_free, - /* .backend_init = */ llama_sampler_dist_blue_noise_backend_init, - /* .backend_accept = */ nullptr, - /* .backend_apply = */ llama_sampler_dist_blue_noise_backend_apply, - /* .backend_set_input = */ llama_sampler_dist_blue_noise_backend_set_input, -}; - struct llama_sampler * llama_sampler_init_dist_blue_noise(uint32_t seed) { auto seed_cur = get_rng_seed(seed); return llama_sampler_init( - /* .iface = */ &llama_sampler_dist_blue_noise_i, - /* .ctx = */ new llama_sampler_dist_blue_noise { - ("dist-blue-noise"), + /* .iface = */ &llama_sampler_dist_i, + /* .ctx = */ new llama_sampler_dist { + {"dist-blue-noise"}, /* .seed = */ seed, /* .seed_cur = */ seed_cur, - /* .bn_rng = */ blue_noise_rng(16, seed_cur), + /* .rng = */ std::make_unique(seed_cur), /* .inp_uniform = */ nullptr, } ); @@ -4119,10 +4000,6 @@ uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) { return ((const llama_sampler_dist *) smpl->ctx)->seed_cur; } - if (smpl->iface == &llama_sampler_dist_blue_noise_i) { - return ((const llama_sampler_dist_blue_noise *) smpl->ctx)->seed_cur; - } - if (smpl->iface == &llama_sampler_mirostat_i) { return ((const llama_sampler_mirostat *) smpl->ctx)->seed_cur; }