tests : fix batch token position tracking in test_backend_sampler.cpp
This commit is contained in:
parent
cc31e6a20e
commit
a519aea35c
|
|
@ -97,8 +97,6 @@ struct test_model_context {
|
|||
last_batch_info.clear();
|
||||
llama_batch batch = llama_batch_init(512, 0, prompts.size());
|
||||
|
||||
int n_tokens_per_prompt = 0;
|
||||
|
||||
auto vocab = get_vocab();
|
||||
for (const auto & [seq_id, prompt] : prompts) {
|
||||
std::vector<llama_token> tokens;
|
||||
|
|
@ -108,18 +106,6 @@ struct test_model_context {
|
|||
int n_tokens = llama_tokenize(vocab, prompt.c_str(), prompt.length(),
|
||||
prompt_tokens.data(), prompt_tokens.size(),
|
||||
false, false);
|
||||
//TODO: refactor this function to just handle a single prompt at a time
|
||||
// to avoid this check and complexity.
|
||||
if (n_tokens_per_prompt == 0) {
|
||||
n_tokens_per_prompt = n_tokens;
|
||||
} else {
|
||||
if (n_tokens != n_tokens_per_prompt) {
|
||||
fprintf(stderr, "Error: prompts must have the same number of tokens\n");
|
||||
llama_batch_free(batch);
|
||||
return false;
|
||||
}
|
||||
n_tokens_per_prompt = n_tokens;
|
||||
}
|
||||
if (n_tokens < 0) {
|
||||
fprintf(stderr, "Warning: tokenization failed for seq_id %d\n", seq_id);
|
||||
llama_batch_free(batch);
|
||||
|
|
@ -130,11 +116,16 @@ struct test_model_context {
|
|||
tokens.push_back(prompt_tokens[i]);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < tokens.size(); i++) {
|
||||
common_batch_add(batch, tokens[i], i, { seq_id }, i == tokens.size() - 1);
|
||||
if (seq_positions.find(seq_id) == seq_positions.end()) {
|
||||
seq_positions[seq_id] = 0;
|
||||
}
|
||||
|
||||
seq_positions[seq_id] = tokens.size();
|
||||
int32_t start_pos = seq_positions[seq_id];
|
||||
for (size_t i = 0; i < tokens.size(); i++) {
|
||||
common_batch_add(batch, tokens[i], start_pos + i, { seq_id }, i == tokens.size() - 1);
|
||||
}
|
||||
|
||||
seq_positions[seq_id] = start_pos + tokens.size();
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -375,7 +366,7 @@ static void test_backend_temp_sampling(const char * model_path) {
|
|||
return;
|
||||
}
|
||||
|
||||
if (!test_ctx.decode({{0, "Some where over"}, {1, "Once upon a"}})) {
|
||||
if (!test_ctx.decode({{0, "Some where over the"}, {1, "Once upon a"}})) {
|
||||
GGML_ASSERT(false && "Failed to decode token");
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue