From 3b3f5fed31ab654c4ed858b98b70df1b9855a2aa Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 18 Dec 2025 10:52:21 +0200 Subject: [PATCH] common : disable backend sampling when grammar is involved --- common/sampling.cpp | 10 +++++++++- common/sampling.h | 4 +++- src/llama-context.cpp | 2 +- 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 48e8addc23..d3b67f0fb5 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -176,7 +176,7 @@ std::string common_params_sampling::print() const { return std::string(result); } -struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) { +struct common_sampler * common_sampler_init(const struct llama_model * model, struct common_params_sampling & params) { const llama_vocab * vocab = llama_model_get_vocab(model); llama_sampler_chain_params lparams = llama_sampler_chain_default_params(); @@ -313,6 +313,12 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co llama_sampler_chain_add(chain, smpl); } + if (grmr && params.backend_sampling) { + LOG_WRN("%s: backend sampling is not compatible with grammar, disabling\n", __func__); + + params.backend_sampling = false; + } + auto * result = new common_sampler { /* .params = */ params, /* .grmr = */ grmr, @@ -430,6 +436,8 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co if (id != LLAMA_TOKEN_NULL) { LOG_DBG("%s: Backend sampler selected token: '%d'. Will not run any CPU samplers\n", __func__, id); + GGML_ASSERT(!gsmpl->grmr && "using grammar in combination with backend sampling is not supported"); + // TODO: simplify gsmpl->cur.resize(1); gsmpl->cur[0] = { id, 0.0f, 1.0f }; diff --git a/common/sampling.h b/common/sampling.h index c7101032f2..5b57ad6581 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -36,7 +36,8 @@ struct common_sampler; // llama_sampler API overloads -struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params); +// note: can mutate params in some cases +struct common_sampler * common_sampler_init(const struct llama_model * model, struct common_params_sampling & params); void common_sampler_free(struct common_sampler * gsmpl); @@ -48,6 +49,7 @@ struct common_sampler * common_sampler_clone (struct common_sampler * gsmpl); // arguments can be nullptr to skip printing void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl); +// get the underlying llama_sampler_chain struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl); // extended sampling implementation: diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 133124b4f5..b32674ab76 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1571,7 +1571,7 @@ int llama_context::decode(const llama_batch & batch_inp) { } // extract logits - // For multipsequence batches that mix backend samplers and CPU sampler + // For multi-sequence batches that mix backend samplers and CPU sampler // this is currently inefficient as we copy all logits even for the // backend sampled tokens. if (logits && t_logits && n_outputs > 0) {