From a519aea35c8265b90d12f43cd148b3ed9060fa3f Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Wed, 17 Dec 2025 13:49:39 +0100 Subject: [PATCH] tests : fix batch token position tracking in test_backend_sampler.cpp --- tests/test-backend-sampler.cpp | 27 +++++++++------------------ 1 file changed, 9 insertions(+), 18 deletions(-) diff --git a/tests/test-backend-sampler.cpp b/tests/test-backend-sampler.cpp index d1b4287cdc..65a9e718e7 100644 --- a/tests/test-backend-sampler.cpp +++ b/tests/test-backend-sampler.cpp @@ -97,8 +97,6 @@ struct test_model_context { last_batch_info.clear(); llama_batch batch = llama_batch_init(512, 0, prompts.size()); - int n_tokens_per_prompt = 0; - auto vocab = get_vocab(); for (const auto & [seq_id, prompt] : prompts) { std::vector tokens; @@ -108,18 +106,6 @@ struct test_model_context { int n_tokens = llama_tokenize(vocab, prompt.c_str(), prompt.length(), prompt_tokens.data(), prompt_tokens.size(), false, false); - //TODO: refactor this function to just handle a single prompt at a time - // to avoid this check and complexity. - if (n_tokens_per_prompt == 0) { - n_tokens_per_prompt = n_tokens; - } else { - if (n_tokens != n_tokens_per_prompt) { - fprintf(stderr, "Error: prompts must have the same number of tokens\n"); - llama_batch_free(batch); - return false; - } - n_tokens_per_prompt = n_tokens; - } if (n_tokens < 0) { fprintf(stderr, "Warning: tokenization failed for seq_id %d\n", seq_id); llama_batch_free(batch); @@ -130,11 +116,16 @@ struct test_model_context { tokens.push_back(prompt_tokens[i]); } - for (size_t i = 0; i < tokens.size(); i++) { - common_batch_add(batch, tokens[i], i, { seq_id }, i == tokens.size() - 1); + if (seq_positions.find(seq_id) == seq_positions.end()) { + seq_positions[seq_id] = 0; } - seq_positions[seq_id] = tokens.size(); + int32_t start_pos = seq_positions[seq_id]; + for (size_t i = 0; i < tokens.size(); i++) { + common_batch_add(batch, tokens[i], start_pos + i, { seq_id }, i == tokens.size() - 1); + } + + seq_positions[seq_id] = start_pos + tokens.size(); } @@ -375,7 +366,7 @@ static void test_backend_temp_sampling(const char * model_path) { return; } - if (!test_ctx.decode({{0, "Some where over"}, {1, "Once upon a"}})) { + if (!test_ctx.decode({{0, "Some where over the"}, {1, "Once upon a"}})) { GGML_ASSERT(false && "Failed to decode token"); }