sampling : ensure at most one output token per seq
This commit adds a check in the batch allocator to ensure that when backend sampling is enabled, at most one output token is specified per sequence.
This commit is contained in:
parent
82957a90f2
commit
311c1a347f
|
|
@ -28,7 +28,8 @@ bool llama_batch_allocr::init(
|
||||||
const llama_memory_i * memory,
|
const llama_memory_i * memory,
|
||||||
uint32_t n_embd,
|
uint32_t n_embd,
|
||||||
uint32_t n_seq_max,
|
uint32_t n_seq_max,
|
||||||
bool output_all) {
|
bool output_all,
|
||||||
|
bool backend_sampling) {
|
||||||
clear();
|
clear();
|
||||||
|
|
||||||
batch = batch_inp;
|
batch = batch_inp;
|
||||||
|
|
@ -145,6 +146,24 @@ bool llama_batch_allocr::init(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (backend_sampling) {
|
||||||
|
std::vector<int32_t> seq_output_count(n_seq_max, 0);
|
||||||
|
|
||||||
|
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
||||||
|
if (batch.logits[i] == 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
|
||||||
|
const llama_seq_id seq_id = batch.seq_id[i][s];
|
||||||
|
seq_output_count[seq_id]++;
|
||||||
|
if (seq_output_count[seq_id] > 1) {
|
||||||
|
LLAMA_LOG_ERROR("%s: backend sampling allows at most one output token per sequence (%d)\n", __func__, seq_id);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// compute stats
|
// compute stats
|
||||||
//
|
//
|
||||||
|
|
|
||||||
|
|
@ -79,7 +79,8 @@ public:
|
||||||
const llama_memory_i * memory,
|
const llama_memory_i * memory,
|
||||||
uint32_t n_embd,
|
uint32_t n_embd,
|
||||||
uint32_t n_seq_max,
|
uint32_t n_seq_max,
|
||||||
bool output_all);
|
bool output_all,
|
||||||
|
bool backend_sampling = false);
|
||||||
|
|
||||||
const llama_batch & get_batch() const;
|
const llama_batch & get_batch() const;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1155,7 +1155,10 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||||
// when computing embeddings, all tokens are output
|
// when computing embeddings, all tokens are output
|
||||||
const bool output_all = cparams.embeddings;
|
const bool output_all = cparams.embeddings;
|
||||||
|
|
||||||
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, output_all)) {
|
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd,
|
||||||
|
cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max,
|
||||||
|
output_all,
|
||||||
|
!samplers.empty())) {
|
||||||
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -629,6 +629,46 @@ static void test_backend_set_sampler(const char * model_path) {
|
||||||
printf("dist sampled token = %d, string='%s'\n", new_backend_token, new_backend_token_str.c_str());
|
printf("dist sampled token = %d, string='%s'\n", new_backend_token, new_backend_token_str.c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void test_backend_max_outputs(const char * model_path) {
|
||||||
|
test_model_context test_ctx;
|
||||||
|
|
||||||
|
const int seq_id = 0;
|
||||||
|
const int32_t seed = 88;
|
||||||
|
llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
|
||||||
|
llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params);
|
||||||
|
llama_sampler_chain_add(backend_sampler_chain, llama_sampler_backend_init_dist(seed));
|
||||||
|
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain }};
|
||||||
|
|
||||||
|
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_batch batch = llama_batch_init(512, 0, 1);
|
||||||
|
std::string prompt = "Hello";
|
||||||
|
|
||||||
|
std::vector<llama_token> tokens;
|
||||||
|
tokens.push_back(llama_vocab_bos(test_ctx.vocab));
|
||||||
|
|
||||||
|
std::vector<llama_token> prompt_tokens(32);
|
||||||
|
int n_tokens = llama_tokenize(test_ctx.vocab, prompt.c_str(), prompt.length(),
|
||||||
|
prompt_tokens.data(), prompt_tokens.size(),
|
||||||
|
false, false);
|
||||||
|
for (int i = 0; i < n_tokens; i++) {
|
||||||
|
tokens.push_back(prompt_tokens[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 0; i < tokens.size(); i++) {
|
||||||
|
// set all tokens as output to trigger error
|
||||||
|
common_batch_add(batch, tokens[i], i, { seq_id }, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
printf(">>> test_max_outputs expected error start:\n");
|
||||||
|
const int ret = llama_decode(test_ctx.ctx, batch);
|
||||||
|
GGML_ASSERT(ret != 0 && "llama_decode should not succeed multiple outputs per sequence");
|
||||||
|
printf("<<< test_max_outputs expected error end.\n");
|
||||||
|
llama_batch_free(batch);
|
||||||
|
}
|
||||||
|
|
||||||
struct backend_test_case {
|
struct backend_test_case {
|
||||||
const char * name;
|
const char * name;
|
||||||
void (*fn)(const char *);
|
void (*fn)(const char *);
|
||||||
|
|
@ -644,6 +684,7 @@ static const backend_test_case BACKEND_TESTS[] = {
|
||||||
{ "dist", test_backend_dist_sampling, true },
|
{ "dist", test_backend_dist_sampling, true },
|
||||||
{ "dist_and_cpu", test_backend_dist_sampling_and_cpu, true },
|
{ "dist_and_cpu", test_backend_dist_sampling_and_cpu, true },
|
||||||
{ "set_sampler", test_backend_set_sampler, true },
|
{ "set_sampler", test_backend_set_sampler, true },
|
||||||
|
{ "max_outputs", test_backend_max_outputs, true },
|
||||||
};
|
};
|
||||||
|
|
||||||
struct backend_cli_args {
|
struct backend_cli_args {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue