sampling : remove sampling branching in output_reserve (#18811)

* sampling : remove sampling branching in output_reserve

This commit updates output_reserve in llama-context.cpp to always
allocate sampling buffers regardless of whether sampling is needed for
the current batch.

The motivation for this is to avoid reallocations and branching based on
the sampling requirements of the batch.
This commit is contained in:
Daniel Bevenius 2026-01-28 05:59:30 +01:00 committed by GitHub
parent 06961e2876
commit eef375ce16
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 33 additions and 57 deletions

View File

@ -253,11 +253,7 @@ llama_context::llama_context(
// graph outputs buffer // graph outputs buffer
{ {
// resized during inference when a batch uses more outputs if (output_reserve(params.n_seq_max) < params.n_seq_max) {
// Create a dummy batch for initialization.
llama_batch dummy_batch = {};
dummy_batch.n_tokens = 0;
if (output_reserve(params.n_seq_max, dummy_batch) < params.n_seq_max) {
throw std::runtime_error("failed to reserve initial output buffer"); throw std::runtime_error("failed to reserve initial output buffer");
} }
@ -1225,7 +1221,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
n_queued_tokens += n_tokens; n_queued_tokens += n_tokens;
// reserve output buffer // reserve output buffer
if (output_reserve(n_tokens, batch_inp) < n_tokens) { if (output_reserve(n_tokens) < n_tokens) {
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens); LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
return -2; return -2;
}; };
@ -1456,6 +1452,23 @@ static void copy_tensor_async_candidates(
} }
} }
static bool needs_raw_logits(const llama_ubatch & ubatch, const std::map<llama_seq_id, llama_sampler *> & samplers) {
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
if (!ubatch.output[i]) {
continue;
}
// Check if the output token has at least one sequence without a backend sampler.
for (int32_t j = 0; j < ubatch.n_seq_id[i]; ++j) {
llama_seq_id seq_id = ubatch.seq_id[i][j];
if (samplers.find(seq_id) == samplers.end()) {
return true;
}
}
}
return false; // all sequences use backend sampling
}
int llama_context::decode(const llama_batch & batch_inp) { int llama_context::decode(const llama_batch & batch_inp) {
GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
@ -1588,7 +1601,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
} }
// reserve output buffer // reserve output buffer
if (output_reserve(n_outputs_all, balloc->get_batch()) < n_outputs_all) { if (output_reserve(n_outputs_all) < n_outputs_all) {
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all); LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
return -2; return -2;
}; };
@ -1661,10 +1674,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
} }
// extract logits // extract logits
// For multi-sequence batches that mix backend samplers and CPU sampler if (logits && t_logits && n_outputs > 0 && needs_raw_logits(ubatch, sampling.samplers)) {
// this is currently inefficient as we copy all logits even for the
// backend sampled tokens.
if (logits && t_logits && n_outputs > 0) {
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
GGML_ASSERT(backend_res != nullptr); GGML_ASSERT(backend_res != nullptr);
GGML_ASSERT(logits != nullptr); GGML_ASSERT(logits != nullptr);
@ -1734,11 +1744,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
} }
} }
// This flag indicates whether a backend sampler has actually sampled a specific // Copy backend sampling output if this ubatch produced any sampling tensors.
// token, or if it has produced probabilites. If true, we can skip the normal copying of logits and embeddings. if (has_samplers && (!res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty())) {
const bool has_sampled = !res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty();
if (has_samplers && has_sampled) {
const auto seq_to_output_row = build_seq_to_output_row(ubatch, n_outputs_prev); const auto seq_to_output_row = build_seq_to_output_row(ubatch, n_outputs_prev);
const auto stride = n_vocab; const auto stride = n_vocab;
@ -1813,7 +1820,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
// output // output
// //
uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & batch) { uint32_t llama_context::output_reserve(int32_t n_outputs) {
const auto & hparams = model.hparams; const auto & hparams = model.hparams;
const auto & vocab = model.vocab; const auto & vocab = model.vocab;
@ -1832,45 +1840,16 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba
has_embd = true; has_embd = true;
} }
// Check which sampling modes are needed for the current batch.
// TODO: avoid this branching by working with the worst-case
bool has_sampling = false;
bool cpu_logits = false;
if (batch.logits) {
for (int32_t i = 0; i < batch.n_tokens; i++) {
if (!batch.logits[i]) {
continue;
}
for (int32_t j = 0; j < batch.n_seq_id[i]; j++) {
llama_seq_id seq_id = batch.seq_id[i][j];
if (sampling.samplers.find(seq_id) != sampling.samplers.end()) {
has_sampling = true;
} else {
cpu_logits = true;
}
}
}
} else {
// When batch.logits is nullptr (when loading state with a dummy batch),
// allocate CPU logits.
cpu_logits = true;
}
size_t backend_float_count = 0; size_t backend_float_count = 0;
size_t backend_token_count = 0; size_t backend_token_count = 0;
// Allocate CPU logits buffer only if needed by sequences in this batch logits_size = has_logits ? n_vocab*n_outputs_max : 0;
logits_size = (has_logits && cpu_logits) ? n_vocab*n_outputs_max : 0;
embd_size = has_embd ? n_embd_out*n_outputs_max : 0; embd_size = has_embd ? n_embd_out*n_outputs_max : 0;
// TODO: avoid this branching by working with the worst-case // Allocate backend sampling output buffers if there are backend samplers configured.
if (!has_sampling) { const bool has_sampling = !sampling.samplers.empty();
sampling.logits_size = 0; if (has_sampling) {
sampling.probs_size = 0;
sampling.sampled_size = 0;
sampling.candidates_size = 0;
} else {
sampling.logits_size = n_vocab*n_outputs_max; sampling.logits_size = n_vocab*n_outputs_max;
sampling.probs_size = n_vocab*n_outputs_max; sampling.probs_size = n_vocab*n_outputs_max;
sampling.sampled_size = n_outputs_max; sampling.sampled_size = n_outputs_max;
@ -1928,7 +1907,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba
size_t offset = 0; size_t offset = 0;
uint8_t * base = (uint8_t *) output_base; uint8_t * base = (uint8_t *) output_base;
logits = (has_logits && cpu_logits) ? output_base : nullptr; logits = has_logits ? output_base : nullptr;
offset += logits_size * sizeof(float); offset += logits_size * sizeof(float);
embd = has_embd ? (float *) (base + offset) : nullptr; embd = has_embd ? (float *) (base + offset) : nullptr;
@ -2614,10 +2593,7 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
auto n_outputs = this->n_outputs; auto n_outputs = this->n_outputs;
io.read_to(&n_outputs, sizeof(n_outputs)); io.read_to(&n_outputs, sizeof(n_outputs));
// Create a dummy batch for state loading. if (n_outputs > output_reserve(n_outputs)) {
llama_batch dummy_batch = {};
dummy_batch.n_tokens = 0;
if (n_outputs > output_reserve(n_outputs, dummy_batch)) {
throw std::runtime_error("could not reserve outputs"); throw std::runtime_error("could not reserve outputs");
} }
@ -2862,7 +2838,7 @@ void llama_context::opt_epoch_iter(
} }
// reserve output buffer // reserve output buffer
if (output_reserve(n_outputs_all, balloc->get_batch()) < n_outputs_all) { if (output_reserve(n_outputs_all) < n_outputs_all) {
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all); LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
GGML_ABORT("TODO: handle this error"); GGML_ABORT("TODO: handle this error");
}; };

View File

@ -212,7 +212,7 @@ private:
// Make sure enough space is available for outputs. // Make sure enough space is available for outputs.
// Returns max number of outputs for which space was reserved. // Returns max number of outputs for which space was reserved.
uint32_t output_reserve(int32_t n_outputs, const llama_batch & batch); uint32_t output_reserve(int32_t n_outputs);
void output_reorder(); void output_reorder();