sampling : cleanup and clarify output_reserve
This commit is contained in:
parent
d88ba1813c
commit
4a90583d7d
|
|
@ -1723,40 +1723,41 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
|||
logits = has_logits ? output_base : nullptr;
|
||||
embd = has_embd ? output_base + logits_size : nullptr;
|
||||
} else {
|
||||
// Allocate worst case (full vocabulary size) for backend sampled
|
||||
// data in the pinned memory buffer.
|
||||
size_t offset = 0;
|
||||
uint8_t * base = (uint8_t *) output_base;
|
||||
|
||||
if (sampling.logits_size > 0) {
|
||||
sampling.logits = (float *) (base + offset);
|
||||
offset += sampling.logits_size * sizeof(float);
|
||||
}
|
||||
if (sampling.probs_size > 0) {
|
||||
sampling.probs = (float *) (base + offset);
|
||||
offset += sampling.probs_size * sizeof(float);
|
||||
}
|
||||
if (sampling.sampled_size > 0) {
|
||||
sampling.sampled = (llama_token *) (base + offset);
|
||||
offset += sampling.sampled_size * sizeof(llama_token);
|
||||
}
|
||||
if (sampling.candidates_size > 0) {
|
||||
sampling.candidates = (llama_token *) (base + offset);
|
||||
offset += sampling.candidates_size * sizeof(llama_token);
|
||||
}
|
||||
sampling.logits = (float *) (base + offset);
|
||||
offset += sampling.logits_size * sizeof(float);
|
||||
|
||||
sampling.probs = (float *) (base + offset);
|
||||
offset += sampling.probs_size * sizeof(float);
|
||||
|
||||
sampling.sampled = (llama_token *) (base + offset);
|
||||
offset += sampling.sampled_size * sizeof(llama_token);
|
||||
|
||||
sampling.candidates = (llama_token *) (base + offset);
|
||||
offset += sampling.candidates_size * sizeof(llama_token);
|
||||
|
||||
// The count vectors keep track of the actual number of logits/probs/candidates
|
||||
// copied from the backend for each output row.
|
||||
const size_t n_rows = (size_t) n_outputs_max;
|
||||
if (sampling.outputs_capacity < n_rows) {
|
||||
// The output size has increased, so resize and reset the count vectors.
|
||||
sampling.outputs_capacity = n_rows;
|
||||
|
||||
sampling.logits_count.assign(n_rows, 0);
|
||||
sampling.probs_count.assign(n_rows, 0);
|
||||
sampling.candidates_count.assign(n_rows, 0);
|
||||
} else {
|
||||
// The output size has not increased so just reset the counts to zero.
|
||||
std::fill(sampling.logits_count.begin(), sampling.logits_count.end(), 0);
|
||||
std::fill(sampling.probs_count.begin(), sampling.probs_count.end(), 0);
|
||||
std::fill(sampling.candidates_count.begin(), sampling.candidates_count.end(), 0);
|
||||
}
|
||||
|
||||
if (sampling.sampled && sampling.sampled_size > 0) {
|
||||
if (sampling.sampled) {
|
||||
std::fill_n(sampling.sampled, sampling.sampled_size, LLAMA_TOKEN_NULL);
|
||||
}
|
||||
}
|
||||
|
|
@ -1814,9 +1815,11 @@ void llama_context::output_reorder() {
|
|||
if (!sampling.logits_count.empty()) {
|
||||
std::swap(sampling.logits_count[i0], sampling.logits_count[i1]);
|
||||
}
|
||||
|
||||
if (!sampling.probs_count.empty()) {
|
||||
std::swap(sampling.probs_count[i0], sampling.probs_count[i1]);
|
||||
}
|
||||
|
||||
if (!sampling.candidates_count.empty()) {
|
||||
std::swap(sampling.candidates_count[i0], sampling.candidates_count[i1]);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue