Merge d89f8dd0ea into f851fa5ab0
This commit is contained in:
commit
47449415b7
|
|
@ -3000,6 +3000,8 @@ server_context_meta server_context::get_meta() const {
|
||||||
/* fim_rep_token */ llama_vocab_fim_rep(impl->vocab),
|
/* fim_rep_token */ llama_vocab_fim_rep(impl->vocab),
|
||||||
/* fim_sep_token */ llama_vocab_fim_sep(impl->vocab),
|
/* fim_sep_token */ llama_vocab_fim_sep(impl->vocab),
|
||||||
|
|
||||||
|
/* logit_bias_eog */ impl->params_base.sampling.logit_bias_eog,
|
||||||
|
|
||||||
/* model_vocab_type */ llama_vocab_type(impl->vocab),
|
/* model_vocab_type */ llama_vocab_type(impl->vocab),
|
||||||
/* model_vocab_n_tokens */ llama_vocab_n_tokens(impl->vocab),
|
/* model_vocab_n_tokens */ llama_vocab_n_tokens(impl->vocab),
|
||||||
/* model_n_ctx_train */ llama_model_n_ctx_train(impl->model),
|
/* model_n_ctx_train */ llama_model_n_ctx_train(impl->model),
|
||||||
|
|
@ -3084,6 +3086,7 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
|
||||||
ctx_server.vocab,
|
ctx_server.vocab,
|
||||||
params,
|
params,
|
||||||
meta->slot_n_ctx,
|
meta->slot_n_ctx,
|
||||||
|
meta->logit_bias_eog,
|
||||||
data);
|
data);
|
||||||
task.id_slot = json_value(data, "id_slot", -1);
|
task.id_slot = json_value(data, "id_slot", -1);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -39,6 +39,9 @@ struct server_context_meta {
|
||||||
llama_token fim_rep_token;
|
llama_token fim_rep_token;
|
||||||
llama_token fim_sep_token;
|
llama_token fim_sep_token;
|
||||||
|
|
||||||
|
// sampling
|
||||||
|
std::vector<llama_logit_bias> logit_bias_eog;
|
||||||
|
|
||||||
// model meta
|
// model meta
|
||||||
enum llama_vocab_type model_vocab_type;
|
enum llama_vocab_type model_vocab_type;
|
||||||
int32_t model_vocab_n_tokens;
|
int32_t model_vocab_n_tokens;
|
||||||
|
|
|
||||||
|
|
@ -239,6 +239,7 @@ task_params server_task::params_from_json_cmpl(
|
||||||
const llama_vocab * vocab,
|
const llama_vocab * vocab,
|
||||||
const common_params & params_base,
|
const common_params & params_base,
|
||||||
const int n_ctx_slot,
|
const int n_ctx_slot,
|
||||||
|
const std::vector<llama_logit_bias> & logit_bias_eog,
|
||||||
const json & data) {
|
const json & data) {
|
||||||
task_params params;
|
task_params params;
|
||||||
|
|
||||||
|
|
@ -562,7 +563,7 @@ task_params server_task::params_from_json_cmpl(
|
||||||
if (params.sampling.ignore_eos) {
|
if (params.sampling.ignore_eos) {
|
||||||
params.sampling.logit_bias.insert(
|
params.sampling.logit_bias.insert(
|
||||||
params.sampling.logit_bias.end(),
|
params.sampling.logit_bias.end(),
|
||||||
defaults.sampling.logit_bias_eog.begin(), defaults.sampling.logit_bias_eog.end());
|
logit_bias_eog.begin(), logit_bias_eog.end());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -209,6 +209,7 @@ struct server_task {
|
||||||
const llama_vocab * vocab,
|
const llama_vocab * vocab,
|
||||||
const common_params & params_base,
|
const common_params & params_base,
|
||||||
const int n_ctx_slot,
|
const int n_ctx_slot,
|
||||||
|
const std::vector<llama_logit_bias> & logit_bias_eog,
|
||||||
const json & data);
|
const json & data);
|
||||||
|
|
||||||
// utility function
|
// utility function
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,43 @@
|
||||||
|
import pytest
|
||||||
|
from utils import *
|
||||||
|
|
||||||
|
server = ServerPreset.tinyllama2()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def create_server():
|
||||||
|
global server
|
||||||
|
server = ServerPreset.tinyllama2()
|
||||||
|
|
||||||
|
|
||||||
|
def test_ignore_eos_populates_logit_bias():
|
||||||
|
"""ignore_eos=true must add EOG logit biases to generation_settings."""
|
||||||
|
global server
|
||||||
|
server.start()
|
||||||
|
res = server.make_request("POST", "/completion", data={
|
||||||
|
"n_predict": 8,
|
||||||
|
"prompt": "Once upon a time",
|
||||||
|
"ignore_eos": True,
|
||||||
|
"temperature": 0.0,
|
||||||
|
})
|
||||||
|
assert res.status_code == 200
|
||||||
|
# EOG token biases must be present with -inf bias
|
||||||
|
logit_bias = res.body["generation_settings"]["logit_bias"]
|
||||||
|
assert len(logit_bias) > 0
|
||||||
|
for entry in logit_bias:
|
||||||
|
assert entry["bias"] is None # null in JSON represents -inf
|
||||||
|
|
||||||
|
|
||||||
|
def test_ignore_eos_false_no_logit_bias():
|
||||||
|
"""ignore_eos=false (default) must NOT add EOG logit biases."""
|
||||||
|
global server
|
||||||
|
server.start()
|
||||||
|
res = server.make_request("POST", "/completion", data={
|
||||||
|
"n_predict": 8,
|
||||||
|
"prompt": "Once upon a time",
|
||||||
|
"ignore_eos": False,
|
||||||
|
"temperature": 0.0,
|
||||||
|
})
|
||||||
|
assert res.status_code == 200
|
||||||
|
logit_bias = res.body["generation_settings"]["logit_bias"]
|
||||||
|
assert len(logit_bias) == 0
|
||||||
Loading…
Reference in New Issue