sampling : always expose sampled_ids

This commit precomputes and caches the full-vocab token id list in
llama_context's constructor, so llama_get_backend_sampled_token_ids_ith
always returns a valid pointer.

The motivation for this is that this enables both common/sampling.cpp
and src/llama-sampling.cpp can simplify their logic.

Not all backends samplers that process logits need to set the
sampled_tokens_id as they may not change the order of the logits, for
example the temperature sampler only scales the logits but does not
change their order. Simliar the logit bias sampler only adds bias to
specific token ids but does not change the order of the logits. In
these cases there will not be a device to host copy of the sampled
token ids, and this is the use case where having this precomputed
list is useful.
This commit is contained in:
Daniel Bevenius 2025-11-18 14:54:49 +01:00
parent 4b52e59903
commit 82957a90f2
No known key found for this signature in database
4 changed files with 26 additions and 40 deletions

View File

@ -128,28 +128,14 @@ struct common_sampler {
if (sampled_probs) {
const uint32_t sampled_probs_count = llama_get_backend_sampled_probs_count_ith(ctx, idx);
cur.reserve(sampled_probs_count);
// The backend sampler has filtered the probabilities so we need to use the sampled ids.
if (sampled_ids != nullptr) {
for (uint32_t i = 0; i < sampled_probs_count; ++i) {
cur.emplace_back(llama_token_data{sampled_ids[i], 0.0f, sampled_probs[i]});
}
} else {
for (llama_token token_id = 0; token_id < (int) sampled_probs_count; token_id++) {
cur.emplace_back(llama_token_data{token_id, 0.0f, sampled_probs[token_id]});
}
for (uint32_t i = 0; i < sampled_probs_count; ++i) {
cur.emplace_back(llama_token_data{sampled_ids[i], 0.0f, sampled_probs[i]});
}
} else if (sampled_logits) {
const uint32_t sampled_logits_count = llama_get_backend_sampled_logits_count_ith(ctx, idx);
cur.reserve(sampled_logits_count);
// The backend sampler has filtered the logits so we need to use the sampled ids.
if (sampled_ids != nullptr) {
for (uint32_t i = 0; i < sampled_logits_count; i++) {
cur.emplace_back(llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f});
}
} else {
for (llama_token token_id = 0; token_id < (int) sampled_logits_count; token_id++) {
cur.emplace_back(llama_token_data{token_id, sampled_logits[token_id], 0.0f});
}
for (uint32_t i = 0; i < sampled_logits_count; i++) {
cur.emplace_back(llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f});
}
} else {
const auto * logits = llama_get_logits_ith(ctx, idx);

View File

@ -430,6 +430,16 @@ llama_context::llama_context(
LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg);
}
}
// Initialize the full vocabulary token ids for backend samplers.
{
const llama_vocab * vocab = llama_model_get_vocab(&model);
const int n_vocab = llama_vocab_n_tokens(vocab);
sampled_token_ids_full_vocab.resize(n_vocab);
for (int i = 0; i < n_vocab; ++i) {
sampled_token_ids_full_vocab[i] = i;
}
}
}
llama_context::~llama_context() {
@ -728,15 +738,18 @@ float * llama_context::get_backend_sampled_logits_ith(int32_t idx) {
const llama_token * llama_context::get_backend_sampled_token_ids_ith(int32_t idx) {
if (idx == -1) {
if (sampled_token_ids_map.size() == 1) {
return sampled_token_ids_map.begin()->second.data();
const auto & vec = sampled_token_ids_map.begin()->second;
if (!vec.empty()) {
return vec.data();
}
}
}
auto it = sampled_token_ids_map.find(idx);
if (it == sampled_token_ids_map.end() || it->second.empty()) {
return nullptr;
if (it != sampled_token_ids_map.end() && !it->second.empty()) {
return it->second.data();
}
return it->second.data();
return sampled_token_ids_full_vocab.data();
}
size_t llama_context::get_backend_sampled_logits_count(int32_t idx) const {

View File

@ -263,6 +263,7 @@ private:
std::unordered_map<int32_t, std::vector<float>> sampled_logits_map;
std::unordered_map<int32_t, std::vector<llama_token>> sampled_token_ids_map;
std::vector<llama_token> sampled_token_ids_full_vocab;
// embeddings output (2-dimensional array: [n_outputs][n_embd])
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE

View File

@ -460,28 +460,14 @@ llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_conte
if (sampled_probs) {
const uint32_t sampled_probs_count = llama_get_backend_sampled_probs_count_ith(ctx, idx);
cur.reserve(sampled_probs_count);
// The backend sampler has filtered the probabilities so we need to use the sampled ids.
if (sampled_ids != nullptr) {
for (uint32_t i = 0; i < sampled_probs_count; ++i) {
cur.emplace_back(llama_token_data{sampled_ids[i], 0.0f, sampled_probs[i]});
}
} else {
for (llama_token token_id = 0; token_id < (int) sampled_probs_count; token_id++) {
cur.emplace_back(llama_token_data{token_id, 0.0f, sampled_probs[token_id]});
}
for (uint32_t i = 0; i < sampled_probs_count; ++i) {
cur.emplace_back(llama_token_data{sampled_ids[i], 0.0f, sampled_probs[i]});
}
} else if (sampled_logits) {
const uint32_t sampled_logits_count = llama_get_backend_sampled_logits_count_ith(ctx, idx);
cur.reserve(sampled_logits_count);
// The backend sampler has filtered the logits so we need to use the sampled ids.
if (sampled_ids != nullptr) {
for (llama_token i = 0; i < (int)sampled_logits_count; i++) {
cur.emplace_back(llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f});
}
} else {
for (llama_token token_id = 0; token_id < (int)sampled_logits_count; token_id++) {
cur.emplace_back(llama_token_data{token_id, sampled_logits[token_id], 0.0f});
}
for (llama_token i = 0; i < (int)sampled_logits_count; i++) {
cur.emplace_back(llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f});
}
} else {
const auto * logits = llama_get_logits_ith(ctx, idx);