llama : refactor sampling_info to use buffer_view template (#19368)

* llama : refactor sampling_info to use buffer_view template

This commit updates the sampling_info struct in llama-context to use a
buffer_view template for the logits, probs, sampled tokens, and
candidates buffers.

The motivation for this is to simplify the code, improve type safety
and readability.
This commit is contained in:
Daniel Bevenius 2026-02-11 05:38:13 +01:00 committed by GitHub
parent 612db61886
commit 2cce9fddb7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 107 additions and 115 deletions

View File

@ -677,7 +677,7 @@ enum llama_pooling_type llama_context::pooling_type() const {
float * llama_context::get_logits() {
output_reorder();
return logits;
return logits.data;
}
int64_t llama_context::output_resolve_row(int32_t i) const {
@ -715,7 +715,7 @@ float * llama_context::get_logits_ith(int32_t i) {
output_reorder();
try {
if (logits == nullptr) {
if (logits.data == nullptr) {
throw std::runtime_error("no logits");
}
@ -739,7 +739,7 @@ float * llama_context::get_logits_ith(int32_t i) {
throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
}
return logits + j*model.vocab.n_tokens();
return logits.data + j*model.vocab.n_tokens();
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
#ifndef NDEBUG
@ -753,11 +753,11 @@ float * llama_context::get_logits_ith(int32_t i) {
float * llama_context::get_embeddings() {
output_reorder();
return embd;
return embd.data;
}
llama_token * llama_context::get_sampled_tokens() const{
return sampling.sampled;
return sampling.sampled.data;
}
float * llama_context::get_embeddings_ith(int32_t i) {
@ -766,7 +766,7 @@ float * llama_context::get_embeddings_ith(int32_t i) {
output_reorder();
try {
if (embd == nullptr) {
if (embd.data == nullptr) {
throw std::runtime_error("no embeddings");
}
@ -791,7 +791,7 @@ float * llama_context::get_embeddings_ith(int32_t i) {
}
const uint32_t n_embd_out = model.hparams.n_embd_out();
return embd + j*n_embd_out;
return embd.data + j*n_embd_out;
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what());
#ifndef NDEBUG
@ -814,14 +814,14 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) {
llama_token llama_context::get_sampled_token_ith(int32_t idx) {
output_reorder();
if (sampling.sampled == nullptr) {
if (!sampling.sampled.has_data()) {
return LLAMA_TOKEN_NULL;
}
try {
const int64_t row = output_resolve_row(idx);
GGML_ASSERT(row < (int64_t) sampling.sampled_size);
return sampling.sampled[row];
GGML_ASSERT(row < (int64_t) sampling.sampled.size);
return sampling.sampled.data[row];
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: invalid backend sampled token id %d, reason: %s\n", __func__, idx, err.what());
return LLAMA_TOKEN_NULL;
@ -831,7 +831,7 @@ llama_token llama_context::get_sampled_token_ith(int32_t idx) {
float * llama_context::get_sampled_probs_ith(int32_t idx) {
output_reorder();
if (sampling.probs == nullptr) {
if (!sampling.probs.has_data()) {
return nullptr;
}
@ -840,7 +840,7 @@ float * llama_context::get_sampled_probs_ith(int32_t idx) {
if ((size_t) row >= sampling.probs_count.size() || sampling.probs_count[row] == 0) {
return nullptr;
}
return sampling.probs + row*model.vocab.n_tokens();
return sampling.probs.data + row*model.vocab.n_tokens();
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: invalid backend sampled probs id %d, reason: %s\n", __func__, idx, err.what());
return nullptr;
@ -850,7 +850,7 @@ float * llama_context::get_sampled_probs_ith(int32_t idx) {
float * llama_context::get_sampled_logits_ith(int32_t idx) {
output_reorder();
if (sampling.logits == nullptr) {
if (!sampling.logits.has_data()) {
return nullptr;
}
@ -859,7 +859,7 @@ float * llama_context::get_sampled_logits_ith(int32_t idx) {
if ((size_t) row >= sampling.logits_count.size() || sampling.logits_count[row] == 0) {
return nullptr;
}
return sampling.logits + row*model.vocab.n_tokens();
return sampling.logits.data + row*model.vocab.n_tokens();
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: invalid backend sampled logits id %d, reason: %s\n", __func__, idx, err.what());
return nullptr;
@ -871,10 +871,10 @@ const llama_token * llama_context::get_sampled_candidates_ith(int32_t idx) {
try {
const int64_t row = output_resolve_row(idx);
if (sampling.candidates != nullptr &&
if (sampling.candidates.has_data() &&
(size_t) row < sampling.candidates_count.size() &&
sampling.candidates_count[row] > 0) {
return sampling.candidates + row*model.vocab.n_tokens();
return sampling.candidates.data + row*model.vocab.n_tokens();
}
} catch (const std::exception & err) {
// fallback to full vocab list
@ -886,7 +886,7 @@ const llama_token * llama_context::get_sampled_candidates_ith(int32_t idx) {
size_t llama_context::get_sampled_candidates_count(int32_t idx) {
output_reorder();
if (sampling.candidates == nullptr) {
if (!sampling.candidates.has_data()) {
return 0;
}
@ -905,7 +905,7 @@ size_t llama_context::get_sampled_candidates_count(int32_t idx) {
size_t llama_context::get_sampled_logits_count(int32_t idx) {
output_reorder();
if (sampling.logits == nullptr) {
if (!sampling.logits.has_data()) {
return model.vocab.n_tokens();
}
@ -924,7 +924,7 @@ size_t llama_context::get_sampled_logits_count(int32_t idx) {
size_t llama_context::get_sampled_probs_count(int32_t idx) {
output_reorder();
if (sampling.probs == nullptr) {
if (!sampling.probs.has_data()) {
return 0;
}
@ -1254,16 +1254,16 @@ int llama_context::encode(const llama_batch & batch_inp) {
auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
// extract logits
if (logits && t_logits) {
if (logits.data && t_logits) {
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
GGML_ASSERT(backend_res != nullptr);
GGML_ASSERT(logits != nullptr);
GGML_ASSERT(logits.data != nullptr);
ggml_backend_tensor_get_async(backend_res, t_logits, logits, 0, n_tokens*n_vocab*sizeof(float));
ggml_backend_tensor_get_async(backend_res, t_logits, logits.data, 0, n_tokens*n_vocab*sizeof(float));
}
// extract embeddings
if (embd && t_embd) {
if (embd.data && t_embd) {
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
GGML_ASSERT(backend_embd != nullptr);
@ -1271,11 +1271,11 @@ int llama_context::encode(const llama_batch & batch_inp) {
case LLAMA_POOLING_TYPE_NONE:
{
// extract token embeddings
GGML_ASSERT(embd != nullptr);
GGML_ASSERT(embd.data != nullptr);
const uint32_t n_embd_out = hparams.n_embd_out();
GGML_ASSERT(n_tokens*n_embd_out <= (int64_t) embd_size);
ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd_out*sizeof(float));
GGML_ASSERT(n_tokens*n_embd_out <= (int64_t) embd.size);
ggml_backend_tensor_get_async(backend_embd, t_embd, embd.data, 0, n_tokens*n_embd_out*sizeof(float));
} break;
case LLAMA_POOLING_TYPE_MEAN:
case LLAMA_POOLING_TYPE_CLS:
@ -1323,7 +1323,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
cross.n_embd = t_embd->ne[0];
cross.n_enc = t_embd->ne[1];
cross.v_embd.resize(cross.n_embd*cross.n_enc);
memcpy(cross.v_embd.data(), embd, ggml_nbytes(t_embd));
memcpy(cross.v_embd.data(), embd.data, ggml_nbytes(t_embd));
const auto & batch = balloc->get_batch();
@ -1363,11 +1363,10 @@ static std::map<llama_seq_id, uint32_t> build_seq_to_output_row(const llama_ubat
static void copy_tensor_async_ints(
const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
llama_token * sampled,
size_t sampled_size,
const buffer_view<llama_token> & sampled,
const std::map<llama_seq_id, uint32_t> & seq_to_row,
ggml_backend_sched_t sched) {
if (sampled == nullptr) {
if (!sampled.has_data()) {
return;
}
@ -1378,23 +1377,23 @@ static void copy_tensor_async_ints(
}
const uint32_t row = it->second;
GGML_ASSERT(row < sampled_size);
GGML_ASSERT(row < sampled.size);
GGML_ASSERT(ggml_is_contiguous(tensor) && "sampled tokens tensor must be contiguous for async copy");
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
ggml_backend_tensor_get_async(backend, tensor, sampled + row, 0, sizeof(sampled[row]));
ggml_backend_tensor_get_async(backend, tensor, sampled.data + row, 0, sizeof(sampled.data[row]));
}
}
static void copy_tensor_async_floats(
const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
float * dst,
const buffer_view<float> & dst,
size_t stride,
std::vector<uint32_t> & counts,
const std::map<llama_seq_id, uint32_t> & seq_to_row,
ggml_backend_sched_t sched) {
if (dst == nullptr) {
if (!dst.has_data()) {
return;
}
@ -1410,7 +1409,7 @@ static void copy_tensor_async_floats(
GGML_ASSERT(ggml_is_contiguous(tensor) && "logits/probs tensor must be contiguous for async copy");
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
float * row_ptr = dst + (size_t) row * stride;
float * row_ptr = dst.data + (size_t) row * stride;
ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor));
// Update the actual number of logits/probabilities that were written for this row.
@ -1420,12 +1419,12 @@ static void copy_tensor_async_floats(
static void copy_tensor_async_candidates(
const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
llama_token * dst,
const buffer_view<llama_token> & dst,
size_t stride,
std::vector<uint32_t> & counts,
const std::map<llama_seq_id, uint32_t> & seq_to_row,
ggml_backend_sched_t sched) {
if (dst == nullptr) {
if (!dst.has_data()) {
return;
}
@ -1441,7 +1440,7 @@ static void copy_tensor_async_candidates(
GGML_ASSERT(ggml_is_contiguous(tensor) && "candidates tensor must be contiguous for async copy");
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
llama_token * row_ptr = dst + (size_t) row * stride;
llama_token * row_ptr = dst.data + (size_t) row * stride;
ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor));
// Update the actual number of candidates that were written.
@ -1671,22 +1670,22 @@ int llama_context::decode(const llama_batch & batch_inp) {
}
// extract logits
if (logits && t_logits && n_outputs > 0 && needs_raw_logits(ubatch, sampling.samplers)) {
if (logits.data && t_logits && n_outputs > 0 && needs_raw_logits(ubatch, sampling.samplers)) {
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
GGML_ASSERT(backend_res != nullptr);
GGML_ASSERT(logits != nullptr);
GGML_ASSERT(logits.data != nullptr);
float * logits_out = logits + n_outputs_prev*n_vocab;
float * logits_out = logits.data + n_outputs_prev*n_vocab;
if (n_outputs) {
GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size);
GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits.size);
ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float));
}
}
// extract embeddings
if (embd && t_embd && n_outputs > 0) {
if (embd.data && t_embd && n_outputs > 0) {
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
GGML_ASSERT(backend_embd != nullptr);
@ -1694,13 +1693,13 @@ int llama_context::decode(const llama_batch & batch_inp) {
case LLAMA_POOLING_TYPE_NONE:
{
// extract token embeddings
GGML_ASSERT(embd != nullptr);
GGML_ASSERT(embd.data != nullptr);
const uint32_t n_embd_out = hparams.n_embd_out();
float * embd_out = embd + n_outputs_prev*n_embd_out;
float * embd_out = embd.data + n_outputs_prev*n_embd_out;
if (n_outputs) {
GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd_out <= (int64_t) embd_size);
GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd_out <= (int64_t) embd.size);
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd_out*sizeof(float));
}
} break;
@ -1747,7 +1746,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
const auto stride = n_vocab;
// async copy the sampling data from the backend to the host
copy_tensor_async_ints(res->t_sampled, sampling.sampled, sampling.sampled_size, seq_to_output_row, sched.get());
copy_tensor_async_ints(res->t_sampled, sampling.sampled, seq_to_output_row, sched.get());
copy_tensor_async_floats (res->t_sampled_logits, sampling.logits, stride, sampling.logits_count, seq_to_output_row, sched.get());
copy_tensor_async_floats (res->t_sampled_probs, sampling.probs, stride, sampling.probs_count, seq_to_output_row, sched.get());
@ -1841,19 +1840,14 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
size_t backend_float_count = 0;
size_t backend_token_count = 0;
logits_size = has_logits ? n_vocab*n_outputs_max : 0;
embd_size = has_embd ? n_embd_out*n_outputs_max : 0;
logits.size = has_logits ? n_vocab*n_outputs_max : 0;
embd.size = has_embd ? n_embd_out*n_outputs_max : 0;
// Allocate backend sampling output buffers if there are backend samplers configured.
const bool has_sampling = !sampling.samplers.empty();
if (has_sampling) {
sampling.logits_size = n_vocab*n_outputs_max;
sampling.probs_size = n_vocab*n_outputs_max;
sampling.sampled_size = n_outputs_max;
sampling.candidates_size = n_vocab*n_outputs_max;
backend_float_count = sampling.logits_size + sampling.probs_size;
backend_token_count = sampling.sampled_size + sampling.candidates_size;
backend_float_count = 2 * n_vocab * n_outputs_max; // logits + probs
backend_token_count = (1 + n_vocab) * n_outputs_max; // sampled + candidates
}
if (output_ids.empty()) {
@ -1863,7 +1857,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0;
const size_t new_size =
(logits_size + embd_size + backend_float_count) * sizeof(float) +
(logits.size + embd.size + backend_float_count) * sizeof(float) +
( backend_token_count) * sizeof(llama_token);
// alloc only when more than the current capacity is required
@ -1878,8 +1872,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
// TODO: not needed?
buf_output = nullptr;
logits = nullptr;
embd = nullptr;
logits.data = nullptr;
embd.data = nullptr;
}
auto * buft = ggml_backend_cpu_buffer_type();
@ -1898,35 +1892,32 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
float * output_base = (float *) ggml_backend_buffer_get_base(buf_output.get());
logits = nullptr;
embd = nullptr;
size_t offset = 0;
uint8_t * base = (uint8_t *) output_base;
logits = has_logits ? output_base : nullptr;
offset += logits_size * sizeof(float);
logits = has_logits ? buffer_view<float>{output_base, logits.size} : buffer_view<float>{nullptr, 0};
offset += logits.size * sizeof(float);
embd = has_embd ? (float *) (base + offset) : nullptr;
offset += embd_size * sizeof(float);
embd = has_embd ? buffer_view<float>{(float *) (base + offset), embd.size} : buffer_view<float>{nullptr, 0};
offset += embd.size * sizeof(float);
sampling.logits = nullptr;
sampling.probs = nullptr;
sampling.sampled = nullptr;
sampling.candidates = nullptr;
sampling.logits = {nullptr, 0};
sampling.probs = {nullptr, 0};
sampling.sampled = {nullptr, 0};
sampling.candidates = {nullptr, 0};
if (has_sampling) {
sampling.logits = (float *) (base + offset);
offset += sampling.logits_size * sizeof(float);
sampling.logits = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)};
offset += sampling.logits.size * sizeof(float);
sampling.probs = (float *) (base + offset);
offset += sampling.probs_size * sizeof(float);
sampling.probs = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)};
offset += sampling.probs.size * sizeof(float);
sampling.sampled = (llama_token *) (base + offset);
offset += sampling.sampled_size * sizeof(llama_token);
sampling.sampled = {(llama_token *) (base + offset), (size_t)n_outputs_max};
offset += sampling.sampled.size * sizeof(llama_token);
sampling.candidates = (llama_token *) (base + offset);
offset += sampling.candidates_size * sizeof(llama_token);
sampling.candidates = {(llama_token *) (base + offset), (size_t)(n_vocab*n_outputs_max)};
offset += sampling.candidates.size * sizeof(llama_token);
// The count vectors keep track of the actual number of logits/probs/candidates
// copied from the backend for each output row.
@ -1939,7 +1930,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
std::fill(sampling.probs_count.begin(), sampling.probs_count.end(), 0);
std::fill(sampling.candidates_count.begin(), sampling.candidates_count.end(), 0);
std::fill_n(sampling.sampled, sampling.sampled_size, LLAMA_TOKEN_NULL);
std::fill_n(sampling.sampled.data, sampling.sampled.size, LLAMA_TOKEN_NULL);
}
// set all ids as invalid (negative)
@ -1958,38 +1949,38 @@ void llama_context::output_reorder() {
const uint64_t i0 = output_swaps[s].i0;
const uint64_t i1 = output_swaps[s].i1;
if (logits_size > 0) {
if (logits.size > 0) {
for (uint64_t k = 0; k < n_vocab; k++) {
std::swap(logits[i0*n_vocab + k], logits[i1*n_vocab + k]);
std::swap(logits.data[i0*n_vocab + k], logits.data[i1*n_vocab + k]);
}
}
if (embd_size > 0) {
if (embd.size > 0) {
for (uint64_t k = 0; k < n_embd; k++) {
std::swap(embd[i0*n_embd + k], embd[i1*n_embd + k]);
std::swap(embd.data[i0*n_embd + k], embd.data[i1*n_embd + k]);
}
}
if (sampling.logits && sampling.logits_size > 0) {
if (sampling.logits.has_data()) {
for (uint64_t k = 0; k < n_vocab; ++k) {
std::swap(sampling.logits[i0*n_vocab + k], sampling.logits[i1*n_vocab + k]);
std::swap(sampling.logits.data[i0*n_vocab + k], sampling.logits.data[i1*n_vocab + k]);
}
}
if (sampling.probs && sampling.probs_size > 0) {
if (sampling.probs.has_data()) {
for (uint64_t k = 0; k < n_vocab; ++k) {
std::swap(sampling.probs[i0*n_vocab + k], sampling.probs[i1*n_vocab + k]);
std::swap(sampling.probs.data[i0*n_vocab + k], sampling.probs.data[i1*n_vocab + k]);
}
}
if (sampling.candidates && sampling.candidates_size > 0) {
if (sampling.candidates.has_data()) {
for (uint64_t k = 0; k < n_vocab; ++k) {
std::swap(sampling.candidates[i0*n_vocab + k], sampling.candidates[i1*n_vocab + k]);
std::swap(sampling.candidates.data[i0*n_vocab + k], sampling.candidates.data[i1*n_vocab + k]);
}
}
if (sampling.sampled && sampling.sampled_size > 0) {
std::swap(sampling.sampled[i0], sampling.sampled[i1]);
if (sampling.sampled.has_data()) {
std::swap(sampling.sampled.data[i0], sampling.sampled.data[i1]);
}
if (!sampling.logits_count.empty()) {
@ -2533,12 +2524,12 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
{
LLAMA_LOG_DEBUG("%s: - writing logits\n", __func__);
const uint64_t logits_size = std::min((uint64_t) this->logits_size, (uint64_t) n_outputs * model.vocab.n_tokens());
const uint64_t logits_size = std::min((uint64_t) this->logits.size, (uint64_t) n_outputs * model.vocab.n_tokens());
io.write(&logits_size, sizeof(logits_size));
if (logits_size) {
io.write(logits, logits_size * sizeof(float));
io.write(logits.data, logits_size * sizeof(float));
}
}
@ -2546,12 +2537,12 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
{
LLAMA_LOG_DEBUG("%s: - writing embeddings\n", __func__);
const uint64_t embd_size = std::min((uint64_t) this->embd_size, (uint64_t) n_outputs * model.hparams.n_embd);
const uint64_t embd_size = std::min((uint64_t) this->embd.size, (uint64_t) n_outputs * model.hparams.n_embd);
io.write(&embd_size, sizeof(embd_size));
if (embd_size) {
io.write(embd, embd_size * sizeof(float));
io.write(embd.data, embd_size * sizeof(float));
}
}
@ -2619,12 +2610,12 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
uint64_t logits_size;
io.read_to(&logits_size, sizeof(logits_size));
if (this->logits_size < logits_size) {
if (this->logits.size < logits_size) {
throw std::runtime_error("logits buffer too small");
}
if (logits_size) {
io.read_to(this->logits, logits_size * sizeof(float));
io.read_to(this->logits.data, logits_size * sizeof(float));
}
}
@ -2635,12 +2626,12 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
uint64_t embd_size;
io.read_to(&embd_size, sizeof(embd_size));
if (this->embd_size < embd_size) {
if (this->embd.size < embd_size) {
throw std::runtime_error("embeddings buffer too small");
}
if (embd_size) {
io.read_to(this->embd, embd_size * sizeof(float));
io.read_to(this->embd.data, embd_size * sizeof(float));
}
}

View File

@ -4,6 +4,7 @@
#include "llama-cparams.h"
#include "llama-graph.h"
#include "llama-adapter.h"
#include "llama-impl.h"
#include "ggml-cpp.h"
#include "ggml-opt.h"
@ -269,29 +270,19 @@ private:
std::unique_ptr<llama_memory_i> memory;
// decode output (2-dimensional array: [n_outputs][n_vocab])
size_t logits_size = 0; // capacity (of floats) for logits
float * logits = nullptr;
struct buffer_view<float> logits = {nullptr, 0};
// embeddings output (2-dimensional array: [n_outputs][n_embd])
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
size_t embd_size = 0; // capacity (of floats) for embeddings
float * embd = nullptr;
struct buffer_view<float> embd = {nullptr, 0};
// TODO: simplify
struct sampling_info {
std::map<llama_seq_id, llama_sampler *> samplers;
float * logits = nullptr;
size_t logits_size = 0;
llama_token * sampled = nullptr;
size_t sampled_size = 0;
float * probs = nullptr;
size_t probs_size = 0;
llama_token * candidates = nullptr;
size_t candidates_size = 0;
struct buffer_view<float> logits = {nullptr, 0};
struct buffer_view<llama_token> sampled = {nullptr, 0};
struct buffer_view<float> probs = {nullptr, 0};
struct buffer_view<llama_token> candidates = {nullptr, 0};
std::vector<uint32_t> logits_count;
std::vector<uint32_t> probs_count;

View File

@ -49,6 +49,16 @@ struct time_meas {
int64_t & t_acc;
};
template <typename T>
struct buffer_view {
T * data;
size_t size = 0;
bool has_data() const {
return data && size > 0;
}
};
void replace_all(std::string & s, const std::string & search, const std::string & replace);
// TODO: rename to llama_format ?