sampling : use pinned memory for backend sampling buffers

This commit is contained in:
Daniel Bevenius 2025-11-21 14:02:16 +01:00
parent c1625620f6
commit 61ffe41dc1
No known key found for this signature in database
7 changed files with 358 additions and 120 deletions

4
common/build-info.cpp Normal file
View File

@ -0,0 +1,4 @@
int LLAMA_BUILD_NUMBER = 5590;
char const *LLAMA_COMMIT = "0d398442";
char const *LLAMA_COMPILER = "cc (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0";
char const *LLAMA_BUILD_TARGET = "x86_64-linux-gnu";

View File

@ -113,9 +113,9 @@ struct common_sampler {
llama_token_data_array cur_p;
void set_logits(struct llama_context * ctx, int idx) {
const float * sampled_probs = llama_get_backend_sampled_probs_ith (ctx, idx);
const float * sampled_logits = llama_get_backend_sampled_logits_ith (ctx, idx);
const llama_token * sampled_ids = llama_get_backend_sampled_token_ids_ith(ctx, idx);
const float * sampled_probs = llama_get_backend_sampled_probs_ith (ctx, idx);
const float * sampled_logits = llama_get_backend_sampled_logits_ith (ctx, idx);
const llama_token * sampled_ids = llama_get_backend_sampled_candidates_ith(ctx, idx);
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);

View File

@ -974,20 +974,23 @@ extern "C" {
// The index matches llama_get_backend_sampled_token_ith().
// Returns NULL if no probabilites were generated.
LLAMA_API float * llama_get_backend_sampled_probs_ith(struct llama_context * ctx, int32_t i);
//
// Get the number of backend sampled probabilites for the ith token.
LLAMA_API uint32_t llama_get_backend_sampled_probs_count_ith(struct llama_context * ctx, int32_t i);
// Get the backend sampled logits for the ith token
// Returns NULL if no logits were sampled.
LLAMA_API float * llama_get_backend_sampled_logits_ith(struct llama_context * ctx, int32_t i);
// Get the backend sampled token ids associated with the sampled logits for the ith token
// Returns NULL if no logits were sampled.
LLAMA_API llama_token * llama_get_backend_sampled_token_ids_ith(struct llama_context * ctx, int32_t i);
//
// Get the number of backend sampled logits for the ith token.
LLAMA_API uint32_t llama_get_backend_sampled_logits_count_ith(struct llama_context * ctx, int32_t i);
// Get the number of backend sampled probabilites for the ith token.
LLAMA_API uint32_t llama_get_backend_sampled_probs_count_ith(struct llama_context * ctx, int32_t i);
// Get the backend sampled candidates (token ids) for the ith token
// Returns NULL if no candidates were sampled.
LLAMA_API llama_token * llama_get_backend_sampled_candidates_ith(struct llama_context * ctx, int32_t i);
//
// Get the number of backend sampled candidates for the ith token.
LLAMA_API uint32_t llama_get_backend_sampled_candidates_count_ith(struct llama_context * ctx, int32_t i);
//
// Vocab

View File

@ -588,6 +588,35 @@ float * llama_context::get_logits() {
return logits;
}
int64_t llama_context::resolve_output_row(int32_t i) const {
int64_t j = -1;
// support negative indices (last output row)
if (i < 0) {
j = n_outputs + i;
if (j < 0) {
throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs));
}
} else if ((size_t) i >= output_ids.size()) {
throw std::runtime_error(format("out of range [0, %zu)", output_ids.size()));
} else {
// use output_ids to translate the batch token index into a row number
// that holds this token's data.
j = output_ids[i];
}
if (j < 0) {
// the batch token was not configured to output anything
throw std::runtime_error(format("batch.logits[%d] != true", i));
}
if (j >= n_outputs) {
throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
}
return j;
}
float * llama_context::get_logits_ith(int32_t i) {
int64_t j = -1;
@ -688,100 +717,135 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) {
}
llama_token llama_context::get_backend_sampled_token_ith(int32_t idx) {
// Handle special case where idx == -1 (single sequence exists) which is
// a valid index when using common_sampler_sample.
if (idx == -1) {
if (sampling.map_sampled.size() == 1) {
auto it = sampling.map_sampled.begin();
return it->second;
}
output_reorder();
if (sampling.sampled == nullptr) {
return LLAMA_TOKEN_NULL;
}
auto it = sampling.map_sampled.find(idx);
if (it == sampling.map_sampled.end()) {
try {
const int64_t row = resolve_output_row(idx);
GGML_ASSERT(row < (int64_t) sampling.sampled_size);
return sampling.sampled[row];
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: invalid backend sampled token id %d, reason: %s\n", __func__, idx, err.what());
return LLAMA_TOKEN_NULL;
}
return it->second;
}
float * llama_context::get_backend_sampled_probs_ith(int32_t idx) {
if (idx == -1) {
if (sampling.map_probs.size() == 1) {
return sampling.map_probs.begin()->second.data();
}
}
output_reorder();
auto it = sampling.map_probs.find(idx);
if (it == sampling.map_probs.end()) {
if (sampling.probs == nullptr) {
return nullptr;
}
return it->second.data();
try {
const int64_t row = resolve_output_row(idx);
if ((size_t) row >= sampling.probs_count.size() || sampling.probs_count[row] == 0) {
return nullptr;
}
return sampling.probs + row*model.vocab.n_tokens();
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: invalid backend sampled probs id %d, reason: %s\n", __func__, idx, err.what());
return nullptr;
}
}
float * llama_context::get_backend_sampled_logits_ith(int32_t idx) {
if (idx == -1) {
if (sampling.map_logits.size() == 1) {
return sampling.map_logits.begin()->second.data();
}
}
auto it = sampling.map_logits.find(idx);
if (it == sampling.map_logits.end()) {
output_reorder();
if (sampling.logits == nullptr) {
return nullptr;
}
return it->second.data();
try {
const int64_t row = resolve_output_row(idx);
if ((size_t) row >= sampling.logits_count.size() || sampling.logits_count[row] == 0) {
return nullptr;
}
return sampling.logits + row*model.vocab.n_tokens();
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: invalid backend sampled logits id %d, reason: %s\n", __func__, idx, err.what());
return nullptr;
}
}
const llama_token * llama_context::get_backend_sampled_token_ids_ith(int32_t idx) {
if (idx == -1) {
if (sampling.map_cadidates.size() == 1) {
const auto & vec = sampling.map_cadidates.begin()->second;
if (!vec.empty()) {
return vec.data();
}
const llama_token * llama_context::get_backend_sampled_candidates_ith(int32_t idx) {
output_reorder();
try {
const int64_t row = resolve_output_row(idx);
if (sampling.candidates != nullptr &&
(size_t) row < sampling.candidates_count.size() &&
sampling.candidates_count[row] > 0) {
return sampling.candidates + row*model.vocab.n_tokens();
}
}
auto it = sampling.map_cadidates.find(idx);
if (it != sampling.map_cadidates.end() && !it->second.empty()) {
return it->second.data();
} catch (const std::exception & err) {
// fallback to full vocab list
}
return sampling.token_ids_full_vocab.data();
}
size_t llama_context::get_backend_sampled_logits_count(int32_t idx) const {
if (idx == -1) {
if (sampling.map_logits.size() == 1) {
return sampling.map_logits.begin()->second.size();
}
}
auto it = sampling.map_logits.find(idx);
if (it == sampling.map_logits.end()) {
size_t llama_context::get_backend_sampled_candidates_count(int32_t idx) {
output_reorder();
if (sampling.candidates == nullptr) {
return 0;
}
return it->second.size();
try {
const int64_t row = resolve_output_row(idx);
if ((size_t) row >= sampling.candidates_count.size()) {
return 0;
}
return sampling.candidates_count[row];
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: invalid backend sampled candidates count id %d, reason: %s\n", __func__, idx, err.what());
return 0;
}
}
size_t llama_context::get_backend_sampled_probs_count(int32_t idx) const {
if (idx == -1) {
if (sampling.map_probs.size() == 1) {
return sampling.map_probs.begin()->second.size();
size_t llama_context::get_backend_sampled_logits_count(int32_t idx) {
output_reorder();
if (sampling.logits == nullptr) {
return 0;
}
try {
const int64_t row = resolve_output_row(idx);
if ((size_t) row >= sampling.logits_count.size()) {
return 0;
}
return sampling.logits_count[row];
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: invalid backend sampled logits count id %d, reason: %s\n", __func__, idx, err.what());
return 0;
}
auto it = sampling.map_probs.find(idx);
if (it == sampling.map_probs.end()) {
return 0;
}
return it->second.size();
}
size_t llama_context::get_backend_sampled_probs_count(int32_t idx) {
output_reorder();
if (sampling.probs == nullptr) {
return 0;
}
try {
const int64_t row = resolve_output_row(idx);
if ((size_t) row >= sampling.probs_count.size()) {
return 0;
}
return sampling.probs_count[row];
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: invalid backend sampled probs count id %d, reason: %s\n", __func__, idx, err.what());
return 0;
}
}
void llama_context::attach_threadpool(
ggml_threadpool_t threadpool,
ggml_threadpool_t threadpool_batch) {
@ -1133,51 +1197,94 @@ int llama_context::encode(const llama_batch & batch_inp) {
return 0;
}
static std::unordered_map<llama_seq_id, int32_t> build_seq_to_batch_idx(const llama_ubatch & ubatch) {
std::unordered_map<llama_seq_id, int32_t> seq_to_batch_idx;
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
if (ubatch.output[i]) {
seq_to_batch_idx[ubatch.seq_id[i][0]] = i;
static std::unordered_map<llama_seq_id, uint32_t> build_seq_to_output_row(const llama_ubatch & ubatch, uint32_t row_offset) {
std::unordered_map<llama_seq_id, uint32_t> seq_to_row;
// how many output tokens we have seen so far for this ubatch.
uint32_t local = 0;
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
// skip tokens that are not output.
if (!ubatch.output[i]) {
continue;
}
const llama_seq_id seq_id = ubatch.seq_id[i][0];
// row_offset is the number of output tokens before this ubatch.
seq_to_row[seq_id] = row_offset + local;
++local;
}
return seq_to_batch_idx;
return seq_to_row;
}
static void copy_tensor_async_int(
static void copy_tensor_async_ints(
const std::unordered_map<llama_seq_id, ggml_tensor*> & tensor_map,
std::unordered_map<int32_t, llama_token> & output_map,
const std::unordered_map<llama_seq_id, int32_t> & seq_to_batch_idx,
llama_token * sampled,
size_t sampled_size,
const std::unordered_map<llama_seq_id, uint32_t> & seq_to_row,
ggml_backend_sched_t sched) {
if (sampled == nullptr || sampled_size == 0) {
return;
}
for (const auto & [seq_id, tensor] : tensor_map) {
const int32_t idx = seq_to_batch_idx.at(seq_id);
auto it = seq_to_row.find(seq_id);
GGML_ASSERT(it != seq_to_row.end());
const uint32_t row = it->second;
GGML_ASSERT(row < sampled_size);
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
ggml_backend_tensor_get_async(backend, tensor, &output_map[idx], 0, sizeof(output_map[idx]));
ggml_backend_tensor_get_async(backend, tensor, sampled + row, 0, sizeof(sampled[row]));
}
}
static void copy_tensor_async_floats(
const std::unordered_map<llama_seq_id, ggml_tensor*> & tensor_map,
std::unordered_map<int32_t, std::vector<float>> & output_map,
const std::unordered_map<llama_seq_id, int32_t> & seq_to_batch_idx,
float * dst,
size_t stride,
std::vector<uint32_t> & counts,
const std::unordered_map<llama_seq_id, uint32_t> & seq_to_row,
ggml_backend_sched_t sched) {
if (dst == nullptr || stride == 0) {
return;
}
for (const auto & [seq_id, tensor] : tensor_map) {
const int32_t idx = seq_to_batch_idx.at(seq_id);
auto it = seq_to_row.find(seq_id);
GGML_ASSERT(it != seq_to_row.end());
const uint32_t row = it->second;
GGML_ASSERT(row < counts.size());
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
output_map[idx].resize(ggml_nelements(tensor));
ggml_backend_tensor_get_async(backend, tensor, output_map[idx].data(), 0, ggml_nbytes(tensor));
float * row_ptr = dst + (size_t) row * stride;
ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor));
// Update the actual number of logits/probabilities that were written for this row.
counts[row] = ggml_nelements(tensor);
}
}
static void copy_tensor_async_token_ids(
static void copy_tensor_async_candidates(
const std::unordered_map<llama_seq_id, ggml_tensor*> & tensor_map,
std::unordered_map<int32_t, std::vector<llama_token>> & output_map,
const std::unordered_map<llama_seq_id, int32_t> & seq_to_batch_idx,
llama_token * dst,
size_t stride,
std::vector<uint32_t> & counts,
const std::unordered_map<llama_seq_id, uint32_t> & seq_to_row,
ggml_backend_sched_t sched) {
if (dst == nullptr || stride == 0) {
return;
}
for (const auto & [seq_id, tensor] : tensor_map) {
const int32_t idx = seq_to_batch_idx.at(seq_id);
auto it = seq_to_row.find(seq_id);
GGML_ASSERT(it != seq_to_row.end());
const uint32_t row = it->second;
GGML_ASSERT(row < counts.size());
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
output_map[idx].resize(ggml_nelements(tensor));
ggml_backend_tensor_get_async(backend, tensor, output_map[idx].data(), 0, ggml_nbytes(tensor));
llama_token * row_ptr = dst + (size_t) row * stride;
ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor));
// Update the actual number of candidates that were written.
counts[row] = ggml_nelements(tensor);
}
}
@ -1235,10 +1342,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
// TODO: this clear of the buffer can easily be forgotten - need something better
embd_seq.clear();
sampling.map_probs.clear();
sampling.map_logits.clear();
sampling.map_sampled.clear();
sampling.map_cadidates.clear();
output_swaps.clear();
bool did_optimize = false;
@ -1364,24 +1467,24 @@ int llama_context::decode(const llama_batch & batch_inp) {
backend_has_sampled = !res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty();
if (has_backend_samplers && backend_has_sampled) {
const auto seq_to_batch_idx = build_seq_to_batch_idx(ubatch);
const auto seq_to_output_row = build_seq_to_output_row(ubatch, n_outputs_prev);
// If a backend sampler has sampled a token we only want to copy the
// sampled tokens and avoid copying logits and probabilites.
if (!res->t_sampled.empty()) {
// async copy the sampled tokens from the backend to the host.
copy_tensor_async_int(res->t_sampled, sampling.map_sampled, seq_to_batch_idx, sched.get());
copy_tensor_async_ints(res->t_sampled, sampling.sampled, sampling.sampled_size, seq_to_output_row, sched.get());
} else {
// async copy the sampled logits/probs from the backend to the host.
copy_tensor_async_floats(res->t_sampled_logits, sampling.map_logits, seq_to_batch_idx, sched.get());
copy_tensor_async_floats(res->t_sampled_probs, sampling.map_probs, seq_to_batch_idx, sched.get());
copy_tensor_async_floats(res->t_sampled_logits, sampling.logits, n_vocab, sampling.logits_count, seq_to_output_row, sched.get());
copy_tensor_async_floats(res->t_sampled_probs, sampling.probs, n_vocab, sampling.probs_count, seq_to_output_row, sched.get());
}
// async copy the candidate token ids from the backend to the host.
// These are needed for:
// 1) Backend dist sampler to map indices to vocab token ids.
// 2) CPU samplers to associate candidate logits with their token ids.
copy_tensor_async_token_ids(res->t_candidates, sampling.map_cadidates, seq_to_batch_idx, sched.get());
copy_tensor_async_candidates(res->t_candidates, sampling.candidates, n_vocab, sampling.candidates_count, seq_to_output_row, sched.get());
}
@ -1471,7 +1574,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
n_outputs = n_outputs_all;
// set output mappings
if (n_outputs > 0 && !backend_has_sampled) {
if (n_outputs > 0) {
bool sorted_output = true;
auto & out_ids = balloc->get_out_ids();
@ -1546,8 +1649,31 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
has_embd = true;
}
logits_size = has_logits ? n_vocab*n_outputs_max : 0;
embd_size = has_embd ? n_embd*n_outputs_max : 0;
const bool backend_sampling = !sampling.samplers.empty();
size_t backend_float_count = 0;
size_t backend_token_count = 0;
if (!backend_sampling) {
logits_size = has_logits ? n_vocab*n_outputs_max : 0;
embd_size = has_embd ? n_embd*n_outputs_max : 0;
// reset backend sampling values.
sampling.logits_size = 0;
sampling.probs_size = 0;
sampling.sampled_size = 0;
sampling.candidates_size = 0;
} else {
logits_size = 0;
embd_size = 0;
sampling.logits_size = n_vocab*n_outputs_max;
sampling.probs_size = n_vocab*n_outputs_max;
sampling.sampled_size = n_outputs_max;
sampling.candidates_size = n_vocab*n_outputs_max;
backend_float_count = sampling.logits_size + sampling.probs_size;
backend_token_count = sampling.sampled_size + sampling.candidates_size;
}
if (output_ids.empty()) {
// init, never resized afterwards
@ -1555,7 +1681,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
}
const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0;
const size_t new_size = (logits_size + embd_size) * sizeof(float);
const size_t new_size = (logits_size + embd_size + backend_float_count) * sizeof(float)
+ backend_token_count * sizeof(llama_token);
// alloc only when more than the current capacity is required
// TODO: also consider shrinking the buffer
@ -1585,13 +1712,57 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
}
float * output_base = (float *) ggml_backend_buffer_get_base(buf_output.get());
llama_token * s_output_base = (llama_token *) ggml_backend_buffer_get_base(buf_output.get());
logits = has_logits ? output_base : nullptr;
embd = has_embd ? output_base + logits_size : nullptr;
logits = nullptr;
embd = nullptr;
sampling.sampled = !sampling.samplers.empty() ? s_output_base : nullptr;
sampling.probs = !sampling.samplers.empty() ? embd : nullptr;
// reset sampling pointers.
sampling.logits = nullptr;
sampling.probs = nullptr;
sampling.sampled = nullptr;
sampling.candidates = nullptr;
if (!backend_sampling) {
logits = has_logits ? output_base : nullptr;
embd = has_embd ? output_base + logits_size : nullptr;
} else {
size_t offset = 0;
uint8_t * base = (uint8_t *) output_base;
if (sampling.logits_size > 0) {
sampling.logits = (float *) (base + offset);
offset += sampling.logits_size * sizeof(float);
}
if (sampling.probs_size > 0) {
sampling.probs = (float *) (base + offset);
offset += sampling.probs_size * sizeof(float);
}
if (sampling.sampled_size > 0) {
sampling.sampled = (llama_token *) (base + offset);
offset += sampling.sampled_size * sizeof(llama_token);
}
if (sampling.candidates_size > 0) {
sampling.candidates = (llama_token *) (base + offset);
offset += sampling.candidates_size * sizeof(llama_token);
}
const size_t n_rows = (size_t) n_outputs_max;
if (sampling.outputs_capacity < n_rows) {
sampling.outputs_capacity = n_rows;
sampling.logits_count.assign(n_rows, 0);
sampling.probs_count.assign(n_rows, 0);
sampling.candidates_count.assign(n_rows, 0);
} else {
std::fill(sampling.logits_count.begin(), sampling.logits_count.end(), 0);
std::fill(sampling.probs_count.begin(), sampling.probs_count.end(), 0);
std::fill(sampling.candidates_count.begin(), sampling.candidates_count.end(), 0);
}
if (sampling.sampled && sampling.sampled_size > 0) {
std::fill_n(sampling.sampled, sampling.sampled_size, LLAMA_TOKEN_NULL);
}
}
// set all ids as invalid (negative)
std::fill(output_ids.begin(), output_ids.end(), -1);
@ -1620,6 +1791,38 @@ void llama_context::output_reorder() {
std::swap(embd[i0*n_embd + k], embd[i1*n_embd + k]);
}
}
if (sampling.logits && sampling.logits_size > 0) {
for (uint64_t k = 0; k < n_vocab; ++k) {
std::swap(sampling.logits[i0*n_vocab + k], sampling.logits[i1*n_vocab + k]);
}
}
if (sampling.probs && sampling.probs_size > 0) {
for (uint64_t k = 0; k < n_vocab; ++k) {
std::swap(sampling.probs[i0*n_vocab + k], sampling.probs[i1*n_vocab + k]);
}
}
if (sampling.candidates && sampling.candidates_size > 0) {
for (uint64_t k = 0; k < n_vocab; ++k) {
std::swap(sampling.candidates[i0*n_vocab + k], sampling.candidates[i1*n_vocab + k]);
}
}
if (sampling.sampled && sampling.sampled_size > 0) {
std::swap(sampling.sampled[i0], sampling.sampled[i1]);
}
if (!sampling.logits_count.empty()) {
std::swap(sampling.logits_count[i0], sampling.logits_count[i1]);
}
if (!sampling.probs_count.empty()) {
std::swap(sampling.probs_count[i0], sampling.probs_count[i1]);
}
if (!sampling.candidates_count.empty()) {
std::swap(sampling.candidates_count[i0], sampling.candidates_count[i1]);
}
}
output_swaps.clear();
@ -2776,10 +2979,16 @@ float * llama_get_backend_sampled_logits_ith(llama_context * ctx, int32_t i) {
return ctx->get_backend_sampled_logits_ith(i);
}
llama_token * llama_get_backend_sampled_token_ids_ith(llama_context * ctx, int32_t i) {
llama_token * llama_get_backend_sampled_candidates_ith(llama_context * ctx, int32_t i) {
ctx->synchronize();
return const_cast<llama_token *>(ctx->get_backend_sampled_token_ids_ith(i));
return const_cast<llama_token *>(ctx->get_backend_sampled_candidates_ith(i));
}
uint32_t llama_get_backend_sampled_candidates_count_ith(llama_context * ctx, int32_t i) {
ctx->synchronize();
return static_cast<uint32_t>(ctx->get_backend_sampled_candidates_count(i));
}
uint32_t llama_get_backend_sampled_logits_count_ith(llama_context * ctx, int32_t i) {

View File

@ -70,11 +70,13 @@ struct llama_context {
llama_token get_backend_sampled_token_ith(int32_t idx);
float * get_backend_sampled_logits_ith(int32_t idx);
const llama_token * get_backend_sampled_token_ids_ith(int32_t idx);
size_t get_backend_sampled_logits_count(int32_t idx) const;
size_t get_backend_sampled_logits_count(int32_t idx);
float * get_backend_sampled_probs_ith(int32_t idx);
size_t get_backend_sampled_probs_count(int32_t idx) const;
size_t get_backend_sampled_probs_count(int32_t idx);
const llama_token * get_backend_sampled_candidates_ith(int32_t idx);
size_t get_backend_sampled_candidates_count(int32_t idx);
void attach_threadpool(
ggml_threadpool_t threadpool,
@ -201,6 +203,7 @@ private:
uint32_t output_reserve(int32_t n_outputs);
void output_reorder();
int64_t resolve_output_row(int32_t i) const;
//
// graph
@ -257,13 +260,22 @@ private:
struct sampling_info {
std::unordered_map<llama_seq_id, llama_sampler*> samplers;
llama_token * sampled = nullptr;
float * probs = nullptr;
float * logits = nullptr;
size_t logits_size = 0;
std::unordered_map<int32_t, llama_token> map_sampled;
std::unordered_map<int32_t, std::vector<float>> map_probs;
std::unordered_map<int32_t, std::vector<float>> map_logits;
std::unordered_map<int32_t, std::vector<llama_token>> map_cadidates;
llama_token * sampled = nullptr;
size_t sampled_size = 0;
float * probs = nullptr;
size_t probs_size = 0;
llama_token * candidates = nullptr;
size_t candidates_size = 0;
size_t outputs_capacity = 0;
std::vector<uint32_t> logits_count;
std::vector<uint32_t> probs_count;
std::vector<uint32_t> candidates_count;
std::vector<llama_token> token_ids_full_vocab;
};

View File

@ -442,7 +442,7 @@ llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_conte
const llama_token sampled_token = llama_get_backend_sampled_token_ith(ctx, idx);
const float * sampled_probs = llama_get_backend_sampled_probs_ith(ctx, idx);
const float * sampled_logits = llama_get_backend_sampled_logits_ith(ctx, idx);
const llama_token * sampled_ids = llama_get_backend_sampled_token_ids_ith(ctx, idx);
const llama_token * sampled_ids = llama_get_backend_sampled_candidates_ith(ctx, idx);
// If a backend sampler has already sampled a token, return it.
if (sampled_token != LLAMA_TOKEN_NULL) {

View File

@ -325,6 +325,13 @@ static void test_backend_top_k_sampling(const char * model_path) {
printf("top_k logit[%zu] = %.6f\n", i, logits[i]);
}
llama_token * candidates = llama_get_backend_sampled_candidates_ith(test_ctx.ctx, batch_idx);
uint32_t n_candidates = llama_get_backend_sampled_candidates_count_ith(test_ctx.ctx, batch_idx);
for (size_t i = 0; i < n_candidates; ++i) {
printf("top_k candidate[%zu] = %d : %s\n", i, candidates[i],
test_ctx.token_to_piece(candidates[i], false).c_str());
}
// Sample using CPU sampler for verification that it is possible to do hybrid
// sampling, first top_k on the backend and then dist on the CPU.
struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
@ -370,6 +377,9 @@ static void test_backend_temp_sampling(const char * model_path) {
int32_t batch_idx_0 = test_ctx.idx_for_seq(0);
int32_t batch_idx_1 = test_ctx.idx_for_seq(1);
int n_logits = llama_get_backend_sampled_logits_count_ith(test_ctx.ctx, batch_idx_0);
GGML_ASSERT(n_logits == test_ctx.n_vocab);
// Sample from sequence 0 using CPU sampler
struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params();
struct llama_sampler * chain_0 = llama_sampler_chain_init(chain_params_0);