sampling : support multiple outputs per sequence

This commit adds support for multiple outputs per sequence in the
backend sampling implementation.

The main motivation for this change is to be able to support speculative
decoding using backend samplers where multiple outputs for the same
sequence would be needed.
This commit is contained in:
Daniel Bevenius 2026-02-23 14:31:29 +01:00
parent a8b192b6ec
commit 1138d5c2d9
No known key found for this signature in database
4 changed files with 129 additions and 115 deletions

View File

@ -1296,8 +1296,8 @@ int llama_context::encode(const llama_batch & batch_inp) {
return 0;
}
static std::map<llama_seq_id, uint32_t> build_seq_to_output_row(const llama_ubatch & ubatch, uint32_t row_offset) {
std::map<llama_seq_id, uint32_t> seq_to_row;
static std::map<llama_seq_id, std::vector<uint32_t>> build_seq_to_output_row(const llama_ubatch & ubatch, uint32_t row_offset) {
std::map<llama_seq_id, std::vector<uint32_t>> seq_to_row;
// how many output tokens we have seen so far for this ubatch.
uint32_t local = 0;
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
@ -1308,96 +1308,114 @@ static std::map<llama_seq_id, uint32_t> build_seq_to_output_row(const llama_ubat
const llama_seq_id seq_id = ubatch.seq_id[i][0];
// row_offset is the number of output tokens before this ubatch.
seq_to_row[seq_id] = row_offset + local;
seq_to_row[seq_id].push_back(row_offset + local);
++local;
}
return seq_to_row;
}
static void copy_tensor_async_ints(
const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
const std::map<llama_seq_id, std::vector<ggml_tensor*>> & tensor_map,
const buffer_view<llama_token> & sampled,
const std::map<llama_seq_id, uint32_t> & seq_to_row,
const std::map<llama_seq_id, std::vector<uint32_t>> & seq_to_row,
ggml_backend_sched_t sched) {
if (!sampled.has_data()) {
return;
}
for (const auto & [seq_id, tensor] : tensor_map) {
for (const auto & [seq_id, tensors] : tensor_map) {
auto it = seq_to_row.find(seq_id);
if (it == seq_to_row.end()) {
continue;
}
const uint32_t row = it->second;
GGML_ASSERT(row < sampled.size);
const std::vector<uint32_t> & rows = it->second;
GGML_ASSERT(tensors.size() == rows.size() && "number of tensors must match number of output rows");
GGML_ASSERT(ggml_is_contiguous(tensor) && "sampled tokens tensor must be contiguous for async copy");
for (size_t i = 0; i < tensors.size(); ++i) {
const uint32_t row = rows[i];
ggml_tensor * tensor = tensors[i];
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
ggml_backend_tensor_get_async(backend, tensor, sampled.data + row, 0, sizeof(sampled.data[row]));
GGML_ASSERT(row < sampled.size);
GGML_ASSERT(ggml_is_contiguous(tensor) && "sampled tokens tensor must be contiguous for async copy");
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
ggml_backend_tensor_get_async(backend, tensor, sampled.data + row, 0, sizeof(sampled.data[row]));
}
}
}
static void copy_tensor_async_floats(
const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
const std::map<llama_seq_id, std::vector<ggml_tensor*>> & tensor_map,
const buffer_view<float> & dst,
size_t stride,
std::vector<uint32_t> & counts,
const std::map<llama_seq_id, uint32_t> & seq_to_row,
const std::map<llama_seq_id, std::vector<uint32_t>> & seq_to_row,
ggml_backend_sched_t sched) {
if (!dst.has_data()) {
return;
}
for (const auto & [seq_id, tensor] : tensor_map) {
for (const auto & [seq_id, tensors] : tensor_map) {
auto it = seq_to_row.find(seq_id);
if (it == seq_to_row.end()) {
continue;
}
const uint32_t row = it->second;
GGML_ASSERT(row < counts.size());
const std::vector<uint32_t> & rows = it->second;
GGML_ASSERT(tensors.size() == rows.size() && "number of tensors must match number of output rows");
GGML_ASSERT(ggml_is_contiguous(tensor) && "logits/probs tensor must be contiguous for async copy");
for (size_t i = 0; i < tensors.size(); ++i) {
const uint32_t row = rows[i];
ggml_tensor * tensor = tensors[i];
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
float * row_ptr = dst.data + (size_t) row * stride;
ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor));
GGML_ASSERT(row < counts.size());
GGML_ASSERT(ggml_is_contiguous(tensor) && "logits/probs tensor must be contiguous for async copy");
// Update the actual number of logits/probabilities that were written for this row.
counts[row] = ggml_nelements(tensor);
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
float * row_ptr = dst.data + (size_t) row * stride;
ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor));
// Update the actual number of logits/probabilities that were written for this row.
counts[row] = ggml_nelements(tensor);
}
}
}
static void copy_tensor_async_candidates(
const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
const std::map<llama_seq_id, std::vector<ggml_tensor*>> & tensor_map,
const buffer_view<llama_token> & dst,
size_t stride,
std::vector<uint32_t> & counts,
const std::map<llama_seq_id, uint32_t> & seq_to_row,
const std::map<llama_seq_id, std::vector<uint32_t>> & seq_to_row,
ggml_backend_sched_t sched) {
if (!dst.has_data()) {
return;
}
for (const auto & [seq_id, tensor] : tensor_map) {
for (const auto & [seq_id, tensors] : tensor_map) {
auto it = seq_to_row.find(seq_id);
if (it == seq_to_row.end()) {
continue;
}
const uint32_t row = it->second;
GGML_ASSERT(row < counts.size());
const std::vector<uint32_t> & rows = it->second;
GGML_ASSERT(tensors.size() == rows.size() && "number of tensors must match number of output rows");
GGML_ASSERT(ggml_is_contiguous(tensor) && "candidates tensor must be contiguous for async copy");
for (size_t i = 0; i < tensors.size(); ++i) {
const uint32_t row = rows[i];
ggml_tensor * tensor = tensors[i];
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
llama_token * row_ptr = dst.data + (size_t) row * stride;
ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor));
GGML_ASSERT(row < counts.size());
GGML_ASSERT(ggml_is_contiguous(tensor) && "candidates tensor must be contiguous for async copy");
// Update the actual number of candidates that were written.
counts[row] = ggml_nelements(tensor);
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
llama_token * row_ptr = dst.data + (size_t) row * stride;
ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor));
// Update the actual number of candidates that were written.
counts[row] = ggml_nelements(tensor);
}
}
}
@ -1443,30 +1461,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
const uint32_t n_seq_max = cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max;
// TODO: avoid this workaround in the future
if (has_samplers && batch_inp.logits) {
std::vector<int32_t> seq_output_count(n_seq_max, 0);
for (int32_t i = 0; i < batch_inp.n_tokens; ++i) {
if (batch_inp.logits[i] == 0) {
continue;
}
const int ns = batch_inp.n_seq_id ? batch_inp.n_seq_id[i] : 1;
for (int32_t s = 0; s < ns; ++s) {
const llama_seq_id seq_id = batch_inp.seq_id ? batch_inp.seq_id[i][s] : 0;
seq_output_count[seq_id]++;
if (seq_output_count[seq_id] > 1) {
LLAMA_LOG_ERROR("%s: backend sampling requires at most one output token per sequence (seq_id %d had %d)\n",
__func__, seq_id, seq_output_count[seq_id]);
return -1;
}
}
}
}
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, n_seq_max, output_all)) {
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
return -1;

View File

@ -774,24 +774,24 @@ void llm_graph_result::set_outputs() {
if (t_embd_pooled != nullptr) {
ggml_set_output(t_embd_pooled);
}
for (auto & [seq_id, t] : t_sampled) {
if (t != nullptr) {
ggml_set_output(t);
for (auto & [seq_id, tensors] : t_sampled) {
for (ggml_tensor * tensor : tensors) {
ggml_set_output(tensor);
}
}
for (auto & [seq_id, t] : t_sampled_probs) {
if (t != nullptr) {
ggml_set_output(t);
for (auto & [seq_id, tensors] : t_sampled_probs) {
for (ggml_tensor * tensor : tensors) {
ggml_set_output(tensor);
}
}
for (auto & [seq_id, t] : t_sampled_logits) {
if (t != nullptr) {
ggml_set_output(t);
for (auto & [seq_id, tensors] : t_sampled_logits) {
for (ggml_tensor * tensor : tensors) {
ggml_set_output(tensor);
}
}
for (auto & [seq_id, t] : t_candidates) {
if (t != nullptr) {
ggml_set_output(t);
for (auto & [seq_id, tensors] : t_candidates) {
for (ggml_tensor * tensor : tensors) {
ggml_set_output(tensor);
}
}
}
@ -2580,13 +2580,13 @@ void llm_graph_context::build_sampling() const {
auto inp_sampling = std::make_unique<llm_graph_input_sampling>(samplers);
res->add_input(std::move(inp_sampling));
std::map<llama_seq_id, int32_t> seq_to_logit_row;
std::map<llama_seq_id, std::vector<int32_t>> seq_to_logit_rows;
int32_t logit_row_idx = 0;
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
if (ubatch.output[i]) {
llama_seq_id seq_id = ubatch.seq_id[i][0];
seq_to_logit_row[seq_id] = logit_row_idx;
seq_to_logit_rows[seq_id].push_back(logit_row_idx);
logit_row_idx++;
}
}
@ -2600,47 +2600,52 @@ void llm_graph_context::build_sampling() const {
ggml_tensor * logits_t = ggml_pad(ctx0, res->t_logits, 0, 1, 0, 0);
for (const auto & [seq_id, sampler] : samplers) {
const auto it = seq_to_logit_row.find(seq_id);
const auto row_it = seq_to_logit_rows.find(seq_id);
// inactive samplers always work on the first row
const auto row_idx = it != seq_to_logit_row.end() ? it->second : 0;
const int i_out = it != seq_to_logit_row.end() ? 1 : 0;
// row_it is now a sequence id to list of row ids
static const std::vector<int32_t> default_row = {0};
const std::vector<int32_t> & logit_rows = row_it != seq_to_logit_rows.end() ? row_it->second : default_row;
for (const int32_t row_idx : logit_rows) {
ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, logits_t->ne[0], row_idx * logits_t->nb[1]);
ggml_format_name(logits_seq, "logits_seq_%d", seq_id);
// inactive samplers always work on the first row
const int i_out = row_it != seq_to_logit_rows.end() ? 1 : 0;
struct llama_sampler_data data = {
/*.logits =*/ logits_seq,
/*.probs =*/ nullptr,
/*.sampled =*/ nullptr,
/*.candidates =*/ nullptr,
};
ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, logits_t->ne[0], row_idx * logits_t->nb[1]);
ggml_format_name(logits_seq, "logits_seq_%d", seq_id);
assert(sampler->iface->backend_apply);
sampler->iface->backend_apply(sampler, ctx0, gf, &data);
struct llama_sampler_data data = {
/*.logits =*/ logits_seq,
/*.probs =*/ nullptr,
/*.sampled =*/ nullptr,
/*.candidates =*/ nullptr,
};
if (data.sampled != nullptr) {
res->t_sampled[seq_id] = data.sampled;
outs[1] = data.sampled;
ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
}
assert(sampler->iface->backend_apply);
sampler->iface->backend_apply(sampler, ctx0, gf, &data);
if (data.probs != nullptr) {
res->t_sampled_probs[seq_id] = data.probs;
outs[1] = data.probs;
ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
}
if (data.sampled != nullptr) {
res->t_sampled[seq_id].push_back(data.sampled);
outs[1] = data.sampled;
ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
}
if (data.logits != nullptr) {
res->t_sampled_logits[seq_id] = data.logits;
outs[1] = data.logits;
ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
}
if (data.probs != nullptr) {
res->t_sampled_probs[seq_id].push_back(data.probs);
outs[1] = data.probs;
ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
}
if (data.candidates != nullptr) {
res->t_candidates[seq_id] = data.candidates;
outs[1] = data.candidates;
ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
if (data.logits != nullptr) {
res->t_sampled_logits[seq_id].push_back(data.logits);
outs[1] = data.logits;
ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
}
if (data.candidates != nullptr) {
res->t_candidates[seq_id].push_back(data.candidates);
outs[1] = data.candidates;
ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
}
}
}

View File

@ -662,10 +662,10 @@ public:
ggml_tensor * t_embd = nullptr;
ggml_tensor * t_embd_pooled = nullptr;
std::map<llama_seq_id, ggml_tensor*> t_sampled_logits;
std::map<llama_seq_id, ggml_tensor*> t_candidates;
std::map<llama_seq_id, ggml_tensor*> t_sampled;
std::map<llama_seq_id, ggml_tensor*> t_sampled_probs;
std::map<llama_seq_id, std::vector<ggml_tensor*>> t_sampled_logits;
std::map<llama_seq_id, std::vector<ggml_tensor*>> t_candidates;
std::map<llama_seq_id, std::vector<ggml_tensor*>> t_sampled;
std::map<llama_seq_id, std::vector<ggml_tensor*>> t_sampled_probs;
std::vector<llm_graph_input_ptr> inputs;

View File

@ -968,7 +968,7 @@ static void test_backend_cpu_mixed_batch(const test_params & params) {
printf("backend-cpu mixed batch test PASSED\n");
}
static void test_backend_max_outputs(const test_params & params) {
static void test_backend_multiple_outputs(const test_params & params) {
const int seq_id = 0;
const int32_t seed = 88;
@ -994,17 +994,32 @@ static void test_backend_max_outputs(const test_params & params) {
}
for (size_t i = 0; i < tokens.size(); i++) {
// set all tokens as output to trigger error
// set all tokens as output to get multiple outputs for a single sequence.
common_batch_add(batch, tokens[i], i, { seq_id }, true);
}
printf(">>> test_max_outputs expected error start:\n");
const int ret = llama_decode(test_ctx.ctx.get(), batch);
GGML_ASSERT(ret != 0 && "llama_decode should not succeed multiple outputs per sequence");
printf("<<< test_max_outputs expected error end.\n");
if (ret != 0) {
GGML_ASSERT(false && "Failed to decode sequence with multiple outputs");
}
std::vector<llama_token> sampled_tokens;
for (int i = 0; i < batch.n_tokens; i++) {
if (batch.logits[i]) {
llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), i);
const std::string token_str = test_ctx.token_to_piece(token, false);
//printf("Position %d: token id=%d, string='%s'\n", i, token, token_str.c_str());
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
sampled_tokens.push_back(token);
}
}
GGML_ASSERT((int)sampled_tokens.size() == batch.n_tokens);
printf("Sampled %zu tokens for sequence %d\n", sampled_tokens.size(), seq_id);
llama_batch_free(batch);
printf("backend max outputs test PASSED\n");
printf("backend multiple outputs test PASSED\n");
}
struct backend_test_case {
@ -1023,7 +1038,7 @@ static const backend_test_case BACKEND_TESTS[] = {
{ "dist", test_backend_dist_sampling, true },
{ "dist_and_cpu", test_backend_dist_sampling_and_cpu, true },
{ "set_sampler", test_backend_set_sampler, true },
{ "max_outputs", test_backend_max_outputs, true },
{ "multiple_outputs",test_backend_multiple_outputs, true },
{ "mixed", test_backend_mixed_sampling, true },
{ "min_p", test_backend_min_p_sampling, true },
{ "cpu_mixed", test_backend_cpu_mixed_batch, true },