llama : assert at most one output token per sequence
This commit is contained in:
parent
4c3d5422ad
commit
435c96709b
|
|
@ -28,8 +28,7 @@ bool llama_batch_allocr::init(
|
|||
const llama_memory_i * memory,
|
||||
uint32_t n_embd,
|
||||
uint32_t n_seq_max,
|
||||
bool output_all,
|
||||
bool sampling) {
|
||||
bool output_all) {
|
||||
clear();
|
||||
|
||||
batch = batch_inp;
|
||||
|
|
@ -146,24 +145,6 @@ bool llama_batch_allocr::init(
|
|||
}
|
||||
}
|
||||
|
||||
if (sampling) {
|
||||
std::vector<int32_t> seq_output_count(n_seq_max, 0);
|
||||
|
||||
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
||||
if (batch.logits[i] == 0) {
|
||||
continue;
|
||||
}
|
||||
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
|
||||
const llama_seq_id seq_id = batch.seq_id[i][s];
|
||||
seq_output_count[seq_id]++;
|
||||
if (seq_output_count[seq_id] > 1) {
|
||||
LLAMA_LOG_ERROR("%s: backend sampling requires at most one output token per sequence (%d)\n", __func__, seq_id);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// compute stats
|
||||
//
|
||||
|
|
|
|||
|
|
@ -81,8 +81,7 @@ public:
|
|||
const llama_memory_i * memory,
|
||||
uint32_t n_embd,
|
||||
uint32_t n_seq_max,
|
||||
bool output_all,
|
||||
bool sampling = false);
|
||||
bool output_all);
|
||||
|
||||
const llama_batch & get_batch() const;
|
||||
|
||||
|
|
|
|||
|
|
@ -1388,13 +1388,36 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
const int64_t n_embd = hparams.n_embd_inp();
|
||||
|
||||
// when computing embeddings, all tokens are output
|
||||
const bool output_all = cparams.embeddings;
|
||||
const bool output_all = cparams.embeddings;
|
||||
const bool has_samplers = !sampling.samplers.empty();
|
||||
|
||||
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd,
|
||||
cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max,
|
||||
output_all,
|
||||
has_samplers)) {
|
||||
const uint32_t n_seq_max = cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max;
|
||||
|
||||
// TODO: avoid this workaround in the future
|
||||
if (has_samplers && batch_inp.logits) {
|
||||
std::vector<int32_t> seq_output_count(n_seq_max, 0);
|
||||
|
||||
for (int32_t i = 0; i < batch_inp.n_tokens; ++i) {
|
||||
if (batch_inp.logits[i] == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const int ns = batch_inp.n_seq_id ? batch_inp.n_seq_id[i] : 1;
|
||||
|
||||
for (int32_t s = 0; s < ns; ++s) {
|
||||
const llama_seq_id seq_id = batch_inp.seq_id ? batch_inp.seq_id[i][s] : 0;
|
||||
|
||||
seq_output_count[seq_id]++;
|
||||
if (seq_output_count[seq_id] > 1) {
|
||||
LLAMA_LOG_ERROR("%s: backend sampling requires at most one output token per sequence (seq_id %d had %d)\n",
|
||||
__func__, seq_id, seq_output_count[seq_id]);
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, n_seq_max, output_all)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
||||
return -1;
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue