sampling : ensure at most one output token per seq
This commit adds a check in the batch allocator to ensure that when backend sampling is enabled, at most one output token is specified per sequence.
This commit is contained in:
parent
82957a90f2
commit
311c1a347f
|
|
@ -28,7 +28,8 @@ bool llama_batch_allocr::init(
|
|||
const llama_memory_i * memory,
|
||||
uint32_t n_embd,
|
||||
uint32_t n_seq_max,
|
||||
bool output_all) {
|
||||
bool output_all,
|
||||
bool backend_sampling) {
|
||||
clear();
|
||||
|
||||
batch = batch_inp;
|
||||
|
|
@ -145,6 +146,24 @@ bool llama_batch_allocr::init(
|
|||
}
|
||||
}
|
||||
|
||||
if (backend_sampling) {
|
||||
std::vector<int32_t> seq_output_count(n_seq_max, 0);
|
||||
|
||||
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
||||
if (batch.logits[i] == 0) {
|
||||
continue;
|
||||
}
|
||||
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
|
||||
const llama_seq_id seq_id = batch.seq_id[i][s];
|
||||
seq_output_count[seq_id]++;
|
||||
if (seq_output_count[seq_id] > 1) {
|
||||
LLAMA_LOG_ERROR("%s: backend sampling allows at most one output token per sequence (%d)\n", __func__, seq_id);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// compute stats
|
||||
//
|
||||
|
|
|
|||
|
|
@ -79,7 +79,8 @@ public:
|
|||
const llama_memory_i * memory,
|
||||
uint32_t n_embd,
|
||||
uint32_t n_seq_max,
|
||||
bool output_all);
|
||||
bool output_all,
|
||||
bool backend_sampling = false);
|
||||
|
||||
const llama_batch & get_batch() const;
|
||||
|
||||
|
|
|
|||
|
|
@ -1155,7 +1155,10 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
// when computing embeddings, all tokens are output
|
||||
const bool output_all = cparams.embeddings;
|
||||
|
||||
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, output_all)) {
|
||||
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd,
|
||||
cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max,
|
||||
output_all,
|
||||
!samplers.empty())) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
||||
return -1;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -629,6 +629,46 @@ static void test_backend_set_sampler(const char * model_path) {
|
|||
printf("dist sampled token = %d, string='%s'\n", new_backend_token, new_backend_token_str.c_str());
|
||||
}
|
||||
|
||||
static void test_backend_max_outputs(const char * model_path) {
|
||||
test_model_context test_ctx;
|
||||
|
||||
const int seq_id = 0;
|
||||
const int32_t seed = 88;
|
||||
llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
|
||||
llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params);
|
||||
llama_sampler_chain_add(backend_sampler_chain, llama_sampler_backend_init_dist(seed));
|
||||
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain }};
|
||||
|
||||
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
|
||||
return;
|
||||
}
|
||||
|
||||
llama_batch batch = llama_batch_init(512, 0, 1);
|
||||
std::string prompt = "Hello";
|
||||
|
||||
std::vector<llama_token> tokens;
|
||||
tokens.push_back(llama_vocab_bos(test_ctx.vocab));
|
||||
|
||||
std::vector<llama_token> prompt_tokens(32);
|
||||
int n_tokens = llama_tokenize(test_ctx.vocab, prompt.c_str(), prompt.length(),
|
||||
prompt_tokens.data(), prompt_tokens.size(),
|
||||
false, false);
|
||||
for (int i = 0; i < n_tokens; i++) {
|
||||
tokens.push_back(prompt_tokens[i]);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < tokens.size(); i++) {
|
||||
// set all tokens as output to trigger error
|
||||
common_batch_add(batch, tokens[i], i, { seq_id }, true);
|
||||
}
|
||||
|
||||
printf(">>> test_max_outputs expected error start:\n");
|
||||
const int ret = llama_decode(test_ctx.ctx, batch);
|
||||
GGML_ASSERT(ret != 0 && "llama_decode should not succeed multiple outputs per sequence");
|
||||
printf("<<< test_max_outputs expected error end.\n");
|
||||
llama_batch_free(batch);
|
||||
}
|
||||
|
||||
struct backend_test_case {
|
||||
const char * name;
|
||||
void (*fn)(const char *);
|
||||
|
|
@ -644,6 +684,7 @@ static const backend_test_case BACKEND_TESTS[] = {
|
|||
{ "dist", test_backend_dist_sampling, true },
|
||||
{ "dist_and_cpu", test_backend_dist_sampling_and_cpu, true },
|
||||
{ "set_sampler", test_backend_set_sampler, true },
|
||||
{ "max_outputs", test_backend_max_outputs, true },
|
||||
};
|
||||
|
||||
struct backend_cli_args {
|
||||
|
|
|
|||
Loading…
Reference in New Issue