tests : remove vocab member from test_model_context
Also includes some minor cleanups related to nullptr checks.
This commit is contained in:
parent
9845996919
commit
76a1b7fe8c
|
|
@ -18,10 +18,9 @@
|
|||
#include <vector>
|
||||
|
||||
struct test_model_context {
|
||||
llama_model_ptr model;
|
||||
llama_context_ptr ctx;
|
||||
const llama_vocab * vocab = nullptr;
|
||||
int n_vocab = 0;
|
||||
llama_model_ptr model;
|
||||
llama_context_ptr ctx;
|
||||
int n_vocab = 0;
|
||||
|
||||
std::unordered_map<llama_seq_id, int32_t> seq_positions;
|
||||
std::unordered_map<llama_seq_id, int32_t> last_batch_info;
|
||||
|
|
@ -45,18 +44,16 @@ struct test_model_context {
|
|||
|
||||
if (!model) {
|
||||
fprintf(stderr, "Warning: failed to load model '%s', skipping test\n", model_path);
|
||||
cleanup();
|
||||
return false;
|
||||
}
|
||||
vocab = llama_model_get_vocab(model.get());
|
||||
n_vocab = llama_vocab_n_tokens(vocab);
|
||||
n_vocab = llama_vocab_n_tokens(get_vocab());
|
||||
fprintf(stderr, "Vocabulary size: %d\n", n_vocab);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool setup(const char * model_path, std::vector<llama_sampler_seq_config> & configs, int32_t n_seq_max = -1) {
|
||||
if (model == nullptr) {
|
||||
if (!model) {
|
||||
load_model(model_path);
|
||||
}
|
||||
|
||||
|
|
@ -82,9 +79,8 @@ struct test_model_context {
|
|||
}
|
||||
|
||||
ctx.reset(llama_init_from_model(model.get(), cparams));
|
||||
if (ctx == nullptr) {
|
||||
if (!ctx) {
|
||||
fprintf(stderr, "Warning: failed to create context, skipping test\n");
|
||||
cleanup();
|
||||
return false;
|
||||
}
|
||||
llama_set_warmup(ctx.get(), false);
|
||||
|
|
@ -93,7 +89,7 @@ struct test_model_context {
|
|||
}
|
||||
|
||||
bool decode(const std::map<llama_seq_id, std::string> & prompts) {
|
||||
if (ctx == nullptr || vocab == nullptr) {
|
||||
if (!ctx) {
|
||||
fprintf(stderr, "Error: context not initialized, call setup() first\n");
|
||||
return false;
|
||||
}
|
||||
|
|
@ -103,6 +99,7 @@ struct test_model_context {
|
|||
|
||||
int n_tokens_per_prompt = 0;
|
||||
|
||||
auto vocab = get_vocab();
|
||||
for (const auto & [seq_id, prompt] : prompts) {
|
||||
std::vector<llama_token> tokens;
|
||||
tokens.push_back(llama_vocab_bos(vocab));
|
||||
|
|
@ -246,10 +243,10 @@ struct test_model_context {
|
|||
std::string token_to_piece(llama_token token, bool special) {
|
||||
std::string piece;
|
||||
piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n'
|
||||
const int n_chars = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special);
|
||||
const int n_chars = llama_token_to_piece(get_vocab(), token, &piece[0], piece.size(), 0, special);
|
||||
if (n_chars < 0) {
|
||||
piece.resize(-n_chars);
|
||||
int check = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special);
|
||||
int check = llama_token_to_piece(get_vocab(), token, &piece[0], piece.size(), 0, special);
|
||||
GGML_ASSERT(check == -n_chars);
|
||||
}
|
||||
else {
|
||||
|
|
@ -260,20 +257,15 @@ struct test_model_context {
|
|||
}
|
||||
|
||||
void reset() {
|
||||
if (ctx) {
|
||||
ctx.reset();
|
||||
}
|
||||
ctx.reset();
|
||||
seq_positions.clear();
|
||||
last_batch_info.clear();
|
||||
}
|
||||
|
||||
void cleanup() {
|
||||
vocab = nullptr;
|
||||
const llama_vocab * get_vocab() const {
|
||||
return model ? llama_model_get_vocab(model.get()) : nullptr;
|
||||
}
|
||||
|
||||
~test_model_context() {
|
||||
cleanup();
|
||||
}
|
||||
};
|
||||
|
||||
static void test_backend_greedy_sampling(const char * model_path) {
|
||||
|
|
@ -795,7 +787,7 @@ static void test_backend_logit_bias_sampling(const char * model_path) {
|
|||
// Get the token for the piece "World".
|
||||
const std::string piece = "World";
|
||||
std::vector<llama_token> tokens(16);
|
||||
llama_tokenize(test_ctx.vocab, piece.c_str(), piece.size(), tokens.data(), tokens.size(), false, false);
|
||||
llama_tokenize(test_ctx.get_vocab(), piece.c_str(), piece.size(), tokens.data(), tokens.size(), false, false);
|
||||
llama_token bias_token = tokens[0];
|
||||
logit_bias.push_back({ bias_token, +100.0f });
|
||||
printf("biasing token piece '%s' -> token id %d\n", piece.c_str(), bias_token);
|
||||
|
|
@ -803,7 +795,7 @@ static void test_backend_logit_bias_sampling(const char * model_path) {
|
|||
struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
|
||||
llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
|
||||
llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_logit_bias(
|
||||
llama_vocab_n_tokens(test_ctx.vocab),
|
||||
llama_vocab_n_tokens(test_ctx.get_vocab()),
|
||||
logit_bias.size(),
|
||||
logit_bias.data()));
|
||||
llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(88));
|
||||
|
|
@ -1062,10 +1054,10 @@ static void test_backend_max_outputs(const char * model_path) {
|
|||
std::string prompt = "Hello";
|
||||
|
||||
std::vector<llama_token> tokens;
|
||||
tokens.push_back(llama_vocab_bos(test_ctx.vocab));
|
||||
tokens.push_back(llama_vocab_bos(test_ctx.get_vocab()));
|
||||
|
||||
std::vector<llama_token> prompt_tokens(32);
|
||||
int n_tokens = llama_tokenize(test_ctx.vocab, prompt.c_str(), prompt.length(),
|
||||
int n_tokens = llama_tokenize(test_ctx.get_vocab(), prompt.c_str(), prompt.length(),
|
||||
prompt_tokens.data(), prompt_tokens.size(),
|
||||
false, false);
|
||||
for (int i = 0; i < n_tokens; i++) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue