server: respect the ignore eos flag
This commit is contained in:
parent
08f21453ae
commit
d89f8dd0ea
|
|
@ -3000,6 +3000,8 @@ server_context_meta server_context::get_meta() const {
|
|||
/* fim_rep_token */ llama_vocab_fim_rep(impl->vocab),
|
||||
/* fim_sep_token */ llama_vocab_fim_sep(impl->vocab),
|
||||
|
||||
/* logit_bias_eog */ impl->params_base.sampling.logit_bias_eog,
|
||||
|
||||
/* model_vocab_type */ llama_vocab_type(impl->vocab),
|
||||
/* model_vocab_n_tokens */ llama_vocab_n_tokens(impl->vocab),
|
||||
/* model_n_ctx_train */ llama_model_n_ctx_train(impl->model),
|
||||
|
|
@ -3084,6 +3086,7 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
|
|||
ctx_server.vocab,
|
||||
params,
|
||||
meta->slot_n_ctx,
|
||||
meta->logit_bias_eog,
|
||||
data);
|
||||
task.id_slot = json_value(data, "id_slot", -1);
|
||||
|
||||
|
|
|
|||
|
|
@ -39,6 +39,9 @@ struct server_context_meta {
|
|||
llama_token fim_rep_token;
|
||||
llama_token fim_sep_token;
|
||||
|
||||
// sampling
|
||||
std::vector<llama_logit_bias> logit_bias_eog;
|
||||
|
||||
// model meta
|
||||
enum llama_vocab_type model_vocab_type;
|
||||
int32_t model_vocab_n_tokens;
|
||||
|
|
|
|||
|
|
@ -239,6 +239,7 @@ task_params server_task::params_from_json_cmpl(
|
|||
const llama_vocab * vocab,
|
||||
const common_params & params_base,
|
||||
const int n_ctx_slot,
|
||||
const std::vector<llama_logit_bias> & logit_bias_eog,
|
||||
const json & data) {
|
||||
task_params params;
|
||||
|
||||
|
|
@ -562,7 +563,7 @@ task_params server_task::params_from_json_cmpl(
|
|||
if (params.sampling.ignore_eos) {
|
||||
params.sampling.logit_bias.insert(
|
||||
params.sampling.logit_bias.end(),
|
||||
defaults.sampling.logit_bias_eog.begin(), defaults.sampling.logit_bias_eog.end());
|
||||
logit_bias_eog.begin(), logit_bias_eog.end());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -209,6 +209,7 @@ struct server_task {
|
|||
const llama_vocab * vocab,
|
||||
const common_params & params_base,
|
||||
const int n_ctx_slot,
|
||||
const std::vector<llama_logit_bias> & logit_bias_eog,
|
||||
const json & data);
|
||||
|
||||
// utility function
|
||||
|
|
|
|||
|
|
@ -0,0 +1,43 @@
|
|||
import pytest
|
||||
from utils import *
|
||||
|
||||
server = ServerPreset.tinyllama2()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def create_server():
|
||||
global server
|
||||
server = ServerPreset.tinyllama2()
|
||||
|
||||
|
||||
def test_ignore_eos_populates_logit_bias():
|
||||
"""ignore_eos=true must add EOG logit biases to generation_settings."""
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"n_predict": 8,
|
||||
"prompt": "Once upon a time",
|
||||
"ignore_eos": True,
|
||||
"temperature": 0.0,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
# EOG token biases must be present with -inf bias
|
||||
logit_bias = res.body["generation_settings"]["logit_bias"]
|
||||
assert len(logit_bias) > 0
|
||||
for entry in logit_bias:
|
||||
assert entry["bias"] is None # null in JSON represents -inf
|
||||
|
||||
|
||||
def test_ignore_eos_false_no_logit_bias():
|
||||
"""ignore_eos=false (default) must NOT add EOG logit biases."""
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"n_predict": 8,
|
||||
"prompt": "Once upon a time",
|
||||
"ignore_eos": False,
|
||||
"temperature": 0.0,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
logit_bias = res.body["generation_settings"]["logit_bias"]
|
||||
assert len(logit_bias) == 0
|
||||
Loading…
Reference in New Issue