From f3beb22b17eb6bbf7096a132d9c3b3ac82453e72 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 8 Dec 2025 21:30:10 +0200 Subject: [PATCH] sampling : handle n_probs case --- common/sampling.cpp | 11 ++++++++--- src/llama-sampling.cpp | 7 +++---- tools/server/server-context.cpp | 13 +++++++------ tools/server/tests/unit/test_chat_completion.py | 6 +++--- tools/server/tests/unit/test_completion.py | 9 ++++----- 5 files changed, 25 insertions(+), 21 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 8095d8ec22..aefc596443 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -435,6 +435,9 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co llama_token id = LLAMA_TOKEN_NULL; + auto & chain = gsmpl->chain; + auto & cur_p = gsmpl->cur_p; // initialized by set_logits + // Check if a backend sampler has already sampled a token in which case we // return that token id directly. { @@ -443,15 +446,17 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co if (id != LLAMA_TOKEN_NULL) { LOG_DBG("%s: Backend sampler selected token: '%d'. Will not run any CPU samplers\n", __func__, id); + // TODO: simplify + gsmpl->cur.resize(1); + gsmpl->cur[0] = { id, 0.0f, 1.0f }; + cur_p = { gsmpl->cur.data(), gsmpl->cur.size(), 0, true }; + return id; } } gsmpl->set_logits(ctx, idx); - auto & chain = gsmpl->chain; - auto & cur_p = gsmpl->cur_p; // initialized by set_logits - llama_sampler_apply(chain, &cur_p); GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration"); diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index d70b765e63..86f82d1691 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1106,15 +1106,14 @@ static void llama_sampler_dist_backend_apply( // Map back to original vocab ids if a candidates tensor is available. struct ggml_tensor * sampled_token = idx; if (data->candidates != nullptr) { - struct ggml_tensor * candidates = data->candidates; - struct ggml_tensor * candidates_reshaped = ggml_view_2d(ctx, candidates, 1, ggml_nelements(candidates), - ggml_type_size(candidates->type), 0); + struct ggml_tensor * candidates = ggml_reshape_2d(ctx, data->candidates, 1, ggml_nelements(data->candidates)); - sampled_token = ggml_get_rows(ctx, candidates_reshaped, idx); + sampled_token = ggml_get_rows(ctx, candidates, idx); ggml_set_name(sampled_token, "dist_sampled_token"); } data->sampled = sampled_token; + data->probs = probs; } static void llama_sampler_dist_backend_set_input(struct llama_sampler * smpl) { diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 29da4a200d..0e4224e115 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -1056,8 +1056,11 @@ struct server_context_impl { return false; } + // TODO: getting post/pre sampling logits is not yet supported with backend sampling + const bool need_logits = task.params.sampling.n_probs > 0; + // TODO: tmp until backend sampling is fully implemented - if (task.params.sampling.backend_sampling) { + if (task.params.sampling.backend_sampling && !need_logits) { llama_set_sampler(ctx, slot.id, common_sampler_get(slot.smpl.get())); } else { llama_set_sampler(ctx, slot.id, nullptr); @@ -1216,10 +1219,8 @@ struct server_context_impl { return slot.has_next_token; // continue } - // TODO: does not work with backend sampling void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) const { - size_t n_probs = slot.task->params.sampling.n_probs; - size_t n_vocab = llama_vocab_n_tokens(vocab); + const size_t n_probs = slot.task->params.sampling.n_probs; if (post_sampling) { const auto * cur_p = common_sampler_get_candidates(slot.smpl.get(), true); @@ -1247,7 +1248,7 @@ struct server_context_impl { std::vector cur = get_token_probabilities(ctx, idx); // set probability for sampled token - for (size_t i = 0; i < n_vocab; i++) { + for (size_t i = 0; i < cur.size(); i++) { // set probability for sampled token if (cur[i].id == result.tok) { result.prob = cur[i].p; @@ -1257,7 +1258,7 @@ struct server_context_impl { // set probability for top n_probs tokens result.probs.reserve(n_probs); - for (size_t i = 0; i < std::min(n_vocab, n_probs); i++) { + for (size_t i = 0; i < std::min(cur.size(), n_probs); i++) { result.probs.push_back({ cur[i].id, common_token_to_piece(ctx, cur[i].id, special), diff --git a/tools/server/tests/unit/test_chat_completion.py b/tools/server/tests/unit/test_chat_completion.py index 500bae1eca..08b5265d48 100644 --- a/tools/server/tests/unit/test_chat_completion.py +++ b/tools/server/tests/unit/test_chat_completion.py @@ -301,7 +301,7 @@ def test_logprobs(): client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1") res = client.chat.completions.create( model="gpt-3.5-turbo-instruct", - temperature=1.0, + temperature=0.0, messages=[ {"role": "system", "content": "Book"}, {"role": "user", "content": "What is the best book"}, @@ -328,7 +328,7 @@ def test_logprobs_stream(): client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1") res = client.chat.completions.create( model="gpt-3.5-turbo-instruct", - temperature=1.0, + temperature=0.0, messages=[ {"role": "system", "content": "Book"}, {"role": "user", "content": "What is the best book"}, @@ -494,5 +494,5 @@ def test_chat_completions_multiple_choices(): assert len(res.body["choices"]) == 2 for choice in res.body["choices"]: assert "assistant" == choice["message"]["role"] - assert match_regex("Suddenly", choice["message"]["content"]) + assert match_regex("Suddenly|Timmy", choice["message"]["content"]) assert choice["finish_reason"] == "length" diff --git a/tools/server/tests/unit/test_completion.py b/tools/server/tests/unit/test_completion.py index 57fca8231a..daaa6e5a90 100644 --- a/tools/server/tests/unit/test_completion.py +++ b/tools/server/tests/unit/test_completion.py @@ -441,7 +441,7 @@ def test_n_probs(): res = server.make_request("POST", "/completion", data={ "prompt": "I believe the meaning of life is", "n_probs": 10, - "temperature": 1.0, + "temperature": 0.0, "n_predict": 5, }) assert res.status_code == 200 @@ -466,7 +466,7 @@ def test_n_probs_stream(): res = server.make_stream_request("POST", "/completion", data={ "prompt": "I believe the meaning of life is", "n_probs": 10, - "temperature": 1.0, + "temperature": 0.0, "n_predict": 5, "stream": True, }) @@ -487,7 +487,6 @@ def test_n_probs_stream(): assert "bytes" in prob and type(prob["bytes"]) == list -# TODO: this does not work with backend sampling def test_n_probs_post_sampling(): global server server.start() @@ -512,8 +511,8 @@ def test_n_probs_post_sampling(): assert "token" in prob and type(prob["token"]) == str assert "prob" in prob and 0.0 <= prob["prob"] <= 1.0 assert "bytes" in prob and type(prob["bytes"]) == list - # because the test model usually output token with either 100% or 0% probability, we need to check all the top_probs - assert any(prob["prob"] == 1.0 for prob in tok["top_probs"]) + # at low temperature, one of the token has a very high probability + assert any(prob["prob"] >= 0.99 for prob in tok["top_probs"]) @pytest.mark.parametrize("tokenize,openai_style", [(False, False), (False, True), (True, False), (True, True)])