tests : cleanup test-backend-sampler.cpp

This commit is contained in:
Daniel Bevenius 2025-11-24 07:18:39 +01:00
parent 9e273f7aa4
commit 50d21aa4a4
No known key found for this signature in database
1 changed files with 85 additions and 65 deletions

View File

@ -16,38 +16,40 @@
#include <vector>
struct test_model_context {
llama_model * model = nullptr;
llama_context * ctx = nullptr;
const llama_vocab * vocab = nullptr;
int n_vocab = 0;
llama_model * model = nullptr;
llama_context * ctx = nullptr;
const llama_vocab * vocab = nullptr;
int n_vocab = 0;
std::unordered_map<llama_seq_id, int32_t> seq_positions;
std::unordered_map<llama_seq_id, int32_t> last_batch_info;
bool setup_model(const char * model_path) {
bool load_model(const char * model_path) {
if (model != nullptr) {
return true;
}
llama_backend_init();
llama_model_params mparams = llama_model_default_params();
model = llama_model_load_from_file(model_path, mparams);
model = llama_model_load_from_file(model_path, llama_model_default_params());
if (model == nullptr) {
fprintf(stderr, "Warning: failed to load model '%s', skipping test\n", model_path);
cleanup();
return false;
}
vocab = llama_model_get_vocab(model);
vocab = llama_model_get_vocab(model);
n_vocab = llama_vocab_n_tokens(vocab);
fprintf(stderr, "Vocabulary size: %d\n", n_vocab);
return true;
}
bool setup(const char * model_path, std::vector<llama_sampler_seq_config> & configs) {
if (model == nullptr) {
setup_model(model_path);
load_model(model_path);
}
if (model != nullptr && ctx != nullptr) {
if (ctx != nullptr) {
return true;
}
@ -73,10 +75,6 @@ struct test_model_context {
}
llama_set_warmup(ctx, false);
vocab = llama_model_get_vocab(model);
n_vocab = llama_vocab_n_tokens(vocab);
fprintf(stderr, "Vocabulary size: %d\n", n_vocab);
return true;
}
@ -130,15 +128,15 @@ struct test_model_context {
printf("Batch contents:\n");
printf(" n_tokens: %d\n", batch.n_tokens);
printf("n_tokens: %d\n", batch.n_tokens);
for (int i = 0; i < batch.n_tokens; i++) {
printf(" token[%d]: tok=%-5d, pos=%d, n_seq_id=%d, seq_ids=[", i, batch.token[i], batch.pos[i], batch.n_seq_id[i]);
printf("token[%d]: tok=%-5d, pos=%d, n_seq_id=%d, seq_ids=[", i, batch.token[i], batch.pos[i], batch.n_seq_id[i]);
for (int j = 0; j < batch.n_seq_id[i]; j++) {
printf("%d%s", batch.seq_id[i][j], j < batch.n_seq_id[i]-1 ? ", " : "");
for (int j = 0; j < batch.n_seq_id[i]; j++) {
printf("%d%s", batch.seq_id[i][j], j < batch.n_seq_id[i]-1 ? ", " : "");
}
printf("], logits=%d\n", batch.logits[i]);
}
printf("], logits=%d\n", batch.logits[i]);
}
if (llama_decode(ctx, batch) != 0) {
fprintf(stderr, "Warning: llama_decode failed\n");
@ -151,7 +149,6 @@ struct test_model_context {
if (batch.logits[i]) {
llama_seq_id seq_id = batch.seq_id[i][0];
last_batch_info[seq_id] = i;
printf("seq %d : batch idx %d\n", seq_id, i);
}
}
@ -249,10 +246,15 @@ struct test_model_context {
}
void cleanup() {
if (ctx) llama_free(ctx);
if (model) llama_model_free(model);
if (ctx) {
llama_free(ctx);
}
if (model) {
llama_model_free(model);
}
llama_backend_free();
ctx = nullptr;
ctx = nullptr;
model = nullptr;
vocab = nullptr;
}
@ -374,36 +376,44 @@ static void test_backend_temp_sampling(const char * model_path) {
return;
}
int32_t batch_idx_0 = test_ctx.idx_for_seq(0);
int32_t batch_idx_1 = test_ctx.idx_for_seq(1);
// Verfify sequence 0
{
int32_t batch_idx = test_ctx.idx_for_seq(0);
int n_logits = llama_get_backend_sampled_logits_count_ith(test_ctx.ctx, batch_idx);
GGML_ASSERT(n_logits == test_ctx.n_vocab);
int n_logits = llama_get_backend_sampled_logits_count_ith(test_ctx.ctx, batch_idx_0);
GGML_ASSERT(n_logits == test_ctx.n_vocab);
// Sample from sequence 0 using CPU sampler
struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
struct llama_sampler * chain = llama_sampler_chain_init(chain_params);
llama_sampler_chain_add(chain, llama_sampler_init_dist(18));
// Sample from sequence 0 using CPU sampler
struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params();
struct llama_sampler * chain_0 = llama_sampler_chain_init(chain_params_0);
llama_sampler_chain_add(chain_0, llama_sampler_init_dist(18));
llama_token token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx);
const std::string token_str = test_ctx.token_to_piece(token, false);
printf("Sequence 0 sampled token id:%d, string: '%s'\n", token, token_str.c_str());
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
llama_token token_0 = llama_sampler_sample(chain_0, test_ctx.ctx, batch_idx_0);
const std::string token_0_str = test_ctx.token_to_piece(token_0, false);
printf("Sequence 0 sampled token id:%d, string: '%s'\n", token_0, token_0_str.c_str());
GGML_ASSERT(token_0 >= 0 && token_0 < test_ctx.n_vocab);
llama_sampler_free(chain);
}
// Sample from sequence 1 using CPU sampler
struct llama_sampler_chain_params chain_params_1 = llama_sampler_chain_default_params();
struct llama_sampler * chain_1 = llama_sampler_chain_init(chain_params_1);
llama_sampler_chain_add(chain_1, llama_sampler_init_dist(18));
// Verfify sequence 1
{
int32_t batch_idx = test_ctx.idx_for_seq(1);
llama_token token_1 = llama_sampler_sample(chain_1, test_ctx.ctx, batch_idx_1);
const std::string token_1_str = test_ctx.token_to_piece(token_1, false);
printf("Sequence 1 sampled token id:%d, string: '%s'\n", token_1, token_1_str.c_str());
GGML_ASSERT(token_1 >= 0 && token_1 < test_ctx.n_vocab);
// Sample from sequence 1 using CPU sampler
struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
struct llama_sampler * chain = llama_sampler_chain_init(chain_params);
llama_sampler_chain_add(chain, llama_sampler_init_dist(18));
llama_token token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx);
const std::string token_str = test_ctx.token_to_piece(token, false);
printf("Sequence 1 sampled token id:%d, string: '%s'\n", token, token_str.c_str());
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
llama_sampler_free(chain);
}
printf("backend temp sampling test PASSED\n");
llama_sampler_free(chain_0);
llama_sampler_free(chain_1);
}
static void test_backend_multi_sequence_sampling(const char * model_path) {
@ -436,17 +446,23 @@ static void test_backend_multi_sequence_sampling(const char * model_path) {
return;
}
int32_t batch_idx_0 = test_ctx.idx_for_seq(0);
llama_token seq0_token = llama_get_backend_sampled_token_ith(test_ctx.ctx, batch_idx_0);
const std::string seq0_token_str = test_ctx.token_to_piece(seq0_token, false);
printf("Seq 0 sampled token id=%d, string='%s'\n", seq0_token, seq0_token_str.c_str());
GGML_ASSERT(seq0_token >= 0 && seq0_token < test_ctx.n_vocab);
// Verfiy sequence 0
{
int32_t batch_idx = test_ctx.idx_for_seq(0);
llama_token token = llama_get_backend_sampled_token_ith(test_ctx.ctx, batch_idx);
const std::string token_str = test_ctx.token_to_piece(token, false);
printf("Seq 0 sampled token id=%d, string='%s'\n", token, token_str.c_str());
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
}
int32_t batch_idx_1 = test_ctx.idx_for_seq(1);
llama_token seq1_token = llama_get_backend_sampled_token_ith(test_ctx.ctx, batch_idx_1);
const std::string seq1_token_str = test_ctx.token_to_piece(seq1_token, false);
printf("Seq 1 sampled token id=%d, string='%s'\n", seq1_token, seq1_token_str.c_str());
GGML_ASSERT(seq1_token >= 0 && seq1_token < test_ctx.n_vocab);
// Verify sequence 1
{
int32_t batch_idx= test_ctx.idx_for_seq(1);
llama_token token = llama_get_backend_sampled_token_ith(test_ctx.ctx, batch_idx);
const std::string token_str = test_ctx.token_to_piece(token, false);
printf("Seq 1 sampled token id=%d, string='%s'\n", token, token_str.c_str());
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
}
// Generate tokens for each sequence
printf("\nMulti-sequence generation:\n");
@ -473,26 +489,29 @@ static void test_backend_multi_sequence_sampling(const char * model_path) {
static void test_backend_dist_sampling(const char * model_path) {
test_model_context test_ctx;
const int seq_id = 189;
const int32_t seed = 88;
struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params);
llama_sampler_chain_add(backend_sampler_chain, llama_sampler_backend_init_dist(seed));
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ 0, backend_sampler_chain }};
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain }};
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
return;
}
if (!test_ctx.decode({{0, "Hello"}})) {
if (!test_ctx.decode({{seq_id, "Some"}})) {
return;
}
llama_token token = llama_get_backend_sampled_token_ith(test_ctx.ctx, test_ctx.idx_for_seq(0));
printf("greedy sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str());
int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
llama_token token = llama_get_backend_sampled_token_ith(test_ctx.ctx, batch_idx);
printf("dist sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str());
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
GGML_ASSERT(llama_get_backend_sampled_logits_ith(test_ctx.ctx, batch_idx) == nullptr);
token = llama_get_backend_sampled_token_ith(test_ctx.ctx, -1);
printf("greedy sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str());
printf("dist sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str());
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
}
@ -510,7 +529,7 @@ static void test_backend_dist_sampling_and_cpu(const char * model_path) {
return;
}
if (!test_ctx.decode({{seq_id, "Hello"}})) {
if (!test_ctx.decode({{seq_id, "Some"}})) {
return;
}
@ -523,14 +542,15 @@ static void test_backend_dist_sampling_and_cpu(const char * model_path) {
llama_token backend_token = llama_get_backend_sampled_token_ith(test_ctx.ctx, batch_idx);
llama_token cpu_token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx);
printf("dist & cpu sampled id:%d, string:'%s'\n", cpu_token, test_ctx.token_to_piece(cpu_token, false).c_str());
GGML_ASSERT(backend_token == cpu_token);
}
static void test_backend_logit_bias_sampling(const char * model_path) {
test_model_context test_ctx;
// Calling setup_model to ensure vocab is loaded and can be accessed
if (!test_ctx.setup_model(model_path)) {
// Calling load_model to ensure vocab is loaded and can be accessed
if (!test_ctx.load_model(model_path)) {
return;
}
@ -616,7 +636,7 @@ static void test_backend_mixed_sampling(const char * model_path) {
GGML_ASSERT(llama_get_backend_sampled_logits_count_ith(test_ctx.ctx, batch_idx) == 0);
}
// Verfiy sequence 0 that used the top-k backend sampler.
// Verfiy sequence 1 that used the top-k backend sampler.
{
int32_t batch_idx = test_ctx.idx_for_seq(1);
float * logits = llama_get_backend_sampled_logits_ith(test_ctx.ctx, batch_idx);