diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 627ffca916..386fab04ac 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -28,8 +28,7 @@ bool llama_batch_allocr::init( const llama_memory_i * memory, uint32_t n_embd, uint32_t n_seq_max, - bool output_all, - bool sampling) { + bool output_all) { clear(); batch = batch_inp; @@ -146,24 +145,6 @@ bool llama_batch_allocr::init( } } - if (sampling) { - std::vector seq_output_count(n_seq_max, 0); - - for (int32_t i = 0; i < batch.n_tokens; ++i) { - if (batch.logits[i] == 0) { - continue; - } - for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) { - const llama_seq_id seq_id = batch.seq_id[i][s]; - seq_output_count[seq_id]++; - if (seq_output_count[seq_id] > 1) { - LLAMA_LOG_ERROR("%s: backend sampling requires at most one output token per sequence (%d)\n", __func__, seq_id); - return false; - } - } - } - } - // // compute stats // diff --git a/src/llama-batch.h b/src/llama-batch.h index 05c03d018d..8e6fac0efa 100644 --- a/src/llama-batch.h +++ b/src/llama-batch.h @@ -81,8 +81,7 @@ public: const llama_memory_i * memory, uint32_t n_embd, uint32_t n_seq_max, - bool output_all, - bool sampling = false); + bool output_all); const llama_batch & get_batch() const; diff --git a/src/llama-context.cpp b/src/llama-context.cpp index cfd71748de..7ac7ac04ea 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1388,13 +1388,36 @@ int llama_context::decode(const llama_batch & batch_inp) { const int64_t n_embd = hparams.n_embd_inp(); // when computing embeddings, all tokens are output - const bool output_all = cparams.embeddings; + const bool output_all = cparams.embeddings; const bool has_samplers = !sampling.samplers.empty(); - if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, - cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, - output_all, - has_samplers)) { + const uint32_t n_seq_max = cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max; + + // TODO: avoid this workaround in the future + if (has_samplers && batch_inp.logits) { + std::vector seq_output_count(n_seq_max, 0); + + for (int32_t i = 0; i < batch_inp.n_tokens; ++i) { + if (batch_inp.logits[i] == 0) { + continue; + } + + const int ns = batch_inp.n_seq_id ? batch_inp.n_seq_id[i] : 1; + + for (int32_t s = 0; s < ns; ++s) { + const llama_seq_id seq_id = batch_inp.seq_id ? batch_inp.seq_id[i][s] : 0; + + seq_output_count[seq_id]++; + if (seq_output_count[seq_id] > 1) { + LLAMA_LOG_ERROR("%s: backend sampling requires at most one output token per sequence (seq_id %d had %d)\n", + __func__, seq_id, seq_output_count[seq_id]); + return -1; + } + } + } + } + + if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, n_seq_max, output_all)) { LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); return -1; }