tests : use smart pointers for model and context

This commit is contained in:
Daniel Bevenius 2025-12-17 11:26:05 +01:00
parent 9a9ea2f6b1
commit 9845996919
No known key found for this signature in database
1 changed files with 66 additions and 75 deletions

View File

@ -18,8 +18,8 @@
#include <vector> #include <vector>
struct test_model_context { struct test_model_context {
llama_model * model = nullptr; llama_model_ptr model;
llama_context * ctx = nullptr; llama_context_ptr ctx;
const llama_vocab * vocab = nullptr; const llama_vocab * vocab = nullptr;
int n_vocab = 0; int n_vocab = 0;
@ -27,7 +27,7 @@ struct test_model_context {
std::unordered_map<llama_seq_id, int32_t> last_batch_info; std::unordered_map<llama_seq_id, int32_t> last_batch_info;
bool load_model(const char * model_path) { bool load_model(const char * model_path) {
if (model != nullptr) { if (model) {
return true; return true;
} }
@ -41,13 +41,14 @@ struct test_model_context {
auto mparams = llama_model_default_params(); auto mparams = llama_model_default_params();
mparams.devices = devs; mparams.devices = devs;
model = llama_model_load_from_file(model_path, mparams); model.reset(llama_model_load_from_file(model_path, mparams));
if (model == nullptr) {
if (!model) {
fprintf(stderr, "Warning: failed to load model '%s', skipping test\n", model_path); fprintf(stderr, "Warning: failed to load model '%s', skipping test\n", model_path);
cleanup(); cleanup();
return false; return false;
} }
vocab = llama_model_get_vocab(model); vocab = llama_model_get_vocab(model.get());
n_vocab = llama_vocab_n_tokens(vocab); n_vocab = llama_vocab_n_tokens(vocab);
fprintf(stderr, "Vocabulary size: %d\n", n_vocab); fprintf(stderr, "Vocabulary size: %d\n", n_vocab);
@ -59,7 +60,7 @@ struct test_model_context {
load_model(model_path); load_model(model_path);
} }
if (ctx != nullptr) { if (ctx) {
return true; return true;
} }
@ -80,13 +81,13 @@ struct test_model_context {
cparams.n_seq_max = n_seq_max; cparams.n_seq_max = n_seq_max;
} }
ctx = llama_init_from_model(model, cparams); ctx.reset(llama_init_from_model(model.get(), cparams));
if (ctx == nullptr) { if (ctx == nullptr) {
fprintf(stderr, "Warning: failed to create context, skipping test\n"); fprintf(stderr, "Warning: failed to create context, skipping test\n");
cleanup(); cleanup();
return false; return false;
} }
llama_set_warmup(ctx, false); llama_set_warmup(ctx.get(), false);
return true; return true;
} }
@ -151,7 +152,7 @@ struct test_model_context {
printf("], logits=%d\n", batch.logits[i]); printf("], logits=%d\n", batch.logits[i]);
} }
if (llama_decode(ctx, batch) != 0) { if (llama_decode(ctx.get(), batch) != 0) {
fprintf(stderr, "Warning: llama_decode failed\n"); fprintf(stderr, "Warning: llama_decode failed\n");
llama_batch_free(batch); llama_batch_free(batch);
return false; return false;
@ -188,7 +189,7 @@ struct test_model_context {
int32_t pos = seq_positions[seq_id]; int32_t pos = seq_positions[seq_id];
common_batch_add(batch, token, pos, { seq_id }, true); common_batch_add(batch, token, pos, { seq_id }, true);
if (llama_decode(ctx, batch) != 0) { if (llama_decode(ctx.get(), batch) != 0) {
fprintf(stderr, "Warning: llama_decode failed for token %d in seq %d\n", token, seq_id); fprintf(stderr, "Warning: llama_decode failed for token %d in seq %d\n", token, seq_id);
llama_batch_free(batch); llama_batch_free(batch);
return false; return false;
@ -220,7 +221,7 @@ struct test_model_context {
common_batch_add(batch, token, pos, { seq_id }, true); common_batch_add(batch, token, pos, { seq_id }, true);
} }
if (llama_decode(ctx, batch) != 0) { if (llama_decode(ctx.get(), batch) != 0) {
fprintf(stderr, "Warning: llama_decode failed for batch tokens\n"); fprintf(stderr, "Warning: llama_decode failed for batch tokens\n");
llama_batch_free(batch); llama_batch_free(batch);
return false; return false;
@ -260,23 +261,13 @@ struct test_model_context {
void reset() { void reset() {
if (ctx) { if (ctx) {
llama_free(ctx); ctx.reset();
ctx = nullptr;
} }
seq_positions.clear(); seq_positions.clear();
last_batch_info.clear(); last_batch_info.clear();
} }
void cleanup() { void cleanup() {
if (ctx) {
llama_free(ctx);
}
if (model) {
llama_model_free(model);
}
ctx = nullptr;
model = nullptr;
vocab = nullptr; vocab = nullptr;
} }
@ -306,17 +297,17 @@ static void test_backend_greedy_sampling(const char * model_path) {
int32_t batch_idx = test_ctx.idx_for_seq(seq_id); int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
llama_token token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx); llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
printf("greedy sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str()); printf("greedy sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str());
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
token = llama_get_sampled_token_ith(test_ctx.ctx, -1); token = llama_get_sampled_token_ith(test_ctx.ctx.get(), -1);
printf("greedy sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str()); printf("greedy sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str());
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
for (int i = 0; i < 10; i++) { for (int i = 0; i < 10; i++) {
int32_t loop_idx = test_ctx.idx_for_seq(seq_id); int32_t loop_idx = test_ctx.idx_for_seq(seq_id);
llama_token token = llama_get_sampled_token_ith(test_ctx.ctx, loop_idx); llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), loop_idx);
printf("Generation step %d: token id:%d, string: %s\n", i, token, test_ctx.token_to_piece(token, false).c_str()); printf("Generation step %d: token id:%d, string: %s\n", i, token, test_ctx.token_to_piece(token, false).c_str());
if (!test_ctx.decode_token(token, 0)) { if (!test_ctx.decode_token(token, 0)) {
GGML_ASSERT(false && "Failed to decode token"); GGML_ASSERT(false && "Failed to decode token");
@ -344,14 +335,14 @@ static void test_backend_top_k_sampling(const char * model_path) {
int32_t batch_idx = test_ctx.idx_for_seq(seq_id); int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
float * logits = llama_get_sampled_logits_ith(test_ctx.ctx, batch_idx); float * logits = llama_get_sampled_logits_ith(test_ctx.ctx.get(), batch_idx);
uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx); uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx);
for (size_t i = 0; i < n_logits; ++i) { for (size_t i = 0; i < n_logits; ++i) {
printf("top_k logit[%zu] = %.6f\n", i, logits[i]); printf("top_k logit[%zu] = %.6f\n", i, logits[i]);
} }
llama_token * candidates = llama_get_sampled_candidates_ith(test_ctx.ctx, batch_idx); llama_token * candidates = llama_get_sampled_candidates_ith(test_ctx.ctx.get(), batch_idx);
uint32_t n_candidates = llama_get_sampled_candidates_count_ith(test_ctx.ctx, batch_idx); uint32_t n_candidates = llama_get_sampled_candidates_count_ith(test_ctx.ctx.get(), batch_idx);
for (size_t i = 0; i < n_candidates; ++i) { for (size_t i = 0; i < n_candidates; ++i) {
printf("top_k candidate[%zu] = %d : %s\n", i, candidates[i], printf("top_k candidate[%zu] = %d : %s\n", i, candidates[i],
test_ctx.token_to_piece(candidates[i], false).c_str()); test_ctx.token_to_piece(candidates[i], false).c_str());
@ -364,7 +355,7 @@ static void test_backend_top_k_sampling(const char * model_path) {
GGML_ASSERT(chain->iface->backend_apply != nullptr); GGML_ASSERT(chain->iface->backend_apply != nullptr);
llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(18)); llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(18));
llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx, batch_idx); llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx);
const std::string token_str = test_ctx.token_to_piece(token, false); const std::string token_str = test_ctx.token_to_piece(token, false);
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
@ -401,7 +392,7 @@ static void test_backend_temp_sampling(const char * model_path) {
// Verfify sequence 0 // Verfify sequence 0
{ {
int32_t batch_idx = test_ctx.idx_for_seq(0); int32_t batch_idx = test_ctx.idx_for_seq(0);
int n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx); int n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx);
GGML_ASSERT(n_logits == test_ctx.n_vocab); GGML_ASSERT(n_logits == test_ctx.n_vocab);
// Sample from sequence 0 using CPU sampler // Sample from sequence 0 using CPU sampler
@ -409,7 +400,7 @@ static void test_backend_temp_sampling(const char * model_path) {
llama_sampler_ptr chain(llama_sampler_chain_init(chain_params)); llama_sampler_ptr chain(llama_sampler_chain_init(chain_params));
llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(18)); llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(18));
llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx, batch_idx); llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx);
const std::string token_str = test_ctx.token_to_piece(token, false); const std::string token_str = test_ctx.token_to_piece(token, false);
printf("Sequence 0 sampled token id:%d, string: '%s'\n", token, token_str.c_str()); printf("Sequence 0 sampled token id:%d, string: '%s'\n", token, token_str.c_str());
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
@ -425,7 +416,7 @@ static void test_backend_temp_sampling(const char * model_path) {
llama_sampler_ptr chain(llama_sampler_chain_init(chain_params)); llama_sampler_ptr chain(llama_sampler_chain_init(chain_params));
llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(18)); llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(18));
llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx, batch_idx); llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx);
const std::string token_str = test_ctx.token_to_piece(token, false); const std::string token_str = test_ctx.token_to_piece(token, false);
printf("Sequence 1 sampled token id:%d, string: '%s'\n", token, token_str.c_str()); printf("Sequence 1 sampled token id:%d, string: '%s'\n", token, token_str.c_str());
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
@ -457,7 +448,7 @@ static void test_backend_temp_sampling(const char * model_path) {
int32_t batch_idx = test_ctx.idx_for_seq(seq_id); int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx); uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx);
GGML_ASSERT(n_logits == 1); GGML_ASSERT(n_logits == 1);
}; };
@ -495,7 +486,7 @@ static void test_backend_temp_ext_sampling(const char * model_path) {
// Verify sequence 0 // Verify sequence 0
{ {
int32_t batch_idx = test_ctx.idx_for_seq(seq_id); int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
int n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx); int n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx);
GGML_ASSERT(n_logits == test_ctx.n_vocab); GGML_ASSERT(n_logits == test_ctx.n_vocab);
} }
} }
@ -527,7 +518,7 @@ static void test_backend_temp_ext_sampling(const char * model_path) {
int32_t batch_idx = test_ctx.idx_for_seq(seq_id); int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx); uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx);
if (temp <= 0.0f && delta >= 0.0f) { if (temp <= 0.0f && delta >= 0.0f) {
GGML_ASSERT(n_logits == 1); GGML_ASSERT(n_logits == 1);
@ -564,8 +555,8 @@ static void test_backend_min_p_sampling(const char * model_path) {
int32_t batch_idx = test_ctx.idx_for_seq(seq_id); int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
float * logits = llama_get_sampled_logits_ith(test_ctx.ctx, batch_idx); float * logits = llama_get_sampled_logits_ith(test_ctx.ctx.get(), batch_idx);
uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx); uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx);
// Print the logits that are above the min-p threshold // Print the logits that are above the min-p threshold
std::vector<float> filtered_logits; std::vector<float> filtered_logits;
@ -582,7 +573,7 @@ static void test_backend_min_p_sampling(const char * model_path) {
llama_sampler_ptr chain(llama_sampler_chain_init(chain_params)); llama_sampler_ptr chain(llama_sampler_chain_init(chain_params));
llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(88)); llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(88));
llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx, batch_idx); llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx);
const std::string token_str = test_ctx.token_to_piece(token, false); const std::string token_str = test_ctx.token_to_piece(token, false);
printf("min-p cpu sampled token id:%d, string: '%s'\n", token, token_str.c_str()); printf("min-p cpu sampled token id:%d, string: '%s'\n", token, token_str.c_str());
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
@ -590,7 +581,7 @@ static void test_backend_min_p_sampling(const char * model_path) {
// Decode and sampler 10 more tokens // Decode and sampler 10 more tokens
for (int i = 0; i < 10; i++) { for (int i = 0; i < 10; i++) {
int32_t loop_idx = test_ctx.idx_for_seq(seq_id); int32_t loop_idx = test_ctx.idx_for_seq(seq_id);
llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx, loop_idx); llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), loop_idx);
printf("min-p gen step %d: token id :%5.d, string: %s\n", i, token, test_ctx.token_to_piece(token, false).c_str()); printf("min-p gen step %d: token id :%5.d, string: %s\n", i, token, test_ctx.token_to_piece(token, false).c_str());
if (!test_ctx.decode_token(token, 0)) { if (!test_ctx.decode_token(token, 0)) {
GGML_ASSERT(false && "Failed to decode token"); GGML_ASSERT(false && "Failed to decode token");
@ -620,8 +611,8 @@ static void test_backend_top_p_sampling(const char * model_path) {
int32_t batch_idx = test_ctx.idx_for_seq(seq_id); int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
float * logits = llama_get_sampled_logits_ith(test_ctx.ctx, batch_idx); float * logits = llama_get_sampled_logits_ith(test_ctx.ctx.get(), batch_idx);
uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx); uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx);
// Print the logits that are above the min-p threshold // Print the logits that are above the min-p threshold
std::vector<float> filtered_logits; std::vector<float> filtered_logits;
@ -638,7 +629,7 @@ static void test_backend_top_p_sampling(const char * model_path) {
llama_sampler_ptr chain(llama_sampler_chain_init(chain_params)); llama_sampler_ptr chain(llama_sampler_chain_init(chain_params));
llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(88)); llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(88));
llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx, batch_idx); llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx);
const std::string token_str = test_ctx.token_to_piece(token, false); const std::string token_str = test_ctx.token_to_piece(token, false);
printf("top-p cpu sampled token id:%d, string: '%s'\n", token, token_str.c_str()); printf("top-p cpu sampled token id:%d, string: '%s'\n", token, token_str.c_str());
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
@ -646,7 +637,7 @@ static void test_backend_top_p_sampling(const char * model_path) {
// Decode and sampler 10 more tokens // Decode and sampler 10 more tokens
for (int i = 0; i < 10; i++) { for (int i = 0; i < 10; i++) {
int32_t loop_idx = test_ctx.idx_for_seq(seq_id); int32_t loop_idx = test_ctx.idx_for_seq(seq_id);
llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx, loop_idx); llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), loop_idx);
printf("top-p gen step %d: token id :%5.d, string: %s\n", i, token, test_ctx.token_to_piece(token, false).c_str()); printf("top-p gen step %d: token id :%5.d, string: %s\n", i, token, test_ctx.token_to_piece(token, false).c_str());
test_ctx.decode_token(token, 0); test_ctx.decode_token(token, 0);
} }
@ -687,7 +678,7 @@ static void test_backend_multi_sequence_sampling(const char * model_path) {
// Verfiy sequence 0 // Verfiy sequence 0
{ {
int32_t batch_idx = test_ctx.idx_for_seq(0); int32_t batch_idx = test_ctx.idx_for_seq(0);
llama_token token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx); llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
const std::string token_str = test_ctx.token_to_piece(token, false); const std::string token_str = test_ctx.token_to_piece(token, false);
printf("Seq 0 sampled token id=%d, string='%s'\n", token, token_str.c_str()); printf("Seq 0 sampled token id=%d, string='%s'\n", token, token_str.c_str());
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
@ -696,7 +687,7 @@ static void test_backend_multi_sequence_sampling(const char * model_path) {
// Verify sequence 1 // Verify sequence 1
{ {
int32_t batch_idx= test_ctx.idx_for_seq(1); int32_t batch_idx= test_ctx.idx_for_seq(1);
llama_token token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx); llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
const std::string token_str = test_ctx.token_to_piece(token, false); const std::string token_str = test_ctx.token_to_piece(token, false);
printf("Seq 1 sampled token id=%d, string='%s'\n", token, token_str.c_str()); printf("Seq 1 sampled token id=%d, string='%s'\n", token, token_str.c_str());
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
@ -709,7 +700,7 @@ static void test_backend_multi_sequence_sampling(const char * model_path) {
for (llama_seq_id seq_id : {0, 1}) { for (llama_seq_id seq_id : {0, 1}) {
int32_t idx = test_ctx.idx_for_seq(seq_id); int32_t idx = test_ctx.idx_for_seq(seq_id);
llama_token token = llama_get_sampled_token_ith(test_ctx.ctx, idx); llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), idx);
const std::string token_str = test_ctx.token_to_piece(token, false); const std::string token_str = test_ctx.token_to_piece(token, false);
printf(" Seq %d, step %d: token id=%d, string='%s'\n", seq_id, step, token, token_str.c_str()); printf(" Seq %d, step %d: token id=%d, string='%s'\n", seq_id, step, token, token_str.c_str());
tokens[seq_id] = token; tokens[seq_id] = token;
@ -743,12 +734,12 @@ static void test_backend_dist_sampling(const char * model_path) {
} }
int32_t batch_idx = test_ctx.idx_for_seq(seq_id); int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
llama_token token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx); llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
printf("dist sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str()); printf("dist sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str());
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
//GGML_ASSERT(llama_get_sampled_logits_ith(test_ctx.ctx, batch_idx) == nullptr); //GGML_ASSERT(llama_get_sampled_logits_ith(test_ctx.ctx.get(), batch_idx) == nullptr);
token = llama_get_sampled_token_ith(test_ctx.ctx, -1); token = llama_get_sampled_token_ith(test_ctx.ctx.get(), -1);
printf("dist sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str()); printf("dist sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str());
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
@ -780,8 +771,8 @@ static void test_backend_dist_sampling_and_cpu(const char * model_path) {
llama_sampler_ptr chain(llama_sampler_chain_init(chain_params)); llama_sampler_ptr chain(llama_sampler_chain_init(chain_params));
llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(18)); llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(18));
llama_token backend_token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx); llama_token backend_token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
llama_token cpu_token = llama_sampler_sample(chain.get(), test_ctx.ctx, batch_idx); llama_token cpu_token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx);
printf("dist & cpu sampled id:%d, string:'%s'\n", cpu_token, test_ctx.token_to_piece(cpu_token, false).c_str()); printf("dist & cpu sampled id:%d, string:'%s'\n", cpu_token, test_ctx.token_to_piece(cpu_token, false).c_str());
GGML_ASSERT(backend_token == cpu_token); GGML_ASSERT(backend_token == cpu_token);
@ -829,7 +820,7 @@ static void test_backend_logit_bias_sampling(const char * model_path) {
GGML_ASSERT(false && "Failed to decode token"); GGML_ASSERT(false && "Failed to decode token");
} }
llama_token backend_token = llama_get_sampled_token_ith(test_ctx.ctx, test_ctx.idx_for_seq(seq_id)); llama_token backend_token = llama_get_sampled_token_ith(test_ctx.ctx.get(), test_ctx.idx_for_seq(seq_id));
const std::string backend_token_str = test_ctx.token_to_piece(backend_token, false); const std::string backend_token_str = test_ctx.token_to_piece(backend_token, false);
printf("logit bias sampled token = %d, string='%s'\n", backend_token, backend_token_str.c_str()); printf("logit bias sampled token = %d, string='%s'\n", backend_token, backend_token_str.c_str());
GGML_ASSERT(backend_token == bias_token); GGML_ASSERT(backend_token == bias_token);
@ -872,22 +863,22 @@ static void test_backend_mixed_sampling(const char * model_path) {
// Verfiy sequence 0 that used the dist backend sampler. // Verfiy sequence 0 that used the dist backend sampler.
{ {
int32_t batch_idx = test_ctx.idx_for_seq(0); int32_t batch_idx = test_ctx.idx_for_seq(0);
llama_token token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx); llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
const std::string token_str = test_ctx.token_to_piece(token, false); const std::string token_str = test_ctx.token_to_piece(token, false);
printf("sampled token id=%d, string='%s'\n", token, token_str.c_str()); printf("sampled token id=%d, string='%s'\n", token, token_str.c_str());
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
//GGML_ASSERT(llama_get_sampled_logits_ith(test_ctx.ctx, batch_idx) == nullptr); //GGML_ASSERT(llama_get_sampled_logits_ith(test_ctx.ctx.get(), batch_idx) == nullptr);
//GGML_ASSERT(llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx) == 0); //GGML_ASSERT(llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx) == 0);
} }
// Verfiy sequence 1 that used the top-k backend sampler. // Verfiy sequence 1 that used the top-k backend sampler.
{ {
int32_t batch_idx = test_ctx.idx_for_seq(1); int32_t batch_idx = test_ctx.idx_for_seq(1);
float * logits = llama_get_sampled_logits_ith(test_ctx.ctx, batch_idx); float * logits = llama_get_sampled_logits_ith(test_ctx.ctx.get(), batch_idx);
GGML_ASSERT(logits != nullptr); GGML_ASSERT(logits != nullptr);
size_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx); size_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx);
GGML_ASSERT(n_logits == (size_t) k); GGML_ASSERT(n_logits == (size_t) k);
GGML_ASSERT(llama_get_sampled_token_ith(test_ctx.ctx, batch_idx) == LLAMA_TOKEN_NULL); GGML_ASSERT(llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx) == LLAMA_TOKEN_NULL);
} }
printf("backend mixed sampling test PASSED\n"); printf("backend mixed sampling test PASSED\n");
@ -914,12 +905,12 @@ static void test_backend_set_sampler(const char * model_path) {
int32_t batch_idx = test_ctx.idx_for_seq(seq_id); int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
// Sample using backend sampler configured above // Sample using backend sampler configured above
llama_token backend_token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx); llama_token backend_token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
const std::string backend_token_str = test_ctx.token_to_piece(backend_token, false); const std::string backend_token_str = test_ctx.token_to_piece(backend_token, false);
printf("dist sampled token = %d, string='%s'\n", backend_token, backend_token_str.c_str()); printf("dist sampled token = %d, string='%s'\n", backend_token, backend_token_str.c_str());
// Now clear the backend sampler for this sequence. // Now clear the backend sampler for this sequence.
llama_set_sampler(test_ctx.ctx, seq_id, nullptr); llama_set_sampler(test_ctx.ctx.get(), seq_id, nullptr);
printf("Cleared backend sampler for seq_id %d\n", seq_id); printf("Cleared backend sampler for seq_id %d\n", seq_id);
// Sample using CPU sampler // Sample using CPU sampler
@ -934,11 +925,11 @@ static void test_backend_set_sampler(const char * model_path) {
// Should not have any sampled token or probs after clearing the backend sampler. // Should not have any sampled token or probs after clearing the backend sampler.
const int32_t idx = test_ctx.idx_for_seq(seq_id); const int32_t idx = test_ctx.idx_for_seq(seq_id);
GGML_ASSERT(llama_get_sampled_token_ith(test_ctx.ctx, idx) == LLAMA_TOKEN_NULL); GGML_ASSERT(llama_get_sampled_token_ith(test_ctx.ctx.get(), idx) == LLAMA_TOKEN_NULL);
GGML_ASSERT(llama_get_sampled_probs_ith(test_ctx.ctx, idx) == nullptr); GGML_ASSERT(llama_get_sampled_probs_ith(test_ctx.ctx.get(), idx) == nullptr);
// Sample the token using the CPU sampler chain. // Sample the token using the CPU sampler chain.
llama_token token2 = llama_sampler_sample(chain.get(), test_ctx.ctx, seq_id); llama_token token2 = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), seq_id);
const std::string token2_str = test_ctx.token_to_piece(token2, false); const std::string token2_str = test_ctx.token_to_piece(token2, false);
printf("CPU sampled token after clearing backend sampler: id=%d, string='%s'\n", token2, token2_str.c_str()); printf("CPU sampled token after clearing backend sampler: id=%d, string='%s'\n", token2, token2_str.c_str());
std::map<llama_seq_id, llama_token> tokens2 = { { seq_id, token2}, }; std::map<llama_seq_id, llama_token> tokens2 = { { seq_id, token2}, };
@ -948,13 +939,13 @@ static void test_backend_set_sampler(const char * model_path) {
llama_sampler_ptr new_backend_sampler_chain(llama_sampler_chain_init(new_backend_chain_params)); llama_sampler_ptr new_backend_sampler_chain(llama_sampler_chain_init(new_backend_chain_params));
llama_sampler_chain_add(new_backend_sampler_chain.get(), llama_sampler_init_top_k(20)); llama_sampler_chain_add(new_backend_sampler_chain.get(), llama_sampler_init_top_k(20));
llama_sampler_chain_add(new_backend_sampler_chain.get(), llama_sampler_init_dist(seed)); llama_sampler_chain_add(new_backend_sampler_chain.get(), llama_sampler_init_dist(seed));
llama_set_sampler(test_ctx.ctx, seq_id, new_backend_sampler_chain.get()); llama_set_sampler(test_ctx.ctx.get(), seq_id, new_backend_sampler_chain.get());
if (!test_ctx.decode_tokens(tokens2)) { if (!test_ctx.decode_tokens(tokens2)) {
GGML_ASSERT(false && "Failed to decode token"); GGML_ASSERT(false && "Failed to decode token");
} }
llama_token new_backend_token = llama_get_sampled_token_ith(test_ctx.ctx, test_ctx.idx_for_seq(seq_id)); llama_token new_backend_token = llama_get_sampled_token_ith(test_ctx.ctx.get(), test_ctx.idx_for_seq(seq_id));
const std::string new_backend_token_str = test_ctx.token_to_piece(new_backend_token, false); const std::string new_backend_token_str = test_ctx.token_to_piece(new_backend_token, false);
printf("dist sampled token = %d, string='%s'\n", new_backend_token, new_backend_token_str.c_str()); printf("dist sampled token = %d, string='%s'\n", new_backend_token, new_backend_token_str.c_str());
@ -990,7 +981,7 @@ static void test_backend_cpu_mixed_batch(const char * model_path) {
// Verify sequence 0 (backend sampled) // Verify sequence 0 (backend sampled)
{ {
int32_t batch_idx = test_ctx.idx_for_seq(0); int32_t batch_idx = test_ctx.idx_for_seq(0);
llama_token token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx); llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
const std::string token_str = test_ctx.token_to_piece(token, false); const std::string token_str = test_ctx.token_to_piece(token, false);
printf("Seq 0 (backend) sampled token id=%d, string='%s'\n", token, token_str.c_str()); printf("Seq 0 (backend) sampled token id=%d, string='%s'\n", token, token_str.c_str());
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
@ -1000,14 +991,14 @@ static void test_backend_cpu_mixed_batch(const char * model_path) {
{ {
int32_t batch_idx = test_ctx.idx_for_seq(1); int32_t batch_idx = test_ctx.idx_for_seq(1);
llama_token backend_token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx); llama_token backend_token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
GGML_ASSERT(backend_token == LLAMA_TOKEN_NULL); GGML_ASSERT(backend_token == LLAMA_TOKEN_NULL);
struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
llama_sampler_ptr chain(llama_sampler_chain_init(chain_params)); llama_sampler_ptr chain(llama_sampler_chain_init(chain_params));
llama_sampler_chain_add(chain.get(), llama_sampler_init_greedy()); llama_sampler_chain_add(chain.get(), llama_sampler_init_greedy());
llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx, batch_idx); llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx);
const std::string token_str = test_ctx.token_to_piece(token, false); const std::string token_str = test_ctx.token_to_piece(token, false);
printf("Seq 1 (CPU) sampled token id=%d, string='%s'\n", token, token_str.c_str()); printf("Seq 1 (CPU) sampled token id=%d, string='%s'\n", token, token_str.c_str());
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
@ -1017,7 +1008,7 @@ static void test_backend_cpu_mixed_batch(const char * model_path) {
{ {
// clear the backend sampler for seq 0 so that there are no backend // clear the backend sampler for seq 0 so that there are no backend
// samplers. // samplers.
llama_set_sampler(test_ctx.ctx, 0, nullptr); llama_set_sampler(test_ctx.ctx.get(), 0, nullptr);
// Create a CPU sampler and verify we can sampler from it. // Create a CPU sampler and verify we can sampler from it.
struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
@ -1025,7 +1016,7 @@ static void test_backend_cpu_mixed_batch(const char * model_path) {
llama_sampler_chain_add(chain.get(), llama_sampler_init_greedy()); llama_sampler_chain_add(chain.get(), llama_sampler_init_greedy());
int32_t batch_idx = test_ctx.idx_for_seq(1); int32_t batch_idx = test_ctx.idx_for_seq(1);
llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx, batch_idx); llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx);
if (!test_ctx.decode_token(token, 1)) { if (!test_ctx.decode_token(token, 1)) {
GGML_ASSERT(false && "Failed to decode token"); GGML_ASSERT(false && "Failed to decode token");
} }
@ -1037,14 +1028,14 @@ static void test_backend_cpu_mixed_batch(const char * model_path) {
llama_sampler_ptr sampler_chain(llama_sampler_chain_init(chain_params)); llama_sampler_ptr sampler_chain(llama_sampler_chain_init(chain_params));
llama_sampler_chain_add(sampler_chain.get(), llama_sampler_init_dist(88)); llama_sampler_chain_add(sampler_chain.get(), llama_sampler_init_dist(88));
llama_set_sampler(test_ctx.ctx, 0, sampler_chain.get()); llama_set_sampler(test_ctx.ctx.get(), 0, sampler_chain.get());
if (!test_ctx.decode_token(3834, 0)) { if (!test_ctx.decode_token(3834, 0)) {
GGML_ASSERT(false && "Failed to decode token"); GGML_ASSERT(false && "Failed to decode token");
} }
int32_t batch_idx = test_ctx.idx_for_seq(0); int32_t batch_idx = test_ctx.idx_for_seq(0);
llama_token token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx); llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
const std::string token_str = test_ctx.token_to_piece(token, false); const std::string token_str = test_ctx.token_to_piece(token, false);
printf("re-added backend sampled token id=%d, string='%s'\n", token, token_str.c_str()); printf("re-added backend sampled token id=%d, string='%s'\n", token, token_str.c_str());
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
@ -1087,7 +1078,7 @@ static void test_backend_max_outputs(const char * model_path) {
} }
printf(">>> test_max_outputs expected error start:\n"); printf(">>> test_max_outputs expected error start:\n");
const int ret = llama_decode(test_ctx.ctx, batch); const int ret = llama_decode(test_ctx.ctx.get(), batch);
GGML_ASSERT(ret != 0 && "llama_decode should not succeed multiple outputs per sequence"); GGML_ASSERT(ret != 0 && "llama_decode should not succeed multiple outputs per sequence");
printf("<<< test_max_outputs expected error end.\n"); printf("<<< test_max_outputs expected error end.\n");
llama_batch_free(batch); llama_batch_free(batch);