sampling : ensure at most one output token per seq

This commit adds a check in the batch allocator to ensure that when
backend sampling is enabled, at most one output token is specified per
sequence.
This commit is contained in:
Daniel Bevenius 2025-11-18 16:01:54 +01:00
parent 82957a90f2
commit 311c1a347f
No known key found for this signature in database
4 changed files with 67 additions and 3 deletions

View File

@ -28,7 +28,8 @@ bool llama_batch_allocr::init(
const llama_memory_i * memory, const llama_memory_i * memory,
uint32_t n_embd, uint32_t n_embd,
uint32_t n_seq_max, uint32_t n_seq_max,
bool output_all) { bool output_all,
bool backend_sampling) {
clear(); clear();
batch = batch_inp; batch = batch_inp;
@ -145,6 +146,24 @@ bool llama_batch_allocr::init(
} }
} }
if (backend_sampling) {
std::vector<int32_t> seq_output_count(n_seq_max, 0);
for (int32_t i = 0; i < batch.n_tokens; ++i) {
if (batch.logits[i] == 0) {
continue;
}
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
const llama_seq_id seq_id = batch.seq_id[i][s];
seq_output_count[seq_id]++;
if (seq_output_count[seq_id] > 1) {
LLAMA_LOG_ERROR("%s: backend sampling allows at most one output token per sequence (%d)\n", __func__, seq_id);
return false;
}
}
}
}
// //
// compute stats // compute stats
// //

View File

@ -79,7 +79,8 @@ public:
const llama_memory_i * memory, const llama_memory_i * memory,
uint32_t n_embd, uint32_t n_embd,
uint32_t n_seq_max, uint32_t n_seq_max,
bool output_all); bool output_all,
bool backend_sampling = false);
const llama_batch & get_batch() const; const llama_batch & get_batch() const;

View File

@ -1155,7 +1155,10 @@ int llama_context::decode(const llama_batch & batch_inp) {
// when computing embeddings, all tokens are output // when computing embeddings, all tokens are output
const bool output_all = cparams.embeddings; const bool output_all = cparams.embeddings;
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, output_all)) { if (!balloc->init(batch_inp, vocab, memory.get(), n_embd,
cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max,
output_all,
!samplers.empty())) {
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
return -1; return -1;
} }

View File

@ -629,6 +629,46 @@ static void test_backend_set_sampler(const char * model_path) {
printf("dist sampled token = %d, string='%s'\n", new_backend_token, new_backend_token_str.c_str()); printf("dist sampled token = %d, string='%s'\n", new_backend_token, new_backend_token_str.c_str());
} }
static void test_backend_max_outputs(const char * model_path) {
test_model_context test_ctx;
const int seq_id = 0;
const int32_t seed = 88;
llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params);
llama_sampler_chain_add(backend_sampler_chain, llama_sampler_backend_init_dist(seed));
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain }};
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
return;
}
llama_batch batch = llama_batch_init(512, 0, 1);
std::string prompt = "Hello";
std::vector<llama_token> tokens;
tokens.push_back(llama_vocab_bos(test_ctx.vocab));
std::vector<llama_token> prompt_tokens(32);
int n_tokens = llama_tokenize(test_ctx.vocab, prompt.c_str(), prompt.length(),
prompt_tokens.data(), prompt_tokens.size(),
false, false);
for (int i = 0; i < n_tokens; i++) {
tokens.push_back(prompt_tokens[i]);
}
for (size_t i = 0; i < tokens.size(); i++) {
// set all tokens as output to trigger error
common_batch_add(batch, tokens[i], i, { seq_id }, true);
}
printf(">>> test_max_outputs expected error start:\n");
const int ret = llama_decode(test_ctx.ctx, batch);
GGML_ASSERT(ret != 0 && "llama_decode should not succeed multiple outputs per sequence");
printf("<<< test_max_outputs expected error end.\n");
llama_batch_free(batch);
}
struct backend_test_case { struct backend_test_case {
const char * name; const char * name;
void (*fn)(const char *); void (*fn)(const char *);
@ -644,6 +684,7 @@ static const backend_test_case BACKEND_TESTS[] = {
{ "dist", test_backend_dist_sampling, true }, { "dist", test_backend_dist_sampling, true },
{ "dist_and_cpu", test_backend_dist_sampling_and_cpu, true }, { "dist_and_cpu", test_backend_dist_sampling_and_cpu, true },
{ "set_sampler", test_backend_set_sampler, true }, { "set_sampler", test_backend_set_sampler, true },
{ "max_outputs", test_backend_max_outputs, true },
}; };
struct backend_cli_args { struct backend_cli_args {