From 341bc7d23cf8bd77639f0c5b9dcb2dc8e9713768 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 15 Feb 2026 14:57:40 +0200 Subject: [PATCH] context : fix output reorder with backend sampling (#19638) --- src/llama-context.cpp | 47 ++++++++++++++++++++----------------------- src/llama-context.h | 14 +++++++------ 2 files changed, 30 insertions(+), 31 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index ac17e1a0fe..99035b6cac 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -878,6 +878,7 @@ const llama_token * llama_context::get_sampled_candidates_ith(int32_t idx) { } } catch (const std::exception & err) { // fallback to full vocab list + GGML_UNUSED(err); } return sampling.token_ids_full_vocab.data(); @@ -1809,7 +1810,6 @@ int llama_context::decode(const llama_batch & batch_inp) { // uint32_t llama_context::output_reserve(int32_t n_outputs) { - const auto & hparams = model.hparams; const auto & vocab = model.vocab; @@ -1893,11 +1893,6 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { embd = has_embd ? buffer_view{(float *) (base + offset), embd.size} : buffer_view{nullptr, 0}; offset += embd.size * sizeof(float); - sampling.logits = {nullptr, 0}; - sampling.probs = {nullptr, 0}; - sampling.sampled = {nullptr, 0}; - sampling.candidates = {nullptr, 0}; - if (has_sampling) { sampling.logits = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)}; offset += sampling.logits.size * sizeof(float); @@ -1923,6 +1918,15 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { std::fill(sampling.candidates_count.begin(), sampling.candidates_count.end(), 0); std::fill_n(sampling.sampled.data, sampling.sampled.size, LLAMA_TOKEN_NULL); + } else { + sampling.logits = {nullptr, 0}; + sampling.probs = {nullptr, 0}; + sampling.sampled = {nullptr, 0}; + sampling.candidates = {nullptr, 0}; + + sampling.logits_count.clear(); + sampling.probs_count.clear(); + sampling.candidates_count.clear(); } // set all ids as invalid (negative) @@ -1953,37 +1957,30 @@ void llama_context::output_reorder() { } } - if (sampling.logits.has_data()) { + if (!sampling.samplers.empty()) { + assert(sampling.logits.size > 0); + assert(sampling.probs.size > 0); + assert(sampling.candidates.size > 0); + assert(sampling.sampled.size > 0); + assert(sampling.logits_count.size() > 0); + assert(sampling.probs_count.size() > 0); + assert(sampling.candidates_count.size() > 0); + for (uint64_t k = 0; k < n_vocab; ++k) { std::swap(sampling.logits.data[i0*n_vocab + k], sampling.logits.data[i1*n_vocab + k]); } - } - if (sampling.probs.has_data()) { for (uint64_t k = 0; k < n_vocab; ++k) { std::swap(sampling.probs.data[i0*n_vocab + k], sampling.probs.data[i1*n_vocab + k]); } - } - if (sampling.candidates.has_data()) { for (uint64_t k = 0; k < n_vocab; ++k) { std::swap(sampling.candidates.data[i0*n_vocab + k], sampling.candidates.data[i1*n_vocab + k]); } - } - if (sampling.sampled.has_data()) { - std::swap(sampling.sampled.data[i0], sampling.sampled.data[i1]); - } - - if (!sampling.logits_count.empty()) { - std::swap(sampling.logits_count[i0], sampling.logits_count[i1]); - } - - if (!sampling.probs_count.empty()) { - std::swap(sampling.probs_count[i0], sampling.probs_count[i1]); - } - - if (!sampling.candidates_count.empty()) { + std::swap(sampling.sampled.data[i0], sampling.sampled.data[i1]); + std::swap(sampling.logits_count[i0], sampling.logits_count[i1]); + std::swap(sampling.probs_count[i0], sampling.probs_count[i1]); std::swap(sampling.candidates_count[i0], sampling.candidates_count[i1]); } } diff --git a/src/llama-context.h b/src/llama-context.h index 37117ba7b6..a8e53f335c 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -265,24 +265,26 @@ private: std::unique_ptr memory; // decode output (2-dimensional array: [n_outputs][n_vocab]) - struct buffer_view logits = {nullptr, 0}; + buffer_view logits = {nullptr, 0}; // embeddings output (2-dimensional array: [n_outputs][n_embd]) // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE - struct buffer_view embd = {nullptr, 0}; + buffer_view embd = {nullptr, 0}; struct sampling_info { + // !samplers.empty() to check if any samplers are active std::map samplers; - struct buffer_view logits = {nullptr, 0}; - struct buffer_view sampled = {nullptr, 0}; - struct buffer_view probs = {nullptr, 0}; - struct buffer_view candidates = {nullptr, 0}; + buffer_view logits = {nullptr, 0}; + buffer_view sampled = {nullptr, 0}; + buffer_view probs = {nullptr, 0}; + buffer_view candidates = {nullptr, 0}; std::vector logits_count; std::vector probs_count; std::vector candidates_count; + // optimization std::vector token_ids_full_vocab; };