diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 9ccd8f3998..0ed729f498 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1723,40 +1723,41 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { logits = has_logits ? output_base : nullptr; embd = has_embd ? output_base + logits_size : nullptr; } else { + // Allocate worst case (full vocabulary size) for backend sampled + // data in the pinned memory buffer. size_t offset = 0; uint8_t * base = (uint8_t *) output_base; - if (sampling.logits_size > 0) { - sampling.logits = (float *) (base + offset); - offset += sampling.logits_size * sizeof(float); - } - if (sampling.probs_size > 0) { - sampling.probs = (float *) (base + offset); - offset += sampling.probs_size * sizeof(float); - } - if (sampling.sampled_size > 0) { - sampling.sampled = (llama_token *) (base + offset); - offset += sampling.sampled_size * sizeof(llama_token); - } - if (sampling.candidates_size > 0) { - sampling.candidates = (llama_token *) (base + offset); - offset += sampling.candidates_size * sizeof(llama_token); - } + sampling.logits = (float *) (base + offset); + offset += sampling.logits_size * sizeof(float); + sampling.probs = (float *) (base + offset); + offset += sampling.probs_size * sizeof(float); + + sampling.sampled = (llama_token *) (base + offset); + offset += sampling.sampled_size * sizeof(llama_token); + + sampling.candidates = (llama_token *) (base + offset); + offset += sampling.candidates_size * sizeof(llama_token); + + // The count vectors keep track of the actual number of logits/probs/candidates + // copied from the backend for each output row. const size_t n_rows = (size_t) n_outputs_max; if (sampling.outputs_capacity < n_rows) { + // The output size has increased, so resize and reset the count vectors. sampling.outputs_capacity = n_rows; sampling.logits_count.assign(n_rows, 0); sampling.probs_count.assign(n_rows, 0); sampling.candidates_count.assign(n_rows, 0); } else { + // The output size has not increased so just reset the counts to zero. std::fill(sampling.logits_count.begin(), sampling.logits_count.end(), 0); std::fill(sampling.probs_count.begin(), sampling.probs_count.end(), 0); std::fill(sampling.candidates_count.begin(), sampling.candidates_count.end(), 0); } - if (sampling.sampled && sampling.sampled_size > 0) { + if (sampling.sampled) { std::fill_n(sampling.sampled, sampling.sampled_size, LLAMA_TOKEN_NULL); } } @@ -1814,9 +1815,11 @@ void llama_context::output_reorder() { if (!sampling.logits_count.empty()) { std::swap(sampling.logits_count[i0], sampling.logits_count[i1]); } + if (!sampling.probs_count.empty()) { std::swap(sampling.probs_count[i0], sampling.probs_count[i1]); } + if (!sampling.candidates_count.empty()) { std::swap(sampling.candidates_count[i0], sampling.candidates_count[i1]); }