tests : remove vocab member from test_model_context

Also includes some minor cleanups related to nullptr checks.
This commit is contained in:
Daniel Bevenius 2025-12-17 11:46:36 +01:00
parent 9845996919
commit 76a1b7fe8c
No known key found for this signature in database
1 changed files with 17 additions and 25 deletions

View File

@ -18,10 +18,9 @@
#include <vector> #include <vector>
struct test_model_context { struct test_model_context {
llama_model_ptr model; llama_model_ptr model;
llama_context_ptr ctx; llama_context_ptr ctx;
const llama_vocab * vocab = nullptr; int n_vocab = 0;
int n_vocab = 0;
std::unordered_map<llama_seq_id, int32_t> seq_positions; std::unordered_map<llama_seq_id, int32_t> seq_positions;
std::unordered_map<llama_seq_id, int32_t> last_batch_info; std::unordered_map<llama_seq_id, int32_t> last_batch_info;
@ -45,18 +44,16 @@ struct test_model_context {
if (!model) { if (!model) {
fprintf(stderr, "Warning: failed to load model '%s', skipping test\n", model_path); fprintf(stderr, "Warning: failed to load model '%s', skipping test\n", model_path);
cleanup();
return false; return false;
} }
vocab = llama_model_get_vocab(model.get()); n_vocab = llama_vocab_n_tokens(get_vocab());
n_vocab = llama_vocab_n_tokens(vocab);
fprintf(stderr, "Vocabulary size: %d\n", n_vocab); fprintf(stderr, "Vocabulary size: %d\n", n_vocab);
return true; return true;
} }
bool setup(const char * model_path, std::vector<llama_sampler_seq_config> & configs, int32_t n_seq_max = -1) { bool setup(const char * model_path, std::vector<llama_sampler_seq_config> & configs, int32_t n_seq_max = -1) {
if (model == nullptr) { if (!model) {
load_model(model_path); load_model(model_path);
} }
@ -82,9 +79,8 @@ struct test_model_context {
} }
ctx.reset(llama_init_from_model(model.get(), cparams)); ctx.reset(llama_init_from_model(model.get(), cparams));
if (ctx == nullptr) { if (!ctx) {
fprintf(stderr, "Warning: failed to create context, skipping test\n"); fprintf(stderr, "Warning: failed to create context, skipping test\n");
cleanup();
return false; return false;
} }
llama_set_warmup(ctx.get(), false); llama_set_warmup(ctx.get(), false);
@ -93,7 +89,7 @@ struct test_model_context {
} }
bool decode(const std::map<llama_seq_id, std::string> & prompts) { bool decode(const std::map<llama_seq_id, std::string> & prompts) {
if (ctx == nullptr || vocab == nullptr) { if (!ctx) {
fprintf(stderr, "Error: context not initialized, call setup() first\n"); fprintf(stderr, "Error: context not initialized, call setup() first\n");
return false; return false;
} }
@ -103,6 +99,7 @@ struct test_model_context {
int n_tokens_per_prompt = 0; int n_tokens_per_prompt = 0;
auto vocab = get_vocab();
for (const auto & [seq_id, prompt] : prompts) { for (const auto & [seq_id, prompt] : prompts) {
std::vector<llama_token> tokens; std::vector<llama_token> tokens;
tokens.push_back(llama_vocab_bos(vocab)); tokens.push_back(llama_vocab_bos(vocab));
@ -246,10 +243,10 @@ struct test_model_context {
std::string token_to_piece(llama_token token, bool special) { std::string token_to_piece(llama_token token, bool special) {
std::string piece; std::string piece;
piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n' piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n'
const int n_chars = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special); const int n_chars = llama_token_to_piece(get_vocab(), token, &piece[0], piece.size(), 0, special);
if (n_chars < 0) { if (n_chars < 0) {
piece.resize(-n_chars); piece.resize(-n_chars);
int check = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special); int check = llama_token_to_piece(get_vocab(), token, &piece[0], piece.size(), 0, special);
GGML_ASSERT(check == -n_chars); GGML_ASSERT(check == -n_chars);
} }
else { else {
@ -260,20 +257,15 @@ struct test_model_context {
} }
void reset() { void reset() {
if (ctx) { ctx.reset();
ctx.reset();
}
seq_positions.clear(); seq_positions.clear();
last_batch_info.clear(); last_batch_info.clear();
} }
void cleanup() { const llama_vocab * get_vocab() const {
vocab = nullptr; return model ? llama_model_get_vocab(model.get()) : nullptr;
} }
~test_model_context() {
cleanup();
}
}; };
static void test_backend_greedy_sampling(const char * model_path) { static void test_backend_greedy_sampling(const char * model_path) {
@ -795,7 +787,7 @@ static void test_backend_logit_bias_sampling(const char * model_path) {
// Get the token for the piece "World". // Get the token for the piece "World".
const std::string piece = "World"; const std::string piece = "World";
std::vector<llama_token> tokens(16); std::vector<llama_token> tokens(16);
llama_tokenize(test_ctx.vocab, piece.c_str(), piece.size(), tokens.data(), tokens.size(), false, false); llama_tokenize(test_ctx.get_vocab(), piece.c_str(), piece.size(), tokens.data(), tokens.size(), false, false);
llama_token bias_token = tokens[0]; llama_token bias_token = tokens[0];
logit_bias.push_back({ bias_token, +100.0f }); logit_bias.push_back({ bias_token, +100.0f });
printf("biasing token piece '%s' -> token id %d\n", piece.c_str(), bias_token); printf("biasing token piece '%s' -> token id %d\n", piece.c_str(), bias_token);
@ -803,7 +795,7 @@ static void test_backend_logit_bias_sampling(const char * model_path) {
struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params)); llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_logit_bias( llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_logit_bias(
llama_vocab_n_tokens(test_ctx.vocab), llama_vocab_n_tokens(test_ctx.get_vocab()),
logit_bias.size(), logit_bias.size(),
logit_bias.data())); logit_bias.data()));
llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(88)); llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(88));
@ -1062,10 +1054,10 @@ static void test_backend_max_outputs(const char * model_path) {
std::string prompt = "Hello"; std::string prompt = "Hello";
std::vector<llama_token> tokens; std::vector<llama_token> tokens;
tokens.push_back(llama_vocab_bos(test_ctx.vocab)); tokens.push_back(llama_vocab_bos(test_ctx.get_vocab()));
std::vector<llama_token> prompt_tokens(32); std::vector<llama_token> prompt_tokens(32);
int n_tokens = llama_tokenize(test_ctx.vocab, prompt.c_str(), prompt.length(), int n_tokens = llama_tokenize(test_ctx.get_vocab(), prompt.c_str(), prompt.length(),
prompt_tokens.data(), prompt_tokens.size(), prompt_tokens.data(), prompt_tokens.size(),
false, false); false, false);
for (int i = 0; i < n_tokens; i++) { for (int i = 0; i < n_tokens; i++) {