From 61ffe41dc1cdafb2c71b650c4265c22ec769f88b Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Fri, 21 Nov 2025 14:02:16 +0100 Subject: [PATCH] sampling : use pinned memory for backend sampling buffers --- common/build-info.cpp | 4 + common/sampling.cpp | 6 +- include/llama.h | 17 +- src/llama-context.cpp | 409 +++++++++++++++++++++++++-------- src/llama-context.h | 30 ++- src/llama-sampling.cpp | 2 +- tests/test-backend-sampler.cpp | 10 + 7 files changed, 358 insertions(+), 120 deletions(-) create mode 100644 common/build-info.cpp diff --git a/common/build-info.cpp b/common/build-info.cpp new file mode 100644 index 0000000000..6e8240fbb1 --- /dev/null +++ b/common/build-info.cpp @@ -0,0 +1,4 @@ +int LLAMA_BUILD_NUMBER = 5590; +char const *LLAMA_COMMIT = "0d398442"; +char const *LLAMA_COMPILER = "cc (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0"; +char const *LLAMA_BUILD_TARGET = "x86_64-linux-gnu"; diff --git a/common/sampling.cpp b/common/sampling.cpp index ebe61f32ca..ae2276c712 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -113,9 +113,9 @@ struct common_sampler { llama_token_data_array cur_p; void set_logits(struct llama_context * ctx, int idx) { - const float * sampled_probs = llama_get_backend_sampled_probs_ith (ctx, idx); - const float * sampled_logits = llama_get_backend_sampled_logits_ith (ctx, idx); - const llama_token * sampled_ids = llama_get_backend_sampled_token_ids_ith(ctx, idx); + const float * sampled_probs = llama_get_backend_sampled_probs_ith (ctx, idx); + const float * sampled_logits = llama_get_backend_sampled_logits_ith (ctx, idx); + const llama_token * sampled_ids = llama_get_backend_sampled_candidates_ith(ctx, idx); const llama_model * model = llama_get_model(ctx); const llama_vocab * vocab = llama_model_get_vocab(model); diff --git a/include/llama.h b/include/llama.h index 9c4862ad89..9fbce771d7 100644 --- a/include/llama.h +++ b/include/llama.h @@ -974,20 +974,23 @@ extern "C" { // The index matches llama_get_backend_sampled_token_ith(). // Returns NULL if no probabilites were generated. LLAMA_API float * llama_get_backend_sampled_probs_ith(struct llama_context * ctx, int32_t i); + // + // Get the number of backend sampled probabilites for the ith token. + LLAMA_API uint32_t llama_get_backend_sampled_probs_count_ith(struct llama_context * ctx, int32_t i); // Get the backend sampled logits for the ith token // Returns NULL if no logits were sampled. LLAMA_API float * llama_get_backend_sampled_logits_ith(struct llama_context * ctx, int32_t i); - - // Get the backend sampled token ids associated with the sampled logits for the ith token - // Returns NULL if no logits were sampled. - LLAMA_API llama_token * llama_get_backend_sampled_token_ids_ith(struct llama_context * ctx, int32_t i); - + // // Get the number of backend sampled logits for the ith token. LLAMA_API uint32_t llama_get_backend_sampled_logits_count_ith(struct llama_context * ctx, int32_t i); - // Get the number of backend sampled probabilites for the ith token. - LLAMA_API uint32_t llama_get_backend_sampled_probs_count_ith(struct llama_context * ctx, int32_t i); + // Get the backend sampled candidates (token ids) for the ith token + // Returns NULL if no candidates were sampled. + LLAMA_API llama_token * llama_get_backend_sampled_candidates_ith(struct llama_context * ctx, int32_t i); + // + // Get the number of backend sampled candidates for the ith token. + LLAMA_API uint32_t llama_get_backend_sampled_candidates_count_ith(struct llama_context * ctx, int32_t i); // // Vocab diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 7bebf58b9e..1694e44720 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -588,6 +588,35 @@ float * llama_context::get_logits() { return logits; } +int64_t llama_context::resolve_output_row(int32_t i) const { + int64_t j = -1; + + // support negative indices (last output row) + if (i < 0) { + j = n_outputs + i; + if (j < 0) { + throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs)); + } + } else if ((size_t) i >= output_ids.size()) { + throw std::runtime_error(format("out of range [0, %zu)", output_ids.size())); + } else { + // use output_ids to translate the batch token index into a row number + // that holds this token's data. + j = output_ids[i]; + } + + if (j < 0) { + // the batch token was not configured to output anything + throw std::runtime_error(format("batch.logits[%d] != true", i)); + } + + if (j >= n_outputs) { + throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs)); + } + + return j; +} + float * llama_context::get_logits_ith(int32_t i) { int64_t j = -1; @@ -688,100 +717,135 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) { } llama_token llama_context::get_backend_sampled_token_ith(int32_t idx) { - // Handle special case where idx == -1 (single sequence exists) which is - // a valid index when using common_sampler_sample. - if (idx == -1) { - if (sampling.map_sampled.size() == 1) { - auto it = sampling.map_sampled.begin(); - return it->second; - } + output_reorder(); + + if (sampling.sampled == nullptr) { return LLAMA_TOKEN_NULL; } - auto it = sampling.map_sampled.find(idx); - if (it == sampling.map_sampled.end()) { + try { + const int64_t row = resolve_output_row(idx); + GGML_ASSERT(row < (int64_t) sampling.sampled_size); + return sampling.sampled[row]; + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid backend sampled token id %d, reason: %s\n", __func__, idx, err.what()); return LLAMA_TOKEN_NULL; } - - return it->second; } float * llama_context::get_backend_sampled_probs_ith(int32_t idx) { - if (idx == -1) { - if (sampling.map_probs.size() == 1) { - return sampling.map_probs.begin()->second.data(); - } - } + output_reorder(); - auto it = sampling.map_probs.find(idx); - if (it == sampling.map_probs.end()) { + if (sampling.probs == nullptr) { return nullptr; } - return it->second.data(); + try { + const int64_t row = resolve_output_row(idx); + if ((size_t) row >= sampling.probs_count.size() || sampling.probs_count[row] == 0) { + return nullptr; + } + return sampling.probs + row*model.vocab.n_tokens(); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid backend sampled probs id %d, reason: %s\n", __func__, idx, err.what()); + return nullptr; + } } float * llama_context::get_backend_sampled_logits_ith(int32_t idx) { - if (idx == -1) { - if (sampling.map_logits.size() == 1) { - return sampling.map_logits.begin()->second.data(); - } - } - auto it = sampling.map_logits.find(idx); - if (it == sampling.map_logits.end()) { + output_reorder(); + + if (sampling.logits == nullptr) { return nullptr; } - return it->second.data(); + try { + const int64_t row = resolve_output_row(idx); + if ((size_t) row >= sampling.logits_count.size() || sampling.logits_count[row] == 0) { + return nullptr; + } + return sampling.logits + row*model.vocab.n_tokens(); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid backend sampled logits id %d, reason: %s\n", __func__, idx, err.what()); + return nullptr; + } } -const llama_token * llama_context::get_backend_sampled_token_ids_ith(int32_t idx) { - if (idx == -1) { - if (sampling.map_cadidates.size() == 1) { - const auto & vec = sampling.map_cadidates.begin()->second; - if (!vec.empty()) { - return vec.data(); - } +const llama_token * llama_context::get_backend_sampled_candidates_ith(int32_t idx) { + output_reorder(); + + try { + const int64_t row = resolve_output_row(idx); + if (sampling.candidates != nullptr && + (size_t) row < sampling.candidates_count.size() && + sampling.candidates_count[row] > 0) { + return sampling.candidates + row*model.vocab.n_tokens(); } - } - auto it = sampling.map_cadidates.find(idx); - if (it != sampling.map_cadidates.end() && !it->second.empty()) { - return it->second.data(); + } catch (const std::exception & err) { + // fallback to full vocab list } return sampling.token_ids_full_vocab.data(); } -size_t llama_context::get_backend_sampled_logits_count(int32_t idx) const { - if (idx == -1) { - if (sampling.map_logits.size() == 1) { - return sampling.map_logits.begin()->second.size(); - } - } - auto it = sampling.map_logits.find(idx); - if (it == sampling.map_logits.end()) { +size_t llama_context::get_backend_sampled_candidates_count(int32_t idx) { + output_reorder(); + + if (sampling.candidates == nullptr) { return 0; } - return it->second.size(); + try { + const int64_t row = resolve_output_row(idx); + if ((size_t) row >= sampling.candidates_count.size()) { + return 0; + } + return sampling.candidates_count[row]; + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid backend sampled candidates count id %d, reason: %s\n", __func__, idx, err.what()); + return 0; + } } -size_t llama_context::get_backend_sampled_probs_count(int32_t idx) const { - if (idx == -1) { - if (sampling.map_probs.size() == 1) { - return sampling.map_probs.begin()->second.size(); +size_t llama_context::get_backend_sampled_logits_count(int32_t idx) { + output_reorder(); + + if (sampling.logits == nullptr) { + return 0; + } + + try { + const int64_t row = resolve_output_row(idx); + if ((size_t) row >= sampling.logits_count.size()) { + return 0; } + return sampling.logits_count[row]; + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid backend sampled logits count id %d, reason: %s\n", __func__, idx, err.what()); return 0; } - - auto it = sampling.map_probs.find(idx); - if (it == sampling.map_probs.end()) { - return 0; - } - - return it->second.size(); } +size_t llama_context::get_backend_sampled_probs_count(int32_t idx) { + output_reorder(); + + if (sampling.probs == nullptr) { + return 0; + } + + try { + const int64_t row = resolve_output_row(idx); + if ((size_t) row >= sampling.probs_count.size()) { + return 0; + } + return sampling.probs_count[row]; + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid backend sampled probs count id %d, reason: %s\n", __func__, idx, err.what()); + return 0; + } +} + + void llama_context::attach_threadpool( ggml_threadpool_t threadpool, ggml_threadpool_t threadpool_batch) { @@ -1133,51 +1197,94 @@ int llama_context::encode(const llama_batch & batch_inp) { return 0; } -static std::unordered_map build_seq_to_batch_idx(const llama_ubatch & ubatch) { - std::unordered_map seq_to_batch_idx; - for (uint32_t i = 0; i < ubatch.n_tokens; i++) { - if (ubatch.output[i]) { - seq_to_batch_idx[ubatch.seq_id[i][0]] = i; +static std::unordered_map build_seq_to_output_row(const llama_ubatch & ubatch, uint32_t row_offset) { + std::unordered_map seq_to_row; + // how many output tokens we have seen so far for this ubatch. + uint32_t local = 0; + for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { + // skip tokens that are not output. + if (!ubatch.output[i]) { + continue; } + + const llama_seq_id seq_id = ubatch.seq_id[i][0]; + // row_offset is the number of output tokens before this ubatch. + seq_to_row[seq_id] = row_offset + local; + ++local; } - return seq_to_batch_idx; + return seq_to_row; } -static void copy_tensor_async_int( +static void copy_tensor_async_ints( const std::unordered_map & tensor_map, - std::unordered_map & output_map, - const std::unordered_map & seq_to_batch_idx, + llama_token * sampled, + size_t sampled_size, + const std::unordered_map & seq_to_row, ggml_backend_sched_t sched) { + if (sampled == nullptr || sampled_size == 0) { + return; + } + for (const auto & [seq_id, tensor] : tensor_map) { - const int32_t idx = seq_to_batch_idx.at(seq_id); + auto it = seq_to_row.find(seq_id); + GGML_ASSERT(it != seq_to_row.end()); + const uint32_t row = it->second; + GGML_ASSERT(row < sampled_size); + ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor); - ggml_backend_tensor_get_async(backend, tensor, &output_map[idx], 0, sizeof(output_map[idx])); + ggml_backend_tensor_get_async(backend, tensor, sampled + row, 0, sizeof(sampled[row])); } } static void copy_tensor_async_floats( const std::unordered_map & tensor_map, - std::unordered_map> & output_map, - const std::unordered_map & seq_to_batch_idx, + float * dst, + size_t stride, + std::vector & counts, + const std::unordered_map & seq_to_row, ggml_backend_sched_t sched) { + if (dst == nullptr || stride == 0) { + return; + } + for (const auto & [seq_id, tensor] : tensor_map) { - const int32_t idx = seq_to_batch_idx.at(seq_id); + auto it = seq_to_row.find(seq_id); + GGML_ASSERT(it != seq_to_row.end()); + const uint32_t row = it->second; + GGML_ASSERT(row < counts.size()); + ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor); - output_map[idx].resize(ggml_nelements(tensor)); - ggml_backend_tensor_get_async(backend, tensor, output_map[idx].data(), 0, ggml_nbytes(tensor)); + float * row_ptr = dst + (size_t) row * stride; + ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor)); + + // Update the actual number of logits/probabilities that were written for this row. + counts[row] = ggml_nelements(tensor); } } -static void copy_tensor_async_token_ids( +static void copy_tensor_async_candidates( const std::unordered_map & tensor_map, - std::unordered_map> & output_map, - const std::unordered_map & seq_to_batch_idx, + llama_token * dst, + size_t stride, + std::vector & counts, + const std::unordered_map & seq_to_row, ggml_backend_sched_t sched) { + if (dst == nullptr || stride == 0) { + return; + } + for (const auto & [seq_id, tensor] : tensor_map) { - const int32_t idx = seq_to_batch_idx.at(seq_id); + auto it = seq_to_row.find(seq_id); + GGML_ASSERT(it != seq_to_row.end()); + const uint32_t row = it->second; + GGML_ASSERT(row < counts.size()); + ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor); - output_map[idx].resize(ggml_nelements(tensor)); - ggml_backend_tensor_get_async(backend, tensor, output_map[idx].data(), 0, ggml_nbytes(tensor)); + llama_token * row_ptr = dst + (size_t) row * stride; + ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor)); + + // Update the actual number of candidates that were written. + counts[row] = ggml_nelements(tensor); } } @@ -1235,10 +1342,6 @@ int llama_context::decode(const llama_batch & batch_inp) { // TODO: this clear of the buffer can easily be forgotten - need something better embd_seq.clear(); - sampling.map_probs.clear(); - sampling.map_logits.clear(); - sampling.map_sampled.clear(); - sampling.map_cadidates.clear(); output_swaps.clear(); bool did_optimize = false; @@ -1364,24 +1467,24 @@ int llama_context::decode(const llama_batch & batch_inp) { backend_has_sampled = !res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty(); if (has_backend_samplers && backend_has_sampled) { - const auto seq_to_batch_idx = build_seq_to_batch_idx(ubatch); + const auto seq_to_output_row = build_seq_to_output_row(ubatch, n_outputs_prev); // If a backend sampler has sampled a token we only want to copy the // sampled tokens and avoid copying logits and probabilites. if (!res->t_sampled.empty()) { // async copy the sampled tokens from the backend to the host. - copy_tensor_async_int(res->t_sampled, sampling.map_sampled, seq_to_batch_idx, sched.get()); + copy_tensor_async_ints(res->t_sampled, sampling.sampled, sampling.sampled_size, seq_to_output_row, sched.get()); } else { // async copy the sampled logits/probs from the backend to the host. - copy_tensor_async_floats(res->t_sampled_logits, sampling.map_logits, seq_to_batch_idx, sched.get()); - copy_tensor_async_floats(res->t_sampled_probs, sampling.map_probs, seq_to_batch_idx, sched.get()); + copy_tensor_async_floats(res->t_sampled_logits, sampling.logits, n_vocab, sampling.logits_count, seq_to_output_row, sched.get()); + copy_tensor_async_floats(res->t_sampled_probs, sampling.probs, n_vocab, sampling.probs_count, seq_to_output_row, sched.get()); } // async copy the candidate token ids from the backend to the host. // These are needed for: // 1) Backend dist sampler to map indices to vocab token ids. // 2) CPU samplers to associate candidate logits with their token ids. - copy_tensor_async_token_ids(res->t_candidates, sampling.map_cadidates, seq_to_batch_idx, sched.get()); + copy_tensor_async_candidates(res->t_candidates, sampling.candidates, n_vocab, sampling.candidates_count, seq_to_output_row, sched.get()); } @@ -1471,7 +1574,7 @@ int llama_context::decode(const llama_batch & batch_inp) { n_outputs = n_outputs_all; // set output mappings - if (n_outputs > 0 && !backend_has_sampled) { + if (n_outputs > 0) { bool sorted_output = true; auto & out_ids = balloc->get_out_ids(); @@ -1546,8 +1649,31 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { has_embd = true; } - logits_size = has_logits ? n_vocab*n_outputs_max : 0; - embd_size = has_embd ? n_embd*n_outputs_max : 0; + const bool backend_sampling = !sampling.samplers.empty(); + size_t backend_float_count = 0; + size_t backend_token_count = 0; + + if (!backend_sampling) { + logits_size = has_logits ? n_vocab*n_outputs_max : 0; + embd_size = has_embd ? n_embd*n_outputs_max : 0; + + // reset backend sampling values. + sampling.logits_size = 0; + sampling.probs_size = 0; + sampling.sampled_size = 0; + sampling.candidates_size = 0; + } else { + logits_size = 0; + embd_size = 0; + + sampling.logits_size = n_vocab*n_outputs_max; + sampling.probs_size = n_vocab*n_outputs_max; + sampling.sampled_size = n_outputs_max; + sampling.candidates_size = n_vocab*n_outputs_max; + + backend_float_count = sampling.logits_size + sampling.probs_size; + backend_token_count = sampling.sampled_size + sampling.candidates_size; + } if (output_ids.empty()) { // init, never resized afterwards @@ -1555,7 +1681,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { } const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0; - const size_t new_size = (logits_size + embd_size) * sizeof(float); + const size_t new_size = (logits_size + embd_size + backend_float_count) * sizeof(float) + + backend_token_count * sizeof(llama_token); // alloc only when more than the current capacity is required // TODO: also consider shrinking the buffer @@ -1585,13 +1712,57 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { } float * output_base = (float *) ggml_backend_buffer_get_base(buf_output.get()); - llama_token * s_output_base = (llama_token *) ggml_backend_buffer_get_base(buf_output.get()); - logits = has_logits ? output_base : nullptr; - embd = has_embd ? output_base + logits_size : nullptr; + logits = nullptr; + embd = nullptr; - sampling.sampled = !sampling.samplers.empty() ? s_output_base : nullptr; - sampling.probs = !sampling.samplers.empty() ? embd : nullptr; + // reset sampling pointers. + sampling.logits = nullptr; + sampling.probs = nullptr; + sampling.sampled = nullptr; + sampling.candidates = nullptr; + + if (!backend_sampling) { + logits = has_logits ? output_base : nullptr; + embd = has_embd ? output_base + logits_size : nullptr; + } else { + size_t offset = 0; + uint8_t * base = (uint8_t *) output_base; + + if (sampling.logits_size > 0) { + sampling.logits = (float *) (base + offset); + offset += sampling.logits_size * sizeof(float); + } + if (sampling.probs_size > 0) { + sampling.probs = (float *) (base + offset); + offset += sampling.probs_size * sizeof(float); + } + if (sampling.sampled_size > 0) { + sampling.sampled = (llama_token *) (base + offset); + offset += sampling.sampled_size * sizeof(llama_token); + } + if (sampling.candidates_size > 0) { + sampling.candidates = (llama_token *) (base + offset); + offset += sampling.candidates_size * sizeof(llama_token); + } + + const size_t n_rows = (size_t) n_outputs_max; + if (sampling.outputs_capacity < n_rows) { + sampling.outputs_capacity = n_rows; + + sampling.logits_count.assign(n_rows, 0); + sampling.probs_count.assign(n_rows, 0); + sampling.candidates_count.assign(n_rows, 0); + } else { + std::fill(sampling.logits_count.begin(), sampling.logits_count.end(), 0); + std::fill(sampling.probs_count.begin(), sampling.probs_count.end(), 0); + std::fill(sampling.candidates_count.begin(), sampling.candidates_count.end(), 0); + } + + if (sampling.sampled && sampling.sampled_size > 0) { + std::fill_n(sampling.sampled, sampling.sampled_size, LLAMA_TOKEN_NULL); + } + } // set all ids as invalid (negative) std::fill(output_ids.begin(), output_ids.end(), -1); @@ -1620,6 +1791,38 @@ void llama_context::output_reorder() { std::swap(embd[i0*n_embd + k], embd[i1*n_embd + k]); } } + + if (sampling.logits && sampling.logits_size > 0) { + for (uint64_t k = 0; k < n_vocab; ++k) { + std::swap(sampling.logits[i0*n_vocab + k], sampling.logits[i1*n_vocab + k]); + } + } + + if (sampling.probs && sampling.probs_size > 0) { + for (uint64_t k = 0; k < n_vocab; ++k) { + std::swap(sampling.probs[i0*n_vocab + k], sampling.probs[i1*n_vocab + k]); + } + } + + if (sampling.candidates && sampling.candidates_size > 0) { + for (uint64_t k = 0; k < n_vocab; ++k) { + std::swap(sampling.candidates[i0*n_vocab + k], sampling.candidates[i1*n_vocab + k]); + } + } + + if (sampling.sampled && sampling.sampled_size > 0) { + std::swap(sampling.sampled[i0], sampling.sampled[i1]); + } + + if (!sampling.logits_count.empty()) { + std::swap(sampling.logits_count[i0], sampling.logits_count[i1]); + } + if (!sampling.probs_count.empty()) { + std::swap(sampling.probs_count[i0], sampling.probs_count[i1]); + } + if (!sampling.candidates_count.empty()) { + std::swap(sampling.candidates_count[i0], sampling.candidates_count[i1]); + } } output_swaps.clear(); @@ -2776,10 +2979,16 @@ float * llama_get_backend_sampled_logits_ith(llama_context * ctx, int32_t i) { return ctx->get_backend_sampled_logits_ith(i); } -llama_token * llama_get_backend_sampled_token_ids_ith(llama_context * ctx, int32_t i) { +llama_token * llama_get_backend_sampled_candidates_ith(llama_context * ctx, int32_t i) { ctx->synchronize(); - return const_cast(ctx->get_backend_sampled_token_ids_ith(i)); + return const_cast(ctx->get_backend_sampled_candidates_ith(i)); +} + +uint32_t llama_get_backend_sampled_candidates_count_ith(llama_context * ctx, int32_t i) { + ctx->synchronize(); + + return static_cast(ctx->get_backend_sampled_candidates_count(i)); } uint32_t llama_get_backend_sampled_logits_count_ith(llama_context * ctx, int32_t i) { diff --git a/src/llama-context.h b/src/llama-context.h index 8e6a111e61..2bdbf8a553 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -70,11 +70,13 @@ struct llama_context { llama_token get_backend_sampled_token_ith(int32_t idx); float * get_backend_sampled_logits_ith(int32_t idx); - const llama_token * get_backend_sampled_token_ids_ith(int32_t idx); - size_t get_backend_sampled_logits_count(int32_t idx) const; + size_t get_backend_sampled_logits_count(int32_t idx); float * get_backend_sampled_probs_ith(int32_t idx); - size_t get_backend_sampled_probs_count(int32_t idx) const; + size_t get_backend_sampled_probs_count(int32_t idx); + + const llama_token * get_backend_sampled_candidates_ith(int32_t idx); + size_t get_backend_sampled_candidates_count(int32_t idx); void attach_threadpool( ggml_threadpool_t threadpool, @@ -201,6 +203,7 @@ private: uint32_t output_reserve(int32_t n_outputs); void output_reorder(); + int64_t resolve_output_row(int32_t i) const; // // graph @@ -257,13 +260,22 @@ private: struct sampling_info { std::unordered_map samplers; - llama_token * sampled = nullptr; - float * probs = nullptr; + float * logits = nullptr; + size_t logits_size = 0; - std::unordered_map map_sampled; - std::unordered_map> map_probs; - std::unordered_map> map_logits; - std::unordered_map> map_cadidates; + llama_token * sampled = nullptr; + size_t sampled_size = 0; + + float * probs = nullptr; + size_t probs_size = 0; + + llama_token * candidates = nullptr; + size_t candidates_size = 0; + + size_t outputs_capacity = 0; + std::vector logits_count; + std::vector probs_count; + std::vector candidates_count; std::vector token_ids_full_vocab; }; diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 11679c6c9e..c126b70226 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -442,7 +442,7 @@ llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_conte const llama_token sampled_token = llama_get_backend_sampled_token_ith(ctx, idx); const float * sampled_probs = llama_get_backend_sampled_probs_ith(ctx, idx); const float * sampled_logits = llama_get_backend_sampled_logits_ith(ctx, idx); - const llama_token * sampled_ids = llama_get_backend_sampled_token_ids_ith(ctx, idx); + const llama_token * sampled_ids = llama_get_backend_sampled_candidates_ith(ctx, idx); // If a backend sampler has already sampled a token, return it. if (sampled_token != LLAMA_TOKEN_NULL) { diff --git a/tests/test-backend-sampler.cpp b/tests/test-backend-sampler.cpp index c6d0d1a38d..2ed13688c9 100644 --- a/tests/test-backend-sampler.cpp +++ b/tests/test-backend-sampler.cpp @@ -325,6 +325,13 @@ static void test_backend_top_k_sampling(const char * model_path) { printf("top_k logit[%zu] = %.6f\n", i, logits[i]); } + llama_token * candidates = llama_get_backend_sampled_candidates_ith(test_ctx.ctx, batch_idx); + uint32_t n_candidates = llama_get_backend_sampled_candidates_count_ith(test_ctx.ctx, batch_idx); + for (size_t i = 0; i < n_candidates; ++i) { + printf("top_k candidate[%zu] = %d : %s\n", i, candidates[i], + test_ctx.token_to_piece(candidates[i], false).c_str()); + } + // Sample using CPU sampler for verification that it is possible to do hybrid // sampling, first top_k on the backend and then dist on the CPU. struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); @@ -370,6 +377,9 @@ static void test_backend_temp_sampling(const char * model_path) { int32_t batch_idx_0 = test_ctx.idx_for_seq(0); int32_t batch_idx_1 = test_ctx.idx_for_seq(1); + int n_logits = llama_get_backend_sampled_logits_count_ith(test_ctx.ctx, batch_idx_0); + GGML_ASSERT(n_logits == test_ctx.n_vocab); + // Sample from sequence 0 using CPU sampler struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params(); struct llama_sampler * chain_0 = llama_sampler_chain_init(chain_params_0);