sampling : cleanup and clarify output_reserve
This commit is contained in:
parent
d88ba1813c
commit
4a90583d7d
|
|
@ -1723,40 +1723,41 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
||||||
logits = has_logits ? output_base : nullptr;
|
logits = has_logits ? output_base : nullptr;
|
||||||
embd = has_embd ? output_base + logits_size : nullptr;
|
embd = has_embd ? output_base + logits_size : nullptr;
|
||||||
} else {
|
} else {
|
||||||
|
// Allocate worst case (full vocabulary size) for backend sampled
|
||||||
|
// data in the pinned memory buffer.
|
||||||
size_t offset = 0;
|
size_t offset = 0;
|
||||||
uint8_t * base = (uint8_t *) output_base;
|
uint8_t * base = (uint8_t *) output_base;
|
||||||
|
|
||||||
if (sampling.logits_size > 0) {
|
sampling.logits = (float *) (base + offset);
|
||||||
sampling.logits = (float *) (base + offset);
|
offset += sampling.logits_size * sizeof(float);
|
||||||
offset += sampling.logits_size * sizeof(float);
|
|
||||||
}
|
|
||||||
if (sampling.probs_size > 0) {
|
|
||||||
sampling.probs = (float *) (base + offset);
|
|
||||||
offset += sampling.probs_size * sizeof(float);
|
|
||||||
}
|
|
||||||
if (sampling.sampled_size > 0) {
|
|
||||||
sampling.sampled = (llama_token *) (base + offset);
|
|
||||||
offset += sampling.sampled_size * sizeof(llama_token);
|
|
||||||
}
|
|
||||||
if (sampling.candidates_size > 0) {
|
|
||||||
sampling.candidates = (llama_token *) (base + offset);
|
|
||||||
offset += sampling.candidates_size * sizeof(llama_token);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
sampling.probs = (float *) (base + offset);
|
||||||
|
offset += sampling.probs_size * sizeof(float);
|
||||||
|
|
||||||
|
sampling.sampled = (llama_token *) (base + offset);
|
||||||
|
offset += sampling.sampled_size * sizeof(llama_token);
|
||||||
|
|
||||||
|
sampling.candidates = (llama_token *) (base + offset);
|
||||||
|
offset += sampling.candidates_size * sizeof(llama_token);
|
||||||
|
|
||||||
|
// The count vectors keep track of the actual number of logits/probs/candidates
|
||||||
|
// copied from the backend for each output row.
|
||||||
const size_t n_rows = (size_t) n_outputs_max;
|
const size_t n_rows = (size_t) n_outputs_max;
|
||||||
if (sampling.outputs_capacity < n_rows) {
|
if (sampling.outputs_capacity < n_rows) {
|
||||||
|
// The output size has increased, so resize and reset the count vectors.
|
||||||
sampling.outputs_capacity = n_rows;
|
sampling.outputs_capacity = n_rows;
|
||||||
|
|
||||||
sampling.logits_count.assign(n_rows, 0);
|
sampling.logits_count.assign(n_rows, 0);
|
||||||
sampling.probs_count.assign(n_rows, 0);
|
sampling.probs_count.assign(n_rows, 0);
|
||||||
sampling.candidates_count.assign(n_rows, 0);
|
sampling.candidates_count.assign(n_rows, 0);
|
||||||
} else {
|
} else {
|
||||||
|
// The output size has not increased so just reset the counts to zero.
|
||||||
std::fill(sampling.logits_count.begin(), sampling.logits_count.end(), 0);
|
std::fill(sampling.logits_count.begin(), sampling.logits_count.end(), 0);
|
||||||
std::fill(sampling.probs_count.begin(), sampling.probs_count.end(), 0);
|
std::fill(sampling.probs_count.begin(), sampling.probs_count.end(), 0);
|
||||||
std::fill(sampling.candidates_count.begin(), sampling.candidates_count.end(), 0);
|
std::fill(sampling.candidates_count.begin(), sampling.candidates_count.end(), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (sampling.sampled && sampling.sampled_size > 0) {
|
if (sampling.sampled) {
|
||||||
std::fill_n(sampling.sampled, sampling.sampled_size, LLAMA_TOKEN_NULL);
|
std::fill_n(sampling.sampled, sampling.sampled_size, LLAMA_TOKEN_NULL);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -1814,9 +1815,11 @@ void llama_context::output_reorder() {
|
||||||
if (!sampling.logits_count.empty()) {
|
if (!sampling.logits_count.empty()) {
|
||||||
std::swap(sampling.logits_count[i0], sampling.logits_count[i1]);
|
std::swap(sampling.logits_count[i0], sampling.logits_count[i1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!sampling.probs_count.empty()) {
|
if (!sampling.probs_count.empty()) {
|
||||||
std::swap(sampling.probs_count[i0], sampling.probs_count[i1]);
|
std::swap(sampling.probs_count[i0], sampling.probs_count[i1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!sampling.candidates_count.empty()) {
|
if (!sampling.candidates_count.empty()) {
|
||||||
std::swap(sampling.candidates_count[i0], sampling.candidates_count[i1]);
|
std::swap(sampling.candidates_count[i0], sampling.candidates_count[i1]);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue