tests : remove vocab member from test_model_context

Also includes some minor cleanups related to nullptr checks.
This commit is contained in:
Daniel Bevenius 2025-12-17 11:46:36 +01:00
parent 9845996919
commit 76a1b7fe8c
No known key found for this signature in database
1 changed files with 17 additions and 25 deletions

View File

@ -18,10 +18,9 @@
#include <vector>
struct test_model_context {
llama_model_ptr model;
llama_context_ptr ctx;
const llama_vocab * vocab = nullptr;
int n_vocab = 0;
llama_model_ptr model;
llama_context_ptr ctx;
int n_vocab = 0;
std::unordered_map<llama_seq_id, int32_t> seq_positions;
std::unordered_map<llama_seq_id, int32_t> last_batch_info;
@ -45,18 +44,16 @@ struct test_model_context {
if (!model) {
fprintf(stderr, "Warning: failed to load model '%s', skipping test\n", model_path);
cleanup();
return false;
}
vocab = llama_model_get_vocab(model.get());
n_vocab = llama_vocab_n_tokens(vocab);
n_vocab = llama_vocab_n_tokens(get_vocab());
fprintf(stderr, "Vocabulary size: %d\n", n_vocab);
return true;
}
bool setup(const char * model_path, std::vector<llama_sampler_seq_config> & configs, int32_t n_seq_max = -1) {
if (model == nullptr) {
if (!model) {
load_model(model_path);
}
@ -82,9 +79,8 @@ struct test_model_context {
}
ctx.reset(llama_init_from_model(model.get(), cparams));
if (ctx == nullptr) {
if (!ctx) {
fprintf(stderr, "Warning: failed to create context, skipping test\n");
cleanup();
return false;
}
llama_set_warmup(ctx.get(), false);
@ -93,7 +89,7 @@ struct test_model_context {
}
bool decode(const std::map<llama_seq_id, std::string> & prompts) {
if (ctx == nullptr || vocab == nullptr) {
if (!ctx) {
fprintf(stderr, "Error: context not initialized, call setup() first\n");
return false;
}
@ -103,6 +99,7 @@ struct test_model_context {
int n_tokens_per_prompt = 0;
auto vocab = get_vocab();
for (const auto & [seq_id, prompt] : prompts) {
std::vector<llama_token> tokens;
tokens.push_back(llama_vocab_bos(vocab));
@ -246,10 +243,10 @@ struct test_model_context {
std::string token_to_piece(llama_token token, bool special) {
std::string piece;
piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n'
const int n_chars = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special);
const int n_chars = llama_token_to_piece(get_vocab(), token, &piece[0], piece.size(), 0, special);
if (n_chars < 0) {
piece.resize(-n_chars);
int check = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special);
int check = llama_token_to_piece(get_vocab(), token, &piece[0], piece.size(), 0, special);
GGML_ASSERT(check == -n_chars);
}
else {
@ -260,20 +257,15 @@ struct test_model_context {
}
void reset() {
if (ctx) {
ctx.reset();
}
ctx.reset();
seq_positions.clear();
last_batch_info.clear();
}
void cleanup() {
vocab = nullptr;
const llama_vocab * get_vocab() const {
return model ? llama_model_get_vocab(model.get()) : nullptr;
}
~test_model_context() {
cleanup();
}
};
static void test_backend_greedy_sampling(const char * model_path) {
@ -795,7 +787,7 @@ static void test_backend_logit_bias_sampling(const char * model_path) {
// Get the token for the piece "World".
const std::string piece = "World";
std::vector<llama_token> tokens(16);
llama_tokenize(test_ctx.vocab, piece.c_str(), piece.size(), tokens.data(), tokens.size(), false, false);
llama_tokenize(test_ctx.get_vocab(), piece.c_str(), piece.size(), tokens.data(), tokens.size(), false, false);
llama_token bias_token = tokens[0];
logit_bias.push_back({ bias_token, +100.0f });
printf("biasing token piece '%s' -> token id %d\n", piece.c_str(), bias_token);
@ -803,7 +795,7 @@ static void test_backend_logit_bias_sampling(const char * model_path) {
struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_logit_bias(
llama_vocab_n_tokens(test_ctx.vocab),
llama_vocab_n_tokens(test_ctx.get_vocab()),
logit_bias.size(),
logit_bias.data()));
llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(88));
@ -1062,10 +1054,10 @@ static void test_backend_max_outputs(const char * model_path) {
std::string prompt = "Hello";
std::vector<llama_token> tokens;
tokens.push_back(llama_vocab_bos(test_ctx.vocab));
tokens.push_back(llama_vocab_bos(test_ctx.get_vocab()));
std::vector<llama_token> prompt_tokens(32);
int n_tokens = llama_tokenize(test_ctx.vocab, prompt.c_str(), prompt.length(),
int n_tokens = llama_tokenize(test_ctx.get_vocab(), prompt.c_str(), prompt.length(),
prompt_tokens.data(), prompt_tokens.size(),
false, false);
for (int i = 0; i < n_tokens; i++) {