sampling : handle n_probs case

This commit is contained in:
Georgi Gerganov 2025-12-08 21:30:10 +02:00
parent 6d38db5dfe
commit f3beb22b17
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
5 changed files with 25 additions and 21 deletions

View File

@ -435,6 +435,9 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
llama_token id = LLAMA_TOKEN_NULL;
auto & chain = gsmpl->chain;
auto & cur_p = gsmpl->cur_p; // initialized by set_logits
// Check if a backend sampler has already sampled a token in which case we
// return that token id directly.
{
@ -443,15 +446,17 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
if (id != LLAMA_TOKEN_NULL) {
LOG_DBG("%s: Backend sampler selected token: '%d'. Will not run any CPU samplers\n", __func__, id);
// TODO: simplify
gsmpl->cur.resize(1);
gsmpl->cur[0] = { id, 0.0f, 1.0f };
cur_p = { gsmpl->cur.data(), gsmpl->cur.size(), 0, true };
return id;
}
}
gsmpl->set_logits(ctx, idx);
auto & chain = gsmpl->chain;
auto & cur_p = gsmpl->cur_p; // initialized by set_logits
llama_sampler_apply(chain, &cur_p);
GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");

View File

@ -1106,15 +1106,14 @@ static void llama_sampler_dist_backend_apply(
// Map back to original vocab ids if a candidates tensor is available.
struct ggml_tensor * sampled_token = idx;
if (data->candidates != nullptr) {
struct ggml_tensor * candidates = data->candidates;
struct ggml_tensor * candidates_reshaped = ggml_view_2d(ctx, candidates, 1, ggml_nelements(candidates),
ggml_type_size(candidates->type), 0);
struct ggml_tensor * candidates = ggml_reshape_2d(ctx, data->candidates, 1, ggml_nelements(data->candidates));
sampled_token = ggml_get_rows(ctx, candidates_reshaped, idx);
sampled_token = ggml_get_rows(ctx, candidates, idx);
ggml_set_name(sampled_token, "dist_sampled_token");
}
data->sampled = sampled_token;
data->probs = probs;
}
static void llama_sampler_dist_backend_set_input(struct llama_sampler * smpl) {

View File

@ -1056,8 +1056,11 @@ struct server_context_impl {
return false;
}
// TODO: getting post/pre sampling logits is not yet supported with backend sampling
const bool need_logits = task.params.sampling.n_probs > 0;
// TODO: tmp until backend sampling is fully implemented
if (task.params.sampling.backend_sampling) {
if (task.params.sampling.backend_sampling && !need_logits) {
llama_set_sampler(ctx, slot.id, common_sampler_get(slot.smpl.get()));
} else {
llama_set_sampler(ctx, slot.id, nullptr);
@ -1216,10 +1219,8 @@ struct server_context_impl {
return slot.has_next_token; // continue
}
// TODO: does not work with backend sampling
void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) const {
size_t n_probs = slot.task->params.sampling.n_probs;
size_t n_vocab = llama_vocab_n_tokens(vocab);
const size_t n_probs = slot.task->params.sampling.n_probs;
if (post_sampling) {
const auto * cur_p = common_sampler_get_candidates(slot.smpl.get(), true);
@ -1247,7 +1248,7 @@ struct server_context_impl {
std::vector<llama_token_data> cur = get_token_probabilities(ctx, idx);
// set probability for sampled token
for (size_t i = 0; i < n_vocab; i++) {
for (size_t i = 0; i < cur.size(); i++) {
// set probability for sampled token
if (cur[i].id == result.tok) {
result.prob = cur[i].p;
@ -1257,7 +1258,7 @@ struct server_context_impl {
// set probability for top n_probs tokens
result.probs.reserve(n_probs);
for (size_t i = 0; i < std::min(n_vocab, n_probs); i++) {
for (size_t i = 0; i < std::min(cur.size(), n_probs); i++) {
result.probs.push_back({
cur[i].id,
common_token_to_piece(ctx, cur[i].id, special),

View File

@ -301,7 +301,7 @@ def test_logprobs():
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
res = client.chat.completions.create(
model="gpt-3.5-turbo-instruct",
temperature=1.0,
temperature=0.0,
messages=[
{"role": "system", "content": "Book"},
{"role": "user", "content": "What is the best book"},
@ -328,7 +328,7 @@ def test_logprobs_stream():
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
res = client.chat.completions.create(
model="gpt-3.5-turbo-instruct",
temperature=1.0,
temperature=0.0,
messages=[
{"role": "system", "content": "Book"},
{"role": "user", "content": "What is the best book"},
@ -494,5 +494,5 @@ def test_chat_completions_multiple_choices():
assert len(res.body["choices"]) == 2
for choice in res.body["choices"]:
assert "assistant" == choice["message"]["role"]
assert match_regex("Suddenly", choice["message"]["content"])
assert match_regex("Suddenly|Timmy", choice["message"]["content"])
assert choice["finish_reason"] == "length"

View File

@ -441,7 +441,7 @@ def test_n_probs():
res = server.make_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is",
"n_probs": 10,
"temperature": 1.0,
"temperature": 0.0,
"n_predict": 5,
})
assert res.status_code == 200
@ -466,7 +466,7 @@ def test_n_probs_stream():
res = server.make_stream_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is",
"n_probs": 10,
"temperature": 1.0,
"temperature": 0.0,
"n_predict": 5,
"stream": True,
})
@ -487,7 +487,6 @@ def test_n_probs_stream():
assert "bytes" in prob and type(prob["bytes"]) == list
# TODO: this does not work with backend sampling
def test_n_probs_post_sampling():
global server
server.start()
@ -512,8 +511,8 @@ def test_n_probs_post_sampling():
assert "token" in prob and type(prob["token"]) == str
assert "prob" in prob and 0.0 <= prob["prob"] <= 1.0
assert "bytes" in prob and type(prob["bytes"]) == list
# because the test model usually output token with either 100% or 0% probability, we need to check all the top_probs
assert any(prob["prob"] == 1.0 for prob in tok["top_probs"])
# at low temperature, one of the token has a very high probability
assert any(prob["prob"] >= 0.99 for prob in tok["top_probs"])
@pytest.mark.parametrize("tokenize,openai_style", [(False, False), (False, True), (True, False), (True, True)])