From 766e7e38769072329f0efae75fe908ad1fabc0c3 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Thu, 5 Feb 2026 15:48:26 +0100 Subject: [PATCH] llama : refactor sampling_info to use buffer_view template This commit updates the sampling_info struct in llama-context to use a buffer_view template for the logits, probs, sampled tokens, and candidates buffers. The motivation for this is to simplify the code, improve type safety and readability. --- src/llama-context.cpp | 100 ++++++++++++++++++++---------------------- src/llama-context.h | 26 ++++++----- 2 files changed, 61 insertions(+), 65 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 95b207e9e1..8748fb77b0 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -757,7 +757,7 @@ float * llama_context::get_embeddings() { } llama_token * llama_context::get_sampled_tokens() const{ - return sampling.sampled; + return sampling.sampled.data; } float * llama_context::get_embeddings_ith(int32_t i) { @@ -814,14 +814,14 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) { llama_token llama_context::get_sampled_token_ith(int32_t idx) { output_reorder(); - if (sampling.sampled == nullptr) { + if (!sampling.sampled.has_data()) { return LLAMA_TOKEN_NULL; } try { const int64_t row = output_resolve_row(idx); - GGML_ASSERT(row < (int64_t) sampling.sampled_size); - return sampling.sampled[row]; + GGML_ASSERT(row < (int64_t) sampling.sampled.size); + return sampling.sampled.data[row]; } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: invalid backend sampled token id %d, reason: %s\n", __func__, idx, err.what()); return LLAMA_TOKEN_NULL; @@ -831,7 +831,7 @@ llama_token llama_context::get_sampled_token_ith(int32_t idx) { float * llama_context::get_sampled_probs_ith(int32_t idx) { output_reorder(); - if (sampling.probs == nullptr) { + if (!sampling.probs.has_data()) { return nullptr; } @@ -840,7 +840,7 @@ float * llama_context::get_sampled_probs_ith(int32_t idx) { if ((size_t) row >= sampling.probs_count.size() || sampling.probs_count[row] == 0) { return nullptr; } - return sampling.probs + row*model.vocab.n_tokens(); + return sampling.probs.data + row*model.vocab.n_tokens(); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: invalid backend sampled probs id %d, reason: %s\n", __func__, idx, err.what()); return nullptr; @@ -850,7 +850,7 @@ float * llama_context::get_sampled_probs_ith(int32_t idx) { float * llama_context::get_sampled_logits_ith(int32_t idx) { output_reorder(); - if (sampling.logits == nullptr) { + if (!sampling.logits.has_data()) { return nullptr; } @@ -859,7 +859,7 @@ float * llama_context::get_sampled_logits_ith(int32_t idx) { if ((size_t) row >= sampling.logits_count.size() || sampling.logits_count[row] == 0) { return nullptr; } - return sampling.logits + row*model.vocab.n_tokens(); + return sampling.logits.data + row*model.vocab.n_tokens(); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: invalid backend sampled logits id %d, reason: %s\n", __func__, idx, err.what()); return nullptr; @@ -871,10 +871,10 @@ const llama_token * llama_context::get_sampled_candidates_ith(int32_t idx) { try { const int64_t row = output_resolve_row(idx); - if (sampling.candidates != nullptr && + if (sampling.candidates.has_data() && (size_t) row < sampling.candidates_count.size() && sampling.candidates_count[row] > 0) { - return sampling.candidates + row*model.vocab.n_tokens(); + return sampling.candidates.data + row*model.vocab.n_tokens(); } } catch (const std::exception & err) { // fallback to full vocab list @@ -886,7 +886,7 @@ const llama_token * llama_context::get_sampled_candidates_ith(int32_t idx) { size_t llama_context::get_sampled_candidates_count(int32_t idx) { output_reorder(); - if (sampling.candidates == nullptr) { + if (!sampling.candidates.has_data()) { return 0; } @@ -905,7 +905,7 @@ size_t llama_context::get_sampled_candidates_count(int32_t idx) { size_t llama_context::get_sampled_logits_count(int32_t idx) { output_reorder(); - if (sampling.logits == nullptr) { + if (!sampling.logits.has_data()) { return model.vocab.n_tokens(); } @@ -924,7 +924,7 @@ size_t llama_context::get_sampled_logits_count(int32_t idx) { size_t llama_context::get_sampled_probs_count(int32_t idx) { output_reorder(); - if (sampling.probs == nullptr) { + if (!sampling.probs.has_data()) { return 0; } @@ -1363,11 +1363,10 @@ static std::map build_seq_to_output_row(const llama_ubat static void copy_tensor_async_ints( const std::map & tensor_map, - llama_token * sampled, - size_t sampled_size, + const llama_context::buffer_view & sampled, const std::map & seq_to_row, ggml_backend_sched_t sched) { - if (sampled == nullptr) { + if (!sampled.has_data()) { return; } @@ -1378,23 +1377,23 @@ static void copy_tensor_async_ints( } const uint32_t row = it->second; - GGML_ASSERT(row < sampled_size); + GGML_ASSERT(row < sampled.size); GGML_ASSERT(ggml_is_contiguous(tensor) && "sampled tokens tensor must be contiguous for async copy"); ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor); - ggml_backend_tensor_get_async(backend, tensor, sampled + row, 0, sizeof(sampled[row])); + ggml_backend_tensor_get_async(backend, tensor, sampled.data + row, 0, sizeof(sampled.data[row])); } } static void copy_tensor_async_floats( const std::map & tensor_map, - float * dst, + const llama_context::buffer_view & dst, size_t stride, std::vector & counts, const std::map & seq_to_row, ggml_backend_sched_t sched) { - if (dst == nullptr) { + if (!dst.has_data()) { return; } @@ -1410,7 +1409,7 @@ static void copy_tensor_async_floats( GGML_ASSERT(ggml_is_contiguous(tensor) && "logits/probs tensor must be contiguous for async copy"); ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor); - float * row_ptr = dst + (size_t) row * stride; + float * row_ptr = dst.data + (size_t) row * stride; ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor)); // Update the actual number of logits/probabilities that were written for this row. @@ -1420,12 +1419,12 @@ static void copy_tensor_async_floats( static void copy_tensor_async_candidates( const std::map & tensor_map, - llama_token * dst, + const llama_context::buffer_view & dst, size_t stride, std::vector & counts, const std::map & seq_to_row, ggml_backend_sched_t sched) { - if (dst == nullptr) { + if (!dst.has_data()) { return; } @@ -1441,7 +1440,7 @@ static void copy_tensor_async_candidates( GGML_ASSERT(ggml_is_contiguous(tensor) && "candidates tensor must be contiguous for async copy"); ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor); - llama_token * row_ptr = dst + (size_t) row * stride; + llama_token * row_ptr = dst.data + (size_t) row * stride; ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor)); // Update the actual number of candidates that were written. @@ -1747,7 +1746,7 @@ int llama_context::decode(const llama_batch & batch_inp) { const auto stride = n_vocab; // async copy the sampling data from the backend to the host - copy_tensor_async_ints(res->t_sampled, sampling.sampled, sampling.sampled_size, seq_to_output_row, sched.get()); + copy_tensor_async_ints(res->t_sampled, sampling.sampled, seq_to_output_row, sched.get()); copy_tensor_async_floats (res->t_sampled_logits, sampling.logits, stride, sampling.logits_count, seq_to_output_row, sched.get()); copy_tensor_async_floats (res->t_sampled_probs, sampling.probs, stride, sampling.probs_count, seq_to_output_row, sched.get()); @@ -1847,13 +1846,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { // Allocate backend sampling output buffers if there are backend samplers configured. const bool has_sampling = !sampling.samplers.empty(); if (has_sampling) { - sampling.logits_size = n_vocab*n_outputs_max; - sampling.probs_size = n_vocab*n_outputs_max; - sampling.sampled_size = n_outputs_max; - sampling.candidates_size = n_vocab*n_outputs_max; - - backend_float_count = sampling.logits_size + sampling.probs_size; - backend_token_count = sampling.sampled_size + sampling.candidates_size; + backend_float_count = 2 * n_vocab * n_outputs_max; // logits + probs + backend_token_count = (1 + n_vocab) * n_outputs_max; // sampled + candidates } if (output_ids.empty()) { @@ -1910,23 +1904,23 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { embd = has_embd ? (float *) (base + offset) : nullptr; offset += embd_size * sizeof(float); - sampling.logits = nullptr; - sampling.probs = nullptr; - sampling.sampled = nullptr; - sampling.candidates = nullptr; + sampling.logits = {nullptr, 0}; + sampling.probs = {nullptr, 0}; + sampling.sampled = {nullptr, 0}; + sampling.candidates = {nullptr, 0}; if (has_sampling) { - sampling.logits = (float *) (base + offset); - offset += sampling.logits_size * sizeof(float); + sampling.logits = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)}; + offset += sampling.logits.size * sizeof(float); - sampling.probs = (float *) (base + offset); - offset += sampling.probs_size * sizeof(float); + sampling.probs = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)}; + offset += sampling.probs.size * sizeof(float); - sampling.sampled = (llama_token *) (base + offset); - offset += sampling.sampled_size * sizeof(llama_token); + sampling.sampled = {(llama_token *) (base + offset), (size_t)n_outputs_max}; + offset += sampling.sampled.size * sizeof(llama_token); - sampling.candidates = (llama_token *) (base + offset); - offset += sampling.candidates_size * sizeof(llama_token); + sampling.candidates = {(llama_token *) (base + offset), (size_t)(n_vocab*n_outputs_max)}; + offset += sampling.candidates.size * sizeof(llama_token); // The count vectors keep track of the actual number of logits/probs/candidates // copied from the backend for each output row. @@ -1939,7 +1933,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { std::fill(sampling.probs_count.begin(), sampling.probs_count.end(), 0); std::fill(sampling.candidates_count.begin(), sampling.candidates_count.end(), 0); - std::fill_n(sampling.sampled, sampling.sampled_size, LLAMA_TOKEN_NULL); + std::fill_n(sampling.sampled.data, sampling.sampled.size, LLAMA_TOKEN_NULL); } // set all ids as invalid (negative) @@ -1970,26 +1964,26 @@ void llama_context::output_reorder() { } } - if (sampling.logits && sampling.logits_size > 0) { + if (sampling.logits.has_data()) { for (uint64_t k = 0; k < n_vocab; ++k) { - std::swap(sampling.logits[i0*n_vocab + k], sampling.logits[i1*n_vocab + k]); + std::swap(sampling.logits.data[i0*n_vocab + k], sampling.logits.data[i1*n_vocab + k]); } } - if (sampling.probs && sampling.probs_size > 0) { + if (sampling.probs.has_data()) { for (uint64_t k = 0; k < n_vocab; ++k) { - std::swap(sampling.probs[i0*n_vocab + k], sampling.probs[i1*n_vocab + k]); + std::swap(sampling.probs.data[i0*n_vocab + k], sampling.probs.data[i1*n_vocab + k]); } } - if (sampling.candidates && sampling.candidates_size > 0) { + if (sampling.candidates.has_data()) { for (uint64_t k = 0; k < n_vocab; ++k) { - std::swap(sampling.candidates[i0*n_vocab + k], sampling.candidates[i1*n_vocab + k]); + std::swap(sampling.candidates.data[i0*n_vocab + k], sampling.candidates.data[i1*n_vocab + k]); } } - if (sampling.sampled && sampling.sampled_size > 0) { - std::swap(sampling.sampled[i0], sampling.sampled[i1]); + if (sampling.sampled.has_data()) { + std::swap(sampling.sampled.data[i0], sampling.sampled.data[i1]); } if (!sampling.logits_count.empty()) { diff --git a/src/llama-context.h b/src/llama-context.h index 8e71cdd1dc..978a109292 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -238,6 +238,16 @@ public: bool set_sampler(llama_seq_id seq_id, llama_sampler * sampler); + template + struct buffer_view { + T * data; + size_t size = 0; + + bool has_data() const { + return data && size > 0; + } + }; + private: llm_graph_params graph_params( llm_graph_result * res, @@ -277,21 +287,13 @@ private: size_t embd_size = 0; // capacity (of floats) for embeddings float * embd = nullptr; - // TODO: simplify struct sampling_info { std::map samplers; - float * logits = nullptr; - size_t logits_size = 0; - - llama_token * sampled = nullptr; - size_t sampled_size = 0; - - float * probs = nullptr; - size_t probs_size = 0; - - llama_token * candidates = nullptr; - size_t candidates_size = 0; + struct buffer_view logits = {nullptr, 0}; + struct buffer_view sampled = {nullptr, 0}; + struct buffer_view probs = {nullptr, 0}; + struct buffer_view candidates = {nullptr, 0}; std::vector logits_count; std::vector probs_count;