diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 86a1a4ba18..f0866a9ca1 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -28,7 +28,8 @@ bool llama_batch_allocr::init( const llama_memory_i * memory, uint32_t n_embd, uint32_t n_seq_max, - bool output_all) { + bool output_all, + bool backend_sampling) { clear(); batch = batch_inp; @@ -145,6 +146,24 @@ bool llama_batch_allocr::init( } } + if (backend_sampling) { + std::vector seq_output_count(n_seq_max, 0); + + for (int32_t i = 0; i < batch.n_tokens; ++i) { + if (batch.logits[i] == 0) { + continue; + } + for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) { + const llama_seq_id seq_id = batch.seq_id[i][s]; + seq_output_count[seq_id]++; + if (seq_output_count[seq_id] > 1) { + LLAMA_LOG_ERROR("%s: backend sampling allows at most one output token per sequence (%d)\n", __func__, seq_id); + return false; + } + } + } + } + // // compute stats // diff --git a/src/llama-batch.h b/src/llama-batch.h index 209cf3699d..d8751274f3 100644 --- a/src/llama-batch.h +++ b/src/llama-batch.h @@ -79,7 +79,8 @@ public: const llama_memory_i * memory, uint32_t n_embd, uint32_t n_seq_max, - bool output_all); + bool output_all, + bool backend_sampling = false); const llama_batch & get_batch() const; diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 25d3528434..0b7f3adf9b 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1155,7 +1155,10 @@ int llama_context::decode(const llama_batch & batch_inp) { // when computing embeddings, all tokens are output const bool output_all = cparams.embeddings; - if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, output_all)) { + if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, + cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, + output_all, + !samplers.empty())) { LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); return -1; } diff --git a/tests/test-backend-sampler.cpp b/tests/test-backend-sampler.cpp index aa018e645a..c6d0d1a38d 100644 --- a/tests/test-backend-sampler.cpp +++ b/tests/test-backend-sampler.cpp @@ -629,6 +629,46 @@ static void test_backend_set_sampler(const char * model_path) { printf("dist sampled token = %d, string='%s'\n", new_backend_token, new_backend_token_str.c_str()); } +static void test_backend_max_outputs(const char * model_path) { + test_model_context test_ctx; + + const int seq_id = 0; + const int32_t seed = 88; + llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); + llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params); + llama_sampler_chain_add(backend_sampler_chain, llama_sampler_backend_init_dist(seed)); + std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain }}; + + if (!test_ctx.setup(model_path, backend_sampler_configs)) { + return; + } + + llama_batch batch = llama_batch_init(512, 0, 1); + std::string prompt = "Hello"; + + std::vector tokens; + tokens.push_back(llama_vocab_bos(test_ctx.vocab)); + + std::vector prompt_tokens(32); + int n_tokens = llama_tokenize(test_ctx.vocab, prompt.c_str(), prompt.length(), + prompt_tokens.data(), prompt_tokens.size(), + false, false); + for (int i = 0; i < n_tokens; i++) { + tokens.push_back(prompt_tokens[i]); + } + + for (size_t i = 0; i < tokens.size(); i++) { + // set all tokens as output to trigger error + common_batch_add(batch, tokens[i], i, { seq_id }, true); + } + + printf(">>> test_max_outputs expected error start:\n"); + const int ret = llama_decode(test_ctx.ctx, batch); + GGML_ASSERT(ret != 0 && "llama_decode should not succeed multiple outputs per sequence"); + printf("<<< test_max_outputs expected error end.\n"); + llama_batch_free(batch); +} + struct backend_test_case { const char * name; void (*fn)(const char *); @@ -644,6 +684,7 @@ static const backend_test_case BACKEND_TESTS[] = { { "dist", test_backend_dist_sampling, true }, { "dist_and_cpu", test_backend_dist_sampling_and_cpu, true }, { "set_sampler", test_backend_set_sampler, true }, + { "max_outputs", test_backend_max_outputs, true }, }; struct backend_cli_args {