llama : assert at most one output token per sequence
This commit is contained in:
parent
4c3d5422ad
commit
435c96709b
|
|
@ -28,8 +28,7 @@ bool llama_batch_allocr::init(
|
||||||
const llama_memory_i * memory,
|
const llama_memory_i * memory,
|
||||||
uint32_t n_embd,
|
uint32_t n_embd,
|
||||||
uint32_t n_seq_max,
|
uint32_t n_seq_max,
|
||||||
bool output_all,
|
bool output_all) {
|
||||||
bool sampling) {
|
|
||||||
clear();
|
clear();
|
||||||
|
|
||||||
batch = batch_inp;
|
batch = batch_inp;
|
||||||
|
|
@ -146,24 +145,6 @@ bool llama_batch_allocr::init(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (sampling) {
|
|
||||||
std::vector<int32_t> seq_output_count(n_seq_max, 0);
|
|
||||||
|
|
||||||
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
|
||||||
if (batch.logits[i] == 0) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
|
|
||||||
const llama_seq_id seq_id = batch.seq_id[i][s];
|
|
||||||
seq_output_count[seq_id]++;
|
|
||||||
if (seq_output_count[seq_id] > 1) {
|
|
||||||
LLAMA_LOG_ERROR("%s: backend sampling requires at most one output token per sequence (%d)\n", __func__, seq_id);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// compute stats
|
// compute stats
|
||||||
//
|
//
|
||||||
|
|
|
||||||
|
|
@ -81,8 +81,7 @@ public:
|
||||||
const llama_memory_i * memory,
|
const llama_memory_i * memory,
|
||||||
uint32_t n_embd,
|
uint32_t n_embd,
|
||||||
uint32_t n_seq_max,
|
uint32_t n_seq_max,
|
||||||
bool output_all,
|
bool output_all);
|
||||||
bool sampling = false);
|
|
||||||
|
|
||||||
const llama_batch & get_batch() const;
|
const llama_batch & get_batch() const;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1388,13 +1388,36 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||||
const int64_t n_embd = hparams.n_embd_inp();
|
const int64_t n_embd = hparams.n_embd_inp();
|
||||||
|
|
||||||
// when computing embeddings, all tokens are output
|
// when computing embeddings, all tokens are output
|
||||||
const bool output_all = cparams.embeddings;
|
const bool output_all = cparams.embeddings;
|
||||||
const bool has_samplers = !sampling.samplers.empty();
|
const bool has_samplers = !sampling.samplers.empty();
|
||||||
|
|
||||||
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd,
|
const uint32_t n_seq_max = cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max;
|
||||||
cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max,
|
|
||||||
output_all,
|
// TODO: avoid this workaround in the future
|
||||||
has_samplers)) {
|
if (has_samplers && batch_inp.logits) {
|
||||||
|
std::vector<int32_t> seq_output_count(n_seq_max, 0);
|
||||||
|
|
||||||
|
for (int32_t i = 0; i < batch_inp.n_tokens; ++i) {
|
||||||
|
if (batch_inp.logits[i] == 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int ns = batch_inp.n_seq_id ? batch_inp.n_seq_id[i] : 1;
|
||||||
|
|
||||||
|
for (int32_t s = 0; s < ns; ++s) {
|
||||||
|
const llama_seq_id seq_id = batch_inp.seq_id ? batch_inp.seq_id[i][s] : 0;
|
||||||
|
|
||||||
|
seq_output_count[seq_id]++;
|
||||||
|
if (seq_output_count[seq_id] > 1) {
|
||||||
|
LLAMA_LOG_ERROR("%s: backend sampling requires at most one output token per sequence (seq_id %d had %d)\n",
|
||||||
|
__func__, seq_id, seq_output_count[seq_id]);
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, n_seq_max, output_all)) {
|
||||||
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue