llama : assert at most one output token per sequence

This commit is contained in:
Georgi Gerganov 2025-12-31 17:44:27 +02:00
parent 4c3d5422ad
commit 435c96709b
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
3 changed files with 30 additions and 27 deletions

View File

@ -28,8 +28,7 @@ bool llama_batch_allocr::init(
const llama_memory_i * memory,
uint32_t n_embd,
uint32_t n_seq_max,
bool output_all,
bool sampling) {
bool output_all) {
clear();
batch = batch_inp;
@ -146,24 +145,6 @@ bool llama_batch_allocr::init(
}
}
if (sampling) {
std::vector<int32_t> seq_output_count(n_seq_max, 0);
for (int32_t i = 0; i < batch.n_tokens; ++i) {
if (batch.logits[i] == 0) {
continue;
}
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
const llama_seq_id seq_id = batch.seq_id[i][s];
seq_output_count[seq_id]++;
if (seq_output_count[seq_id] > 1) {
LLAMA_LOG_ERROR("%s: backend sampling requires at most one output token per sequence (%d)\n", __func__, seq_id);
return false;
}
}
}
}
//
// compute stats
//

View File

@ -81,8 +81,7 @@ public:
const llama_memory_i * memory,
uint32_t n_embd,
uint32_t n_seq_max,
bool output_all,
bool sampling = false);
bool output_all);
const llama_batch & get_batch() const;

View File

@ -1388,13 +1388,36 @@ int llama_context::decode(const llama_batch & batch_inp) {
const int64_t n_embd = hparams.n_embd_inp();
// when computing embeddings, all tokens are output
const bool output_all = cparams.embeddings;
const bool output_all = cparams.embeddings;
const bool has_samplers = !sampling.samplers.empty();
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd,
cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max,
output_all,
has_samplers)) {
const uint32_t n_seq_max = cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max;
// TODO: avoid this workaround in the future
if (has_samplers && batch_inp.logits) {
std::vector<int32_t> seq_output_count(n_seq_max, 0);
for (int32_t i = 0; i < batch_inp.n_tokens; ++i) {
if (batch_inp.logits[i] == 0) {
continue;
}
const int ns = batch_inp.n_seq_id ? batch_inp.n_seq_id[i] : 1;
for (int32_t s = 0; s < ns; ++s) {
const llama_seq_id seq_id = batch_inp.seq_id ? batch_inp.seq_id[i][s] : 0;
seq_output_count[seq_id]++;
if (seq_output_count[seq_id] > 1) {
LLAMA_LOG_ERROR("%s: backend sampling requires at most one output token per sequence (seq_id %d had %d)\n",
__func__, seq_id, seq_output_count[seq_id]);
return -1;
}
}
}
}
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, n_seq_max, output_all)) {
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
return -1;
}