diff --git a/tests/test-model-random.cpp b/tests/test-model-random.cpp index 51d3bfdf29..3645c00b5f 100644 --- a/tests/test-model-random.cpp +++ b/tests/test-model-random.cpp @@ -1077,7 +1077,6 @@ int main(int argc, char ** argv) { const int32_t n_shared_len = 13; // prime number, shared prompt length const int32_t n_seq_len = 127; // prime number - llama_batch batch = llama_batch_init(n_batch, 0, 1); // TODO: batch with embeddings std::vector model_variants; @@ -1119,6 +1118,8 @@ int main(int argc, char ** argv) { // TODO: avoid re-creating reference outputs for (int32_t n_seq_max : { 1, 2, 5 }) { + llama_batch batch = llama_batch_init(n_batch, 0, n_seq_max); + // TODO(later): context shift testing for (int32_t n_ctx : { n_seq_len * n_seq_max }) { @@ -1195,6 +1196,7 @@ int main(int argc, char ** argv) { for (llama_seq_id seq_id = 0; seq_id < n_seq_max; ++seq_id) { seq_id_group.push_back(seq_id); seq_id_n_past[seq_id] += shared_prompt.size(); + seq_ids_in_batch.insert(seq_id); }; for (size_t i = 0; i < shared_prompt.size(); ++i) { @@ -1272,12 +1274,12 @@ int main(int argc, char ** argv) { } } } + + llama_batch_free(batch); } llama_model_free(model); } - llama_batch_free(batch); - return 0; }