diff --git a/tests/test-backend-sampler.cpp b/tests/test-backend-sampler.cpp index 23f9f54137..053fc9ccc4 100644 --- a/tests/test-backend-sampler.cpp +++ b/tests/test-backend-sampler.cpp @@ -18,10 +18,9 @@ #include struct test_model_context { - llama_model_ptr model; - llama_context_ptr ctx; - const llama_vocab * vocab = nullptr; - int n_vocab = 0; + llama_model_ptr model; + llama_context_ptr ctx; + int n_vocab = 0; std::unordered_map seq_positions; std::unordered_map last_batch_info; @@ -45,18 +44,16 @@ struct test_model_context { if (!model) { fprintf(stderr, "Warning: failed to load model '%s', skipping test\n", model_path); - cleanup(); return false; } - vocab = llama_model_get_vocab(model.get()); - n_vocab = llama_vocab_n_tokens(vocab); + n_vocab = llama_vocab_n_tokens(get_vocab()); fprintf(stderr, "Vocabulary size: %d\n", n_vocab); return true; } bool setup(const char * model_path, std::vector & configs, int32_t n_seq_max = -1) { - if (model == nullptr) { + if (!model) { load_model(model_path); } @@ -82,9 +79,8 @@ struct test_model_context { } ctx.reset(llama_init_from_model(model.get(), cparams)); - if (ctx == nullptr) { + if (!ctx) { fprintf(stderr, "Warning: failed to create context, skipping test\n"); - cleanup(); return false; } llama_set_warmup(ctx.get(), false); @@ -93,7 +89,7 @@ struct test_model_context { } bool decode(const std::map & prompts) { - if (ctx == nullptr || vocab == nullptr) { + if (!ctx) { fprintf(stderr, "Error: context not initialized, call setup() first\n"); return false; } @@ -103,6 +99,7 @@ struct test_model_context { int n_tokens_per_prompt = 0; + auto vocab = get_vocab(); for (const auto & [seq_id, prompt] : prompts) { std::vector tokens; tokens.push_back(llama_vocab_bos(vocab)); @@ -246,10 +243,10 @@ struct test_model_context { std::string token_to_piece(llama_token token, bool special) { std::string piece; piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n' - const int n_chars = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special); + const int n_chars = llama_token_to_piece(get_vocab(), token, &piece[0], piece.size(), 0, special); if (n_chars < 0) { piece.resize(-n_chars); - int check = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special); + int check = llama_token_to_piece(get_vocab(), token, &piece[0], piece.size(), 0, special); GGML_ASSERT(check == -n_chars); } else { @@ -260,20 +257,15 @@ struct test_model_context { } void reset() { - if (ctx) { - ctx.reset(); - } + ctx.reset(); seq_positions.clear(); last_batch_info.clear(); } - void cleanup() { - vocab = nullptr; + const llama_vocab * get_vocab() const { + return model ? llama_model_get_vocab(model.get()) : nullptr; } - ~test_model_context() { - cleanup(); - } }; static void test_backend_greedy_sampling(const char * model_path) { @@ -795,7 +787,7 @@ static void test_backend_logit_bias_sampling(const char * model_path) { // Get the token for the piece "World". const std::string piece = "World"; std::vector tokens(16); - llama_tokenize(test_ctx.vocab, piece.c_str(), piece.size(), tokens.data(), tokens.size(), false, false); + llama_tokenize(test_ctx.get_vocab(), piece.c_str(), piece.size(), tokens.data(), tokens.size(), false, false); llama_token bias_token = tokens[0]; logit_bias.push_back({ bias_token, +100.0f }); printf("biasing token piece '%s' -> token id %d\n", piece.c_str(), bias_token); @@ -803,7 +795,7 @@ static void test_backend_logit_bias_sampling(const char * model_path) { struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params)); llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_logit_bias( - llama_vocab_n_tokens(test_ctx.vocab), + llama_vocab_n_tokens(test_ctx.get_vocab()), logit_bias.size(), logit_bias.data())); llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(88)); @@ -1062,10 +1054,10 @@ static void test_backend_max_outputs(const char * model_path) { std::string prompt = "Hello"; std::vector tokens; - tokens.push_back(llama_vocab_bos(test_ctx.vocab)); + tokens.push_back(llama_vocab_bos(test_ctx.get_vocab())); std::vector prompt_tokens(32); - int n_tokens = llama_tokenize(test_ctx.vocab, prompt.c_str(), prompt.length(), + int n_tokens = llama_tokenize(test_ctx.get_vocab(), prompt.c_str(), prompt.length(), prompt_tokens.data(), prompt_tokens.size(), false, false); for (int i = 0; i < n_tokens; i++) {