From d89f8dd0ea120c0d7701d83f12f69baaa7fdd16b Mon Sep 17 00:00:00 2001 From: Yuri Khrustalev Date: Mon, 30 Mar 2026 18:43:11 -0700 Subject: [PATCH] server: respect the ignore eos flag --- tools/server/server-context.cpp | 3 ++ tools/server/server-context.h | 3 ++ tools/server/server-task.cpp | 3 +- tools/server/server-task.h | 1 + tools/server/tests/unit/test_ignore_eos.py | 43 ++++++++++++++++++++++ 5 files changed, 52 insertions(+), 1 deletion(-) create mode 100644 tools/server/tests/unit/test_ignore_eos.py diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 6f737d94d0..d002ea1c3b 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -3000,6 +3000,8 @@ server_context_meta server_context::get_meta() const { /* fim_rep_token */ llama_vocab_fim_rep(impl->vocab), /* fim_sep_token */ llama_vocab_fim_sep(impl->vocab), + /* logit_bias_eog */ impl->params_base.sampling.logit_bias_eog, + /* model_vocab_type */ llama_vocab_type(impl->vocab), /* model_vocab_n_tokens */ llama_vocab_n_tokens(impl->vocab), /* model_n_ctx_train */ llama_model_n_ctx_train(impl->model), @@ -3084,6 +3086,7 @@ std::unique_ptr server_routes::handle_completions_impl( ctx_server.vocab, params, meta->slot_n_ctx, + meta->logit_bias_eog, data); task.id_slot = json_value(data, "id_slot", -1); diff --git a/tools/server/server-context.h b/tools/server/server-context.h index a4d2201cbe..fa71ace978 100644 --- a/tools/server/server-context.h +++ b/tools/server/server-context.h @@ -39,6 +39,9 @@ struct server_context_meta { llama_token fim_rep_token; llama_token fim_sep_token; + // sampling + std::vector logit_bias_eog; + // model meta enum llama_vocab_type model_vocab_type; int32_t model_vocab_n_tokens; diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index 3018ac90f8..8dada57994 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -239,6 +239,7 @@ task_params server_task::params_from_json_cmpl( const llama_vocab * vocab, const common_params & params_base, const int n_ctx_slot, + const std::vector & logit_bias_eog, const json & data) { task_params params; @@ -562,7 +563,7 @@ task_params server_task::params_from_json_cmpl( if (params.sampling.ignore_eos) { params.sampling.logit_bias.insert( params.sampling.logit_bias.end(), - defaults.sampling.logit_bias_eog.begin(), defaults.sampling.logit_bias_eog.end()); + logit_bias_eog.begin(), logit_bias_eog.end()); } } diff --git a/tools/server/server-task.h b/tools/server/server-task.h index a49ddb594b..0b319142aa 100644 --- a/tools/server/server-task.h +++ b/tools/server/server-task.h @@ -209,6 +209,7 @@ struct server_task { const llama_vocab * vocab, const common_params & params_base, const int n_ctx_slot, + const std::vector & logit_bias_eog, const json & data); // utility function diff --git a/tools/server/tests/unit/test_ignore_eos.py b/tools/server/tests/unit/test_ignore_eos.py new file mode 100644 index 0000000000..f40faf5a82 --- /dev/null +++ b/tools/server/tests/unit/test_ignore_eos.py @@ -0,0 +1,43 @@ +import pytest +from utils import * + +server = ServerPreset.tinyllama2() + + +@pytest.fixture(autouse=True) +def create_server(): + global server + server = ServerPreset.tinyllama2() + + +def test_ignore_eos_populates_logit_bias(): + """ignore_eos=true must add EOG logit biases to generation_settings.""" + global server + server.start() + res = server.make_request("POST", "/completion", data={ + "n_predict": 8, + "prompt": "Once upon a time", + "ignore_eos": True, + "temperature": 0.0, + }) + assert res.status_code == 200 + # EOG token biases must be present with -inf bias + logit_bias = res.body["generation_settings"]["logit_bias"] + assert len(logit_bias) > 0 + for entry in logit_bias: + assert entry["bias"] is None # null in JSON represents -inf + + +def test_ignore_eos_false_no_logit_bias(): + """ignore_eos=false (default) must NOT add EOG logit biases.""" + global server + server.start() + res = server.make_request("POST", "/completion", data={ + "n_predict": 8, + "prompt": "Once upon a time", + "ignore_eos": False, + "temperature": 0.0, + }) + assert res.status_code == 200 + logit_bias = res.body["generation_settings"]["logit_bias"] + assert len(logit_bias) == 0