From 82957a90f21f37b6862a3eb900b54b0de6687e23 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Tue, 18 Nov 2025 14:54:49 +0100 Subject: [PATCH] sampling : always expose sampled_ids This commit precomputes and caches the full-vocab token id list in llama_context's constructor, so llama_get_backend_sampled_token_ids_ith always returns a valid pointer. The motivation for this is that this enables both common/sampling.cpp and src/llama-sampling.cpp can simplify their logic. Not all backends samplers that process logits need to set the sampled_tokens_id as they may not change the order of the logits, for example the temperature sampler only scales the logits but does not change their order. Simliar the logit bias sampler only adds bias to specific token ids but does not change the order of the logits. In these cases there will not be a device to host copy of the sampled token ids, and this is the use case where having this precomputed list is useful. --- common/sampling.cpp | 22 ++++------------------ src/llama-context.cpp | 21 +++++++++++++++++---- src/llama-context.h | 1 + src/llama-sampling.cpp | 22 ++++------------------ 4 files changed, 26 insertions(+), 40 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index ec61c18832..9c707a5bb9 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -128,28 +128,14 @@ struct common_sampler { if (sampled_probs) { const uint32_t sampled_probs_count = llama_get_backend_sampled_probs_count_ith(ctx, idx); cur.reserve(sampled_probs_count); - // The backend sampler has filtered the probabilities so we need to use the sampled ids. - if (sampled_ids != nullptr) { - for (uint32_t i = 0; i < sampled_probs_count; ++i) { - cur.emplace_back(llama_token_data{sampled_ids[i], 0.0f, sampled_probs[i]}); - } - } else { - for (llama_token token_id = 0; token_id < (int) sampled_probs_count; token_id++) { - cur.emplace_back(llama_token_data{token_id, 0.0f, sampled_probs[token_id]}); - } + for (uint32_t i = 0; i < sampled_probs_count; ++i) { + cur.emplace_back(llama_token_data{sampled_ids[i], 0.0f, sampled_probs[i]}); } } else if (sampled_logits) { const uint32_t sampled_logits_count = llama_get_backend_sampled_logits_count_ith(ctx, idx); cur.reserve(sampled_logits_count); - // The backend sampler has filtered the logits so we need to use the sampled ids. - if (sampled_ids != nullptr) { - for (uint32_t i = 0; i < sampled_logits_count; i++) { - cur.emplace_back(llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f}); - } - } else { - for (llama_token token_id = 0; token_id < (int) sampled_logits_count; token_id++) { - cur.emplace_back(llama_token_data{token_id, sampled_logits[token_id], 0.0f}); - } + for (uint32_t i = 0; i < sampled_logits_count; i++) { + cur.emplace_back(llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f}); } } else { const auto * logits = llama_get_logits_ith(ctx, idx); diff --git a/src/llama-context.cpp b/src/llama-context.cpp index f931881c9c..25d3528434 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -430,6 +430,16 @@ llama_context::llama_context( LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg); } } + + // Initialize the full vocabulary token ids for backend samplers. + { + const llama_vocab * vocab = llama_model_get_vocab(&model); + const int n_vocab = llama_vocab_n_tokens(vocab); + sampled_token_ids_full_vocab.resize(n_vocab); + for (int i = 0; i < n_vocab; ++i) { + sampled_token_ids_full_vocab[i] = i; + } + } } llama_context::~llama_context() { @@ -728,15 +738,18 @@ float * llama_context::get_backend_sampled_logits_ith(int32_t idx) { const llama_token * llama_context::get_backend_sampled_token_ids_ith(int32_t idx) { if (idx == -1) { if (sampled_token_ids_map.size() == 1) { - return sampled_token_ids_map.begin()->second.data(); + const auto & vec = sampled_token_ids_map.begin()->second; + if (!vec.empty()) { + return vec.data(); + } } } auto it = sampled_token_ids_map.find(idx); - if (it == sampled_token_ids_map.end() || it->second.empty()) { - return nullptr; + if (it != sampled_token_ids_map.end() && !it->second.empty()) { + return it->second.data(); } - return it->second.data(); + return sampled_token_ids_full_vocab.data(); } size_t llama_context::get_backend_sampled_logits_count(int32_t idx) const { diff --git a/src/llama-context.h b/src/llama-context.h index b9020beff1..aba62e6e38 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -263,6 +263,7 @@ private: std::unordered_map> sampled_logits_map; std::unordered_map> sampled_token_ids_map; + std::vector sampled_token_ids_full_vocab; // embeddings output (2-dimensional array: [n_outputs][n_embd]) // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index dc9227c1a5..d210b826c7 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -460,28 +460,14 @@ llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_conte if (sampled_probs) { const uint32_t sampled_probs_count = llama_get_backend_sampled_probs_count_ith(ctx, idx); cur.reserve(sampled_probs_count); - // The backend sampler has filtered the probabilities so we need to use the sampled ids. - if (sampled_ids != nullptr) { - for (uint32_t i = 0; i < sampled_probs_count; ++i) { - cur.emplace_back(llama_token_data{sampled_ids[i], 0.0f, sampled_probs[i]}); - } - } else { - for (llama_token token_id = 0; token_id < (int) sampled_probs_count; token_id++) { - cur.emplace_back(llama_token_data{token_id, 0.0f, sampled_probs[token_id]}); - } + for (uint32_t i = 0; i < sampled_probs_count; ++i) { + cur.emplace_back(llama_token_data{sampled_ids[i], 0.0f, sampled_probs[i]}); } } else if (sampled_logits) { const uint32_t sampled_logits_count = llama_get_backend_sampled_logits_count_ith(ctx, idx); cur.reserve(sampled_logits_count); - // The backend sampler has filtered the logits so we need to use the sampled ids. - if (sampled_ids != nullptr) { - for (llama_token i = 0; i < (int)sampled_logits_count; i++) { - cur.emplace_back(llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f}); - } - } else { - for (llama_token token_id = 0; token_id < (int)sampled_logits_count; token_id++) { - cur.emplace_back(llama_token_data{token_id, sampled_logits[token_id], 0.0f}); - } + for (llama_token i = 0; i < (int)sampled_logits_count; i++) { + cur.emplace_back(llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f}); } } else { const auto * logits = llama_get_logits_ith(ctx, idx);