tests : fix batch token position tracking in test_backend_sampler.cpp
This commit is contained in:
parent
cc31e6a20e
commit
a519aea35c
|
|
@ -97,8 +97,6 @@ struct test_model_context {
|
||||||
last_batch_info.clear();
|
last_batch_info.clear();
|
||||||
llama_batch batch = llama_batch_init(512, 0, prompts.size());
|
llama_batch batch = llama_batch_init(512, 0, prompts.size());
|
||||||
|
|
||||||
int n_tokens_per_prompt = 0;
|
|
||||||
|
|
||||||
auto vocab = get_vocab();
|
auto vocab = get_vocab();
|
||||||
for (const auto & [seq_id, prompt] : prompts) {
|
for (const auto & [seq_id, prompt] : prompts) {
|
||||||
std::vector<llama_token> tokens;
|
std::vector<llama_token> tokens;
|
||||||
|
|
@ -108,18 +106,6 @@ struct test_model_context {
|
||||||
int n_tokens = llama_tokenize(vocab, prompt.c_str(), prompt.length(),
|
int n_tokens = llama_tokenize(vocab, prompt.c_str(), prompt.length(),
|
||||||
prompt_tokens.data(), prompt_tokens.size(),
|
prompt_tokens.data(), prompt_tokens.size(),
|
||||||
false, false);
|
false, false);
|
||||||
//TODO: refactor this function to just handle a single prompt at a time
|
|
||||||
// to avoid this check and complexity.
|
|
||||||
if (n_tokens_per_prompt == 0) {
|
|
||||||
n_tokens_per_prompt = n_tokens;
|
|
||||||
} else {
|
|
||||||
if (n_tokens != n_tokens_per_prompt) {
|
|
||||||
fprintf(stderr, "Error: prompts must have the same number of tokens\n");
|
|
||||||
llama_batch_free(batch);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
n_tokens_per_prompt = n_tokens;
|
|
||||||
}
|
|
||||||
if (n_tokens < 0) {
|
if (n_tokens < 0) {
|
||||||
fprintf(stderr, "Warning: tokenization failed for seq_id %d\n", seq_id);
|
fprintf(stderr, "Warning: tokenization failed for seq_id %d\n", seq_id);
|
||||||
llama_batch_free(batch);
|
llama_batch_free(batch);
|
||||||
|
|
@ -130,11 +116,16 @@ struct test_model_context {
|
||||||
tokens.push_back(prompt_tokens[i]);
|
tokens.push_back(prompt_tokens[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (size_t i = 0; i < tokens.size(); i++) {
|
if (seq_positions.find(seq_id) == seq_positions.end()) {
|
||||||
common_batch_add(batch, tokens[i], i, { seq_id }, i == tokens.size() - 1);
|
seq_positions[seq_id] = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
seq_positions[seq_id] = tokens.size();
|
int32_t start_pos = seq_positions[seq_id];
|
||||||
|
for (size_t i = 0; i < tokens.size(); i++) {
|
||||||
|
common_batch_add(batch, tokens[i], start_pos + i, { seq_id }, i == tokens.size() - 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
seq_positions[seq_id] = start_pos + tokens.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -375,7 +366,7 @@ static void test_backend_temp_sampling(const char * model_path) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!test_ctx.decode({{0, "Some where over"}, {1, "Once upon a"}})) {
|
if (!test_ctx.decode({{0, "Some where over the"}, {1, "Once upon a"}})) {
|
||||||
GGML_ASSERT(false && "Failed to decode token");
|
GGML_ASSERT(false && "Failed to decode token");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue