sampling : cleanup and clarify output_reserve

This commit is contained in:
Daniel Bevenius 2025-11-24 13:26:18 +01:00
parent d88ba1813c
commit 4a90583d7d
No known key found for this signature in database
1 changed files with 20 additions and 17 deletions

View File

@ -1723,40 +1723,41 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
logits = has_logits ? output_base : nullptr;
embd = has_embd ? output_base + logits_size : nullptr;
} else {
// Allocate worst case (full vocabulary size) for backend sampled
// data in the pinned memory buffer.
size_t offset = 0;
uint8_t * base = (uint8_t *) output_base;
if (sampling.logits_size > 0) {
sampling.logits = (float *) (base + offset);
offset += sampling.logits_size * sizeof(float);
}
if (sampling.probs_size > 0) {
sampling.probs = (float *) (base + offset);
offset += sampling.probs_size * sizeof(float);
}
if (sampling.sampled_size > 0) {
sampling.sampled = (llama_token *) (base + offset);
offset += sampling.sampled_size * sizeof(llama_token);
}
if (sampling.candidates_size > 0) {
sampling.candidates = (llama_token *) (base + offset);
offset += sampling.candidates_size * sizeof(llama_token);
}
sampling.logits = (float *) (base + offset);
offset += sampling.logits_size * sizeof(float);
sampling.probs = (float *) (base + offset);
offset += sampling.probs_size * sizeof(float);
sampling.sampled = (llama_token *) (base + offset);
offset += sampling.sampled_size * sizeof(llama_token);
sampling.candidates = (llama_token *) (base + offset);
offset += sampling.candidates_size * sizeof(llama_token);
// The count vectors keep track of the actual number of logits/probs/candidates
// copied from the backend for each output row.
const size_t n_rows = (size_t) n_outputs_max;
if (sampling.outputs_capacity < n_rows) {
// The output size has increased, so resize and reset the count vectors.
sampling.outputs_capacity = n_rows;
sampling.logits_count.assign(n_rows, 0);
sampling.probs_count.assign(n_rows, 0);
sampling.candidates_count.assign(n_rows, 0);
} else {
// The output size has not increased so just reset the counts to zero.
std::fill(sampling.logits_count.begin(), sampling.logits_count.end(), 0);
std::fill(sampling.probs_count.begin(), sampling.probs_count.end(), 0);
std::fill(sampling.candidates_count.begin(), sampling.candidates_count.end(), 0);
}
if (sampling.sampled && sampling.sampled_size > 0) {
if (sampling.sampled) {
std::fill_n(sampling.sampled, sampling.sampled_size, LLAMA_TOKEN_NULL);
}
}
@ -1814,9 +1815,11 @@ void llama_context::output_reorder() {
if (!sampling.logits_count.empty()) {
std::swap(sampling.logits_count[i0], sampling.logits_count[i1]);
}
if (!sampling.probs_count.empty()) {
std::swap(sampling.probs_count[i0], sampling.probs_count[i1]);
}
if (!sampling.candidates_count.empty()) {
std::swap(sampling.candidates_count[i0], sampling.candidates_count[i1]);
}