From f271576d81ca920d5d35a76f44a663da47608adb Mon Sep 17 00:00:00 2001 From: Jan Boon Date: Wed, 4 Feb 2026 19:12:50 +0000 Subject: [PATCH] llama : initial blue noise test implementation --- common/arg.cpp | 7 ++ common/common.h | 1 + common/sampling.cpp | 6 +- include/llama.h | 3 +- src/llama-sampling.cpp | 197 ++++++++++++++++++++++++++++++++++- tools/server/server-task.cpp | 3 + 6 files changed, 214 insertions(+), 3 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 5fbc9022c0..924b5198a2 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1577,6 +1577,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.sampling.ignore_eos = true; } ).set_sparam()); + add_opt(common_arg( + {"--blue-noise"}, + "use blue noise RNG for sampling instead of white noise", + [](common_params & params) { + params.sampling.blue_noise = true; + } + ).set_sparam()); add_opt(common_arg( {"--temp"}, "N", string_format("temperature (default: %.2f)", (double)params.sampling.temp), diff --git a/common/common.h b/common/common.h index 398ebb0960..0a76a1e26c 100644 --- a/common/common.h +++ b/common/common.h @@ -209,6 +209,7 @@ struct common_params_sampling { bool ignore_eos = false; bool no_perf = false; // disable performance metrics bool timing_per_token = false; + bool blue_noise = false; // use blue noise RNG instead of white noise for dist sampler uint64_t user_sampling_config = 0; // bitfield to track user-specified samplers diff --git a/common/sampling.cpp b/common/sampling.cpp index 11a1d48398..2811eb3a48 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -313,7 +313,11 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st samplers.push_back(llama_sampler_init_adaptive_p(params.adaptive_target, params.adaptive_decay, params.seed)); } else { // default: sample from distribution - samplers.push_back(llama_sampler_init_dist(params.seed)); + if (params.blue_noise) { + samplers.push_back(llama_sampler_init_dist_blue_noise(params.seed)); + } else { + samplers.push_back(llama_sampler_init_dist(params.seed)); + } } } else if (params.mirostat == 1) { samplers.push_back(llama_sampler_init_temp(params.temp)); diff --git a/include/llama.h b/include/llama.h index bf4e28a8be..22f08e1683 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1295,7 +1295,8 @@ extern "C" { LLAMA_API struct llama_sampler * llama_sampler_init_greedy(void); /// seed == LLAMA_DEFAULT_SEED to use a random seed. - LLAMA_API struct llama_sampler * llama_sampler_init_dist(uint32_t seed); + LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed); + LLAMA_API struct llama_sampler * llama_sampler_init_dist_blue_noise(uint32_t seed); /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 /// Setting k <= 0 makes this a noop diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 7c83095582..09fd3a4700 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -369,7 +369,7 @@ struct blue_noise_rng { const int n = (int)states.size(); position = 0; - // 5 reachable states with stationary distribution 3:3:2:1:1 (out of 10) + // 5 reachable states with distribution 3:3:2:1:1 static const int8_t tbl[10][2] = { { 0, 0}, { 0, 0}, { 0, 0}, {-1, 0}, {-1, 0}, {-1, 0}, @@ -1340,6 +1340,197 @@ struct llama_sampler * llama_sampler_init_dist(uint32_t seed) { ); } +// dist (blue noise) + +struct llama_sampler_dist_blue_noise : public llama_sampler_backend { + const uint32_t seed; + uint32_t seed_cur; + + blue_noise_rng bn_rng; + + ggml_tensor * inp_uniform; +}; + +static const char * llama_sampler_dist_blue_noise_name(const struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_dist_blue_noise *) smpl->ctx; + return sctx->get_name(); +} + +static void llama_sampler_dist_blue_noise_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_dist_blue_noise *) smpl->ctx; + + // edge cases + if (cur_p->size == 0) { + cur_p->selected = -1; + return; + } + + cur_p->selected = 0; + + if (cur_p->size == 1) { + cur_p->data[0].p = 1.0f; + return; + } + + // max logit for numerical stability + float max_l = cur_p->data[0].logit; + if (!cur_p->sorted) { + for (size_t i = 1; i < cur_p->size; ++i) { + max_l = std::max(max_l, cur_p->data[i].logit); + } + } + + // apply softmax to obtain the probabilities + double sum_cum = 0.0f; + for (size_t i = 0; i < cur_p->size; ++i) { + float p = expf(cur_p->data[i].logit - max_l); + cur_p->data[i].p = p; + sum_cum += p; + } + + // sample using blue noise RNG + const double rnd = ctx->bn_rng.nextf(); + + double sum_run = 0.0f; + const double sum_tgt = sum_cum*rnd; + + bool found = false; + for (size_t i = 0; i < cur_p->size; ++i) { + if (!found) { + sum_run += cur_p->data[i].p; + if (sum_run >= sum_tgt) { + cur_p->selected = i; + found = true; + } + } + + // normalize probs + cur_p->data[i].p /= sum_cum; + } + + assert(found); + if (!found) { + cur_p->selected = cur_p->size - 1; + } +} + +static void llama_sampler_dist_blue_noise_reset(struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_dist_blue_noise *) smpl->ctx; + ctx->seed_cur = get_rng_seed(ctx->seed); + ctx->bn_rng.init(16, ctx->seed_cur); +} + +static struct llama_sampler * llama_sampler_dist_blue_noise_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_dist_blue_noise *) smpl->ctx; + auto * result = llama_sampler_init_dist_blue_noise(ctx->seed); + + // copy the state + { + auto * result_ctx = (llama_sampler_dist_blue_noise *) result->ctx; + + result_ctx->seed_cur = ctx->seed_cur; + result_ctx->bn_rng = ctx->bn_rng; + } + + return result; +} + +static void llama_sampler_dist_blue_noise_free(struct llama_sampler * smpl) { + delete (llama_sampler_dist_blue_noise *) smpl->ctx; +} + +static bool llama_sampler_dist_blue_noise_backend_init( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + auto * sctx = (llama_sampler_dist_blue_noise *) smpl->ctx; + + const bool res = llama_sampler_backend_support(smpl, buft); + + sctx->init(res); + + return res; +} + +static void llama_sampler_dist_blue_noise_backend_apply( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_data * data) { + GGML_UNUSED(gf); + + auto * sctx = (llama_sampler_dist_blue_noise *) smpl->ctx; + + sctx->inp_uniform = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); + ggml_set_name (sctx->inp_uniform, "uniform"); + ggml_set_input(sctx->inp_uniform); + + struct ggml_tensor * probs = ggml_soft_max(ctx, data->logits); + ggml_set_name(probs, "dist_probs"); + + struct ggml_tensor * cumsum = ggml_cumsum(ctx, probs); + ggml_set_name(cumsum, "dist_cumsum"); + + struct ggml_tensor * diff = ggml_sub(ctx, cumsum, sctx->inp_uniform); + ggml_set_name(diff, "dist_cumsum"); + + struct ggml_tensor * mask = ggml_step(ctx, diff); + ggml_set_name(mask, "dist_mask"); + + struct ggml_tensor * idxf = ggml_sum(ctx, mask); + ggml_set_name(idxf, "dist_index_f32"); + + struct ggml_tensor * idx = ggml_cast(ctx, ggml_scale_bias(ctx, idxf, -1.0f, mask->ne[0]), GGML_TYPE_I32); + ggml_set_name(idx, "dist_index_i32"); + + struct ggml_tensor * sampled_token = idx; + if (data->candidates != nullptr) { + struct ggml_tensor * candidates = ggml_reshape_2d(ctx, data->candidates, 1, ggml_nelements(data->candidates)); + + sampled_token = ggml_get_rows(ctx, candidates, idx); + ggml_set_name(sampled_token, "dist_sampled_token"); + } + + data->sampled = sampled_token; + data->probs = probs; +} + +static void llama_sampler_dist_blue_noise_backend_set_input(struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_dist_blue_noise *) smpl->ctx; + + GGML_ASSERT(sctx->inp_uniform != nullptr); + + const float rnd = (float)sctx->bn_rng.nextf(); + + ggml_backend_tensor_set(sctx->inp_uniform, &rnd, 0, sizeof(float)); +} + +static struct llama_sampler_i llama_sampler_dist_blue_noise_i = { + /* .name = */ llama_sampler_dist_blue_noise_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_dist_blue_noise_apply, + /* .reset = */ llama_sampler_dist_blue_noise_reset, + /* .clone = */ llama_sampler_dist_blue_noise_clone, + /* .free = */ llama_sampler_dist_blue_noise_free, + /* .backend_init = */ llama_sampler_dist_blue_noise_backend_init, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ llama_sampler_dist_blue_noise_backend_apply, + /* .backend_set_input = */ llama_sampler_dist_blue_noise_backend_set_input, +}; + +struct llama_sampler * llama_sampler_init_dist_blue_noise(uint32_t seed) { + auto seed_cur = get_rng_seed(seed); + return llama_sampler_init( + /* .iface = */ &llama_sampler_dist_blue_noise_i, + /* .ctx = */ new llama_sampler_dist_blue_noise { + ("dist-blue-noise"), + /* .seed = */ seed, + /* .seed_cur = */ seed_cur, + /* .bn_rng = */ blue_noise_rng(16, seed_cur), + /* .inp_uniform = */ nullptr, + } + ); +} + // top-k struct llama_sampler_top_k : public llama_sampler_backend { @@ -3928,6 +4119,10 @@ uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) { return ((const llama_sampler_dist *) smpl->ctx)->seed_cur; } + if (smpl->iface == &llama_sampler_dist_blue_noise_i) { + return ((const llama_sampler_dist_blue_noise *) smpl->ctx)->seed_cur; + } + if (smpl->iface == &llama_sampler_mirostat_i) { return ((const llama_sampler_mirostat *) smpl->ctx)->seed_cur; } diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index 2d25db63b7..16c3cf12d0 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -66,6 +66,7 @@ json task_params::to_json(bool only_metrics) const { {"n_keep", n_keep}, {"n_discard", n_discard}, {"ignore_eos", sampling.ignore_eos}, + {"blue_noise", sampling.blue_noise}, {"stream", stream}, {"n_probs", sampling.n_probs}, {"min_keep", sampling.min_keep}, @@ -125,6 +126,7 @@ json task_params::to_json(bool only_metrics) const { {"n_keep", n_keep}, {"n_discard", n_discard}, {"ignore_eos", sampling.ignore_eos}, + {"blue_noise", sampling.blue_noise}, {"stream", stream}, {"logit_bias", format_logit_bias(sampling.logit_bias)}, {"n_probs", sampling.n_probs}, @@ -467,6 +469,7 @@ task_params server_task::params_from_json_cmpl( } } + params.sampling.blue_noise = json_value(data, "blue_noise", params_base.sampling.blue_noise); params.sampling.ignore_eos = json_value(data, "ignore_eos", params_base.sampling.ignore_eos); if (params.sampling.ignore_eos) { params.sampling.logit_bias.insert(