This commit is contained in:
Daniel Bevenius 2026-02-06 14:56:19 +01:00 committed by GitHub
commit c72416f496
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 62 additions and 65 deletions

View File

@ -757,7 +757,7 @@ float * llama_context::get_embeddings() {
}
llama_token * llama_context::get_sampled_tokens() const{
return sampling.sampled;
return sampling.sampled.data;
}
float * llama_context::get_embeddings_ith(int32_t i) {
@ -814,14 +814,14 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) {
llama_token llama_context::get_sampled_token_ith(int32_t idx) {
output_reorder();
if (sampling.sampled == nullptr) {
if (!sampling.sampled.has_data()) {
return LLAMA_TOKEN_NULL;
}
try {
const int64_t row = output_resolve_row(idx);
GGML_ASSERT(row < (int64_t) sampling.sampled_size);
return sampling.sampled[row];
GGML_ASSERT(row < (int64_t) sampling.sampled.size);
return sampling.sampled.data[row];
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: invalid backend sampled token id %d, reason: %s\n", __func__, idx, err.what());
return LLAMA_TOKEN_NULL;
@ -831,7 +831,7 @@ llama_token llama_context::get_sampled_token_ith(int32_t idx) {
float * llama_context::get_sampled_probs_ith(int32_t idx) {
output_reorder();
if (sampling.probs == nullptr) {
if (!sampling.probs.has_data()) {
return nullptr;
}
@ -840,7 +840,7 @@ float * llama_context::get_sampled_probs_ith(int32_t idx) {
if ((size_t) row >= sampling.probs_count.size() || sampling.probs_count[row] == 0) {
return nullptr;
}
return sampling.probs + row*model.vocab.n_tokens();
return sampling.probs.data + row*model.vocab.n_tokens();
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: invalid backend sampled probs id %d, reason: %s\n", __func__, idx, err.what());
return nullptr;
@ -850,7 +850,7 @@ float * llama_context::get_sampled_probs_ith(int32_t idx) {
float * llama_context::get_sampled_logits_ith(int32_t idx) {
output_reorder();
if (sampling.logits == nullptr) {
if (!sampling.logits.has_data()) {
return nullptr;
}
@ -859,7 +859,7 @@ float * llama_context::get_sampled_logits_ith(int32_t idx) {
if ((size_t) row >= sampling.logits_count.size() || sampling.logits_count[row] == 0) {
return nullptr;
}
return sampling.logits + row*model.vocab.n_tokens();
return sampling.logits.data + row*model.vocab.n_tokens();
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: invalid backend sampled logits id %d, reason: %s\n", __func__, idx, err.what());
return nullptr;
@ -871,10 +871,10 @@ const llama_token * llama_context::get_sampled_candidates_ith(int32_t idx) {
try {
const int64_t row = output_resolve_row(idx);
if (sampling.candidates != nullptr &&
if (sampling.candidates.has_data() &&
(size_t) row < sampling.candidates_count.size() &&
sampling.candidates_count[row] > 0) {
return sampling.candidates + row*model.vocab.n_tokens();
return sampling.candidates.data + row*model.vocab.n_tokens();
}
} catch (const std::exception & err) {
// fallback to full vocab list
@ -886,7 +886,7 @@ const llama_token * llama_context::get_sampled_candidates_ith(int32_t idx) {
size_t llama_context::get_sampled_candidates_count(int32_t idx) {
output_reorder();
if (sampling.candidates == nullptr) {
if (!sampling.candidates.has_data()) {
return 0;
}
@ -905,7 +905,7 @@ size_t llama_context::get_sampled_candidates_count(int32_t idx) {
size_t llama_context::get_sampled_logits_count(int32_t idx) {
output_reorder();
if (sampling.logits == nullptr) {
if (!sampling.logits.has_data()) {
return model.vocab.n_tokens();
}
@ -924,7 +924,7 @@ size_t llama_context::get_sampled_logits_count(int32_t idx) {
size_t llama_context::get_sampled_probs_count(int32_t idx) {
output_reorder();
if (sampling.probs == nullptr) {
if (!sampling.probs.has_data()) {
return 0;
}
@ -1363,11 +1363,10 @@ static std::map<llama_seq_id, uint32_t> build_seq_to_output_row(const llama_ubat
static void copy_tensor_async_ints(
const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
llama_token * sampled,
size_t sampled_size,
const buffer_view<llama_token> & sampled,
const std::map<llama_seq_id, uint32_t> & seq_to_row,
ggml_backend_sched_t sched) {
if (sampled == nullptr) {
if (!sampled.has_data()) {
return;
}
@ -1378,23 +1377,23 @@ static void copy_tensor_async_ints(
}
const uint32_t row = it->second;
GGML_ASSERT(row < sampled_size);
GGML_ASSERT(row < sampled.size);
GGML_ASSERT(ggml_is_contiguous(tensor) && "sampled tokens tensor must be contiguous for async copy");
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
ggml_backend_tensor_get_async(backend, tensor, sampled + row, 0, sizeof(sampled[row]));
ggml_backend_tensor_get_async(backend, tensor, sampled.data + row, 0, sizeof(sampled.data[row]));
}
}
static void copy_tensor_async_floats(
const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
float * dst,
const buffer_view<float> & dst,
size_t stride,
std::vector<uint32_t> & counts,
const std::map<llama_seq_id, uint32_t> & seq_to_row,
ggml_backend_sched_t sched) {
if (dst == nullptr) {
if (!dst.has_data()) {
return;
}
@ -1410,7 +1409,7 @@ static void copy_tensor_async_floats(
GGML_ASSERT(ggml_is_contiguous(tensor) && "logits/probs tensor must be contiguous for async copy");
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
float * row_ptr = dst + (size_t) row * stride;
float * row_ptr = dst.data + (size_t) row * stride;
ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor));
// Update the actual number of logits/probabilities that were written for this row.
@ -1420,12 +1419,12 @@ static void copy_tensor_async_floats(
static void copy_tensor_async_candidates(
const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
llama_token * dst,
const buffer_view<llama_token> & dst,
size_t stride,
std::vector<uint32_t> & counts,
const std::map<llama_seq_id, uint32_t> & seq_to_row,
ggml_backend_sched_t sched) {
if (dst == nullptr) {
if (!dst.has_data()) {
return;
}
@ -1441,7 +1440,7 @@ static void copy_tensor_async_candidates(
GGML_ASSERT(ggml_is_contiguous(tensor) && "candidates tensor must be contiguous for async copy");
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
llama_token * row_ptr = dst + (size_t) row * stride;
llama_token * row_ptr = dst.data + (size_t) row * stride;
ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor));
// Update the actual number of candidates that were written.
@ -1747,7 +1746,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
const auto stride = n_vocab;
// async copy the sampling data from the backend to the host
copy_tensor_async_ints(res->t_sampled, sampling.sampled, sampling.sampled_size, seq_to_output_row, sched.get());
copy_tensor_async_ints(res->t_sampled, sampling.sampled, seq_to_output_row, sched.get());
copy_tensor_async_floats (res->t_sampled_logits, sampling.logits, stride, sampling.logits_count, seq_to_output_row, sched.get());
copy_tensor_async_floats (res->t_sampled_probs, sampling.probs, stride, sampling.probs_count, seq_to_output_row, sched.get());
@ -1847,13 +1846,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
// Allocate backend sampling output buffers if there are backend samplers configured.
const bool has_sampling = !sampling.samplers.empty();
if (has_sampling) {
sampling.logits_size = n_vocab*n_outputs_max;
sampling.probs_size = n_vocab*n_outputs_max;
sampling.sampled_size = n_outputs_max;
sampling.candidates_size = n_vocab*n_outputs_max;
backend_float_count = sampling.logits_size + sampling.probs_size;
backend_token_count = sampling.sampled_size + sampling.candidates_size;
backend_float_count = 2 * n_vocab * n_outputs_max; // logits + probs
backend_token_count = (1 + n_vocab) * n_outputs_max; // sampled + candidates
}
if (output_ids.empty()) {
@ -1910,23 +1904,23 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
embd = has_embd ? (float *) (base + offset) : nullptr;
offset += embd_size * sizeof(float);
sampling.logits = nullptr;
sampling.probs = nullptr;
sampling.sampled = nullptr;
sampling.candidates = nullptr;
sampling.logits = {nullptr, 0};
sampling.probs = {nullptr, 0};
sampling.sampled = {nullptr, 0};
sampling.candidates = {nullptr, 0};
if (has_sampling) {
sampling.logits = (float *) (base + offset);
offset += sampling.logits_size * sizeof(float);
sampling.logits = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)};
offset += sampling.logits.size * sizeof(float);
sampling.probs = (float *) (base + offset);
offset += sampling.probs_size * sizeof(float);
sampling.probs = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)};
offset += sampling.probs.size * sizeof(float);
sampling.sampled = (llama_token *) (base + offset);
offset += sampling.sampled_size * sizeof(llama_token);
sampling.sampled = {(llama_token *) (base + offset), (size_t)n_outputs_max};
offset += sampling.sampled.size * sizeof(llama_token);
sampling.candidates = (llama_token *) (base + offset);
offset += sampling.candidates_size * sizeof(llama_token);
sampling.candidates = {(llama_token *) (base + offset), (size_t)(n_vocab*n_outputs_max)};
offset += sampling.candidates.size * sizeof(llama_token);
// The count vectors keep track of the actual number of logits/probs/candidates
// copied from the backend for each output row.
@ -1939,7 +1933,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
std::fill(sampling.probs_count.begin(), sampling.probs_count.end(), 0);
std::fill(sampling.candidates_count.begin(), sampling.candidates_count.end(), 0);
std::fill_n(sampling.sampled, sampling.sampled_size, LLAMA_TOKEN_NULL);
std::fill_n(sampling.sampled.data, sampling.sampled.size, LLAMA_TOKEN_NULL);
}
// set all ids as invalid (negative)
@ -1970,26 +1964,26 @@ void llama_context::output_reorder() {
}
}
if (sampling.logits && sampling.logits_size > 0) {
if (sampling.logits.has_data()) {
for (uint64_t k = 0; k < n_vocab; ++k) {
std::swap(sampling.logits[i0*n_vocab + k], sampling.logits[i1*n_vocab + k]);
std::swap(sampling.logits.data[i0*n_vocab + k], sampling.logits.data[i1*n_vocab + k]);
}
}
if (sampling.probs && sampling.probs_size > 0) {
if (sampling.probs.has_data()) {
for (uint64_t k = 0; k < n_vocab; ++k) {
std::swap(sampling.probs[i0*n_vocab + k], sampling.probs[i1*n_vocab + k]);
std::swap(sampling.probs.data[i0*n_vocab + k], sampling.probs.data[i1*n_vocab + k]);
}
}
if (sampling.candidates && sampling.candidates_size > 0) {
if (sampling.candidates.has_data()) {
for (uint64_t k = 0; k < n_vocab; ++k) {
std::swap(sampling.candidates[i0*n_vocab + k], sampling.candidates[i1*n_vocab + k]);
std::swap(sampling.candidates.data[i0*n_vocab + k], sampling.candidates.data[i1*n_vocab + k]);
}
}
if (sampling.sampled && sampling.sampled_size > 0) {
std::swap(sampling.sampled[i0], sampling.sampled[i1]);
if (sampling.sampled.has_data()) {
std::swap(sampling.sampled.data[i0], sampling.sampled.data[i1]);
}
if (!sampling.logits_count.empty()) {

View File

@ -4,6 +4,7 @@
#include "llama-cparams.h"
#include "llama-graph.h"
#include "llama-adapter.h"
#include "llama-impl.h"
#include "ggml-cpp.h"
#include "ggml-opt.h"
@ -277,21 +278,13 @@ private:
size_t embd_size = 0; // capacity (of floats) for embeddings
float * embd = nullptr;
// TODO: simplify
struct sampling_info {
std::map<llama_seq_id, llama_sampler *> samplers;
float * logits = nullptr;
size_t logits_size = 0;
llama_token * sampled = nullptr;
size_t sampled_size = 0;
float * probs = nullptr;
size_t probs_size = 0;
llama_token * candidates = nullptr;
size_t candidates_size = 0;
struct buffer_view<float> logits = {nullptr, 0};
struct buffer_view<llama_token> sampled = {nullptr, 0};
struct buffer_view<float> probs = {nullptr, 0};
struct buffer_view<llama_token> candidates = {nullptr, 0};
std::vector<uint32_t> logits_count;
std::vector<uint32_t> probs_count;

View File

@ -49,6 +49,16 @@ struct time_meas {
int64_t & t_acc;
};
template <typename T>
struct buffer_view {
T * data;
size_t size = 0;
bool has_data() const {
return data && size > 0;
}
};
void replace_all(std::string & s, const std::string & search, const std::string & replace);
// TODO: rename to llama_format ?