sampling : handle n_probs case
This commit is contained in:
parent
6d38db5dfe
commit
f3beb22b17
|
|
@ -435,6 +435,9 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
|
|||
|
||||
llama_token id = LLAMA_TOKEN_NULL;
|
||||
|
||||
auto & chain = gsmpl->chain;
|
||||
auto & cur_p = gsmpl->cur_p; // initialized by set_logits
|
||||
|
||||
// Check if a backend sampler has already sampled a token in which case we
|
||||
// return that token id directly.
|
||||
{
|
||||
|
|
@ -443,15 +446,17 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
|
|||
if (id != LLAMA_TOKEN_NULL) {
|
||||
LOG_DBG("%s: Backend sampler selected token: '%d'. Will not run any CPU samplers\n", __func__, id);
|
||||
|
||||
// TODO: simplify
|
||||
gsmpl->cur.resize(1);
|
||||
gsmpl->cur[0] = { id, 0.0f, 1.0f };
|
||||
cur_p = { gsmpl->cur.data(), gsmpl->cur.size(), 0, true };
|
||||
|
||||
return id;
|
||||
}
|
||||
}
|
||||
|
||||
gsmpl->set_logits(ctx, idx);
|
||||
|
||||
auto & chain = gsmpl->chain;
|
||||
auto & cur_p = gsmpl->cur_p; // initialized by set_logits
|
||||
|
||||
llama_sampler_apply(chain, &cur_p);
|
||||
|
||||
GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
|
||||
|
|
|
|||
|
|
@ -1106,15 +1106,14 @@ static void llama_sampler_dist_backend_apply(
|
|||
// Map back to original vocab ids if a candidates tensor is available.
|
||||
struct ggml_tensor * sampled_token = idx;
|
||||
if (data->candidates != nullptr) {
|
||||
struct ggml_tensor * candidates = data->candidates;
|
||||
struct ggml_tensor * candidates_reshaped = ggml_view_2d(ctx, candidates, 1, ggml_nelements(candidates),
|
||||
ggml_type_size(candidates->type), 0);
|
||||
struct ggml_tensor * candidates = ggml_reshape_2d(ctx, data->candidates, 1, ggml_nelements(data->candidates));
|
||||
|
||||
sampled_token = ggml_get_rows(ctx, candidates_reshaped, idx);
|
||||
sampled_token = ggml_get_rows(ctx, candidates, idx);
|
||||
ggml_set_name(sampled_token, "dist_sampled_token");
|
||||
}
|
||||
|
||||
data->sampled = sampled_token;
|
||||
data->probs = probs;
|
||||
}
|
||||
|
||||
static void llama_sampler_dist_backend_set_input(struct llama_sampler * smpl) {
|
||||
|
|
|
|||
|
|
@ -1056,8 +1056,11 @@ struct server_context_impl {
|
|||
return false;
|
||||
}
|
||||
|
||||
// TODO: getting post/pre sampling logits is not yet supported with backend sampling
|
||||
const bool need_logits = task.params.sampling.n_probs > 0;
|
||||
|
||||
// TODO: tmp until backend sampling is fully implemented
|
||||
if (task.params.sampling.backend_sampling) {
|
||||
if (task.params.sampling.backend_sampling && !need_logits) {
|
||||
llama_set_sampler(ctx, slot.id, common_sampler_get(slot.smpl.get()));
|
||||
} else {
|
||||
llama_set_sampler(ctx, slot.id, nullptr);
|
||||
|
|
@ -1216,10 +1219,8 @@ struct server_context_impl {
|
|||
return slot.has_next_token; // continue
|
||||
}
|
||||
|
||||
// TODO: does not work with backend sampling
|
||||
void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) const {
|
||||
size_t n_probs = slot.task->params.sampling.n_probs;
|
||||
size_t n_vocab = llama_vocab_n_tokens(vocab);
|
||||
const size_t n_probs = slot.task->params.sampling.n_probs;
|
||||
|
||||
if (post_sampling) {
|
||||
const auto * cur_p = common_sampler_get_candidates(slot.smpl.get(), true);
|
||||
|
|
@ -1247,7 +1248,7 @@ struct server_context_impl {
|
|||
std::vector<llama_token_data> cur = get_token_probabilities(ctx, idx);
|
||||
|
||||
// set probability for sampled token
|
||||
for (size_t i = 0; i < n_vocab; i++) {
|
||||
for (size_t i = 0; i < cur.size(); i++) {
|
||||
// set probability for sampled token
|
||||
if (cur[i].id == result.tok) {
|
||||
result.prob = cur[i].p;
|
||||
|
|
@ -1257,7 +1258,7 @@ struct server_context_impl {
|
|||
|
||||
// set probability for top n_probs tokens
|
||||
result.probs.reserve(n_probs);
|
||||
for (size_t i = 0; i < std::min(n_vocab, n_probs); i++) {
|
||||
for (size_t i = 0; i < std::min(cur.size(), n_probs); i++) {
|
||||
result.probs.push_back({
|
||||
cur[i].id,
|
||||
common_token_to_piece(ctx, cur[i].id, special),
|
||||
|
|
|
|||
|
|
@ -301,7 +301,7 @@ def test_logprobs():
|
|||
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
|
||||
res = client.chat.completions.create(
|
||||
model="gpt-3.5-turbo-instruct",
|
||||
temperature=1.0,
|
||||
temperature=0.0,
|
||||
messages=[
|
||||
{"role": "system", "content": "Book"},
|
||||
{"role": "user", "content": "What is the best book"},
|
||||
|
|
@ -328,7 +328,7 @@ def test_logprobs_stream():
|
|||
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
|
||||
res = client.chat.completions.create(
|
||||
model="gpt-3.5-turbo-instruct",
|
||||
temperature=1.0,
|
||||
temperature=0.0,
|
||||
messages=[
|
||||
{"role": "system", "content": "Book"},
|
||||
{"role": "user", "content": "What is the best book"},
|
||||
|
|
@ -494,5 +494,5 @@ def test_chat_completions_multiple_choices():
|
|||
assert len(res.body["choices"]) == 2
|
||||
for choice in res.body["choices"]:
|
||||
assert "assistant" == choice["message"]["role"]
|
||||
assert match_regex("Suddenly", choice["message"]["content"])
|
||||
assert match_regex("Suddenly|Timmy", choice["message"]["content"])
|
||||
assert choice["finish_reason"] == "length"
|
||||
|
|
|
|||
|
|
@ -441,7 +441,7 @@ def test_n_probs():
|
|||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": "I believe the meaning of life is",
|
||||
"n_probs": 10,
|
||||
"temperature": 1.0,
|
||||
"temperature": 0.0,
|
||||
"n_predict": 5,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
|
|
@ -466,7 +466,7 @@ def test_n_probs_stream():
|
|||
res = server.make_stream_request("POST", "/completion", data={
|
||||
"prompt": "I believe the meaning of life is",
|
||||
"n_probs": 10,
|
||||
"temperature": 1.0,
|
||||
"temperature": 0.0,
|
||||
"n_predict": 5,
|
||||
"stream": True,
|
||||
})
|
||||
|
|
@ -487,7 +487,6 @@ def test_n_probs_stream():
|
|||
assert "bytes" in prob and type(prob["bytes"]) == list
|
||||
|
||||
|
||||
# TODO: this does not work with backend sampling
|
||||
def test_n_probs_post_sampling():
|
||||
global server
|
||||
server.start()
|
||||
|
|
@ -512,8 +511,8 @@ def test_n_probs_post_sampling():
|
|||
assert "token" in prob and type(prob["token"]) == str
|
||||
assert "prob" in prob and 0.0 <= prob["prob"] <= 1.0
|
||||
assert "bytes" in prob and type(prob["bytes"]) == list
|
||||
# because the test model usually output token with either 100% or 0% probability, we need to check all the top_probs
|
||||
assert any(prob["prob"] == 1.0 for prob in tok["top_probs"])
|
||||
# at low temperature, one of the token has a very high probability
|
||||
assert any(prob["prob"] >= 0.99 for prob in tok["top_probs"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tokenize,openai_style", [(False, False), (False, True), (True, False), (True, True)])
|
||||
|
|
|
|||
Loading…
Reference in New Issue