context : fix output reorder with backend sampling (#19638)

This commit is contained in:
Georgi Gerganov 2026-02-15 14:57:40 +02:00 committed by GitHub
parent 08e6d914b8
commit 341bc7d23c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 30 additions and 31 deletions

View File

@ -878,6 +878,7 @@ const llama_token * llama_context::get_sampled_candidates_ith(int32_t idx) {
}
} catch (const std::exception & err) {
// fallback to full vocab list
GGML_UNUSED(err);
}
return sampling.token_ids_full_vocab.data();
@ -1809,7 +1810,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
//
uint32_t llama_context::output_reserve(int32_t n_outputs) {
const auto & hparams = model.hparams;
const auto & vocab = model.vocab;
@ -1893,11 +1893,6 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
embd = has_embd ? buffer_view<float>{(float *) (base + offset), embd.size} : buffer_view<float>{nullptr, 0};
offset += embd.size * sizeof(float);
sampling.logits = {nullptr, 0};
sampling.probs = {nullptr, 0};
sampling.sampled = {nullptr, 0};
sampling.candidates = {nullptr, 0};
if (has_sampling) {
sampling.logits = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)};
offset += sampling.logits.size * sizeof(float);
@ -1923,6 +1918,15 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
std::fill(sampling.candidates_count.begin(), sampling.candidates_count.end(), 0);
std::fill_n(sampling.sampled.data, sampling.sampled.size, LLAMA_TOKEN_NULL);
} else {
sampling.logits = {nullptr, 0};
sampling.probs = {nullptr, 0};
sampling.sampled = {nullptr, 0};
sampling.candidates = {nullptr, 0};
sampling.logits_count.clear();
sampling.probs_count.clear();
sampling.candidates_count.clear();
}
// set all ids as invalid (negative)
@ -1953,37 +1957,30 @@ void llama_context::output_reorder() {
}
}
if (sampling.logits.has_data()) {
if (!sampling.samplers.empty()) {
assert(sampling.logits.size > 0);
assert(sampling.probs.size > 0);
assert(sampling.candidates.size > 0);
assert(sampling.sampled.size > 0);
assert(sampling.logits_count.size() > 0);
assert(sampling.probs_count.size() > 0);
assert(sampling.candidates_count.size() > 0);
for (uint64_t k = 0; k < n_vocab; ++k) {
std::swap(sampling.logits.data[i0*n_vocab + k], sampling.logits.data[i1*n_vocab + k]);
}
}
if (sampling.probs.has_data()) {
for (uint64_t k = 0; k < n_vocab; ++k) {
std::swap(sampling.probs.data[i0*n_vocab + k], sampling.probs.data[i1*n_vocab + k]);
}
}
if (sampling.candidates.has_data()) {
for (uint64_t k = 0; k < n_vocab; ++k) {
std::swap(sampling.candidates.data[i0*n_vocab + k], sampling.candidates.data[i1*n_vocab + k]);
}
}
if (sampling.sampled.has_data()) {
std::swap(sampling.sampled.data[i0], sampling.sampled.data[i1]);
}
if (!sampling.logits_count.empty()) {
std::swap(sampling.logits_count[i0], sampling.logits_count[i1]);
}
if (!sampling.probs_count.empty()) {
std::swap(sampling.probs_count[i0], sampling.probs_count[i1]);
}
if (!sampling.candidates_count.empty()) {
std::swap(sampling.sampled.data[i0], sampling.sampled.data[i1]);
std::swap(sampling.logits_count[i0], sampling.logits_count[i1]);
std::swap(sampling.probs_count[i0], sampling.probs_count[i1]);
std::swap(sampling.candidates_count[i0], sampling.candidates_count[i1]);
}
}

View File

@ -265,24 +265,26 @@ private:
std::unique_ptr<llama_memory_i> memory;
// decode output (2-dimensional array: [n_outputs][n_vocab])
struct buffer_view<float> logits = {nullptr, 0};
buffer_view<float> logits = {nullptr, 0};
// embeddings output (2-dimensional array: [n_outputs][n_embd])
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
struct buffer_view<float> embd = {nullptr, 0};
buffer_view<float> embd = {nullptr, 0};
struct sampling_info {
// !samplers.empty() to check if any samplers are active
std::map<llama_seq_id, llama_sampler *> samplers;
struct buffer_view<float> logits = {nullptr, 0};
struct buffer_view<llama_token> sampled = {nullptr, 0};
struct buffer_view<float> probs = {nullptr, 0};
struct buffer_view<llama_token> candidates = {nullptr, 0};
buffer_view<float> logits = {nullptr, 0};
buffer_view<llama_token> sampled = {nullptr, 0};
buffer_view<float> probs = {nullptr, 0};
buffer_view<llama_token> candidates = {nullptr, 0};
std::vector<uint32_t> logits_count;
std::vector<uint32_t> probs_count;
std::vector<uint32_t> candidates_count;
// optimization
std::vector<llama_token> token_ids_full_vocab;
};