tests : fix batch token position tracking in test_backend_sampler.cpp

This commit is contained in:
Daniel Bevenius 2025-12-17 13:49:39 +01:00
parent cc31e6a20e
commit a519aea35c
No known key found for this signature in database
1 changed files with 9 additions and 18 deletions

View File

@ -97,8 +97,6 @@ struct test_model_context {
last_batch_info.clear();
llama_batch batch = llama_batch_init(512, 0, prompts.size());
int n_tokens_per_prompt = 0;
auto vocab = get_vocab();
for (const auto & [seq_id, prompt] : prompts) {
std::vector<llama_token> tokens;
@ -108,18 +106,6 @@ struct test_model_context {
int n_tokens = llama_tokenize(vocab, prompt.c_str(), prompt.length(),
prompt_tokens.data(), prompt_tokens.size(),
false, false);
//TODO: refactor this function to just handle a single prompt at a time
// to avoid this check and complexity.
if (n_tokens_per_prompt == 0) {
n_tokens_per_prompt = n_tokens;
} else {
if (n_tokens != n_tokens_per_prompt) {
fprintf(stderr, "Error: prompts must have the same number of tokens\n");
llama_batch_free(batch);
return false;
}
n_tokens_per_prompt = n_tokens;
}
if (n_tokens < 0) {
fprintf(stderr, "Warning: tokenization failed for seq_id %d\n", seq_id);
llama_batch_free(batch);
@ -130,11 +116,16 @@ struct test_model_context {
tokens.push_back(prompt_tokens[i]);
}
for (size_t i = 0; i < tokens.size(); i++) {
common_batch_add(batch, tokens[i], i, { seq_id }, i == tokens.size() - 1);
if (seq_positions.find(seq_id) == seq_positions.end()) {
seq_positions[seq_id] = 0;
}
seq_positions[seq_id] = tokens.size();
int32_t start_pos = seq_positions[seq_id];
for (size_t i = 0; i < tokens.size(); i++) {
common_batch_add(batch, tokens[i], start_pos + i, { seq_id }, i == tokens.size() - 1);
}
seq_positions[seq_id] = start_pos + tokens.size();
}
@ -375,7 +366,7 @@ static void test_backend_temp_sampling(const char * model_path) {
return;
}
if (!test_ctx.decode({{0, "Some where over"}, {1, "Once upon a"}})) {
if (!test_ctx.decode({{0, "Some where over the"}, {1, "Once upon a"}})) {
GGML_ASSERT(false && "Failed to decode token");
}