tests : cleanup test-backend-sampler.cpp
This commit is contained in:
parent
9e273f7aa4
commit
50d21aa4a4
|
|
@ -16,38 +16,40 @@
|
|||
#include <vector>
|
||||
|
||||
struct test_model_context {
|
||||
llama_model * model = nullptr;
|
||||
llama_context * ctx = nullptr;
|
||||
const llama_vocab * vocab = nullptr;
|
||||
int n_vocab = 0;
|
||||
llama_model * model = nullptr;
|
||||
llama_context * ctx = nullptr;
|
||||
const llama_vocab * vocab = nullptr;
|
||||
int n_vocab = 0;
|
||||
|
||||
std::unordered_map<llama_seq_id, int32_t> seq_positions;
|
||||
std::unordered_map<llama_seq_id, int32_t> last_batch_info;
|
||||
|
||||
bool setup_model(const char * model_path) {
|
||||
bool load_model(const char * model_path) {
|
||||
if (model != nullptr) {
|
||||
return true;
|
||||
}
|
||||
|
||||
llama_backend_init();
|
||||
|
||||
llama_model_params mparams = llama_model_default_params();
|
||||
model = llama_model_load_from_file(model_path, mparams);
|
||||
model = llama_model_load_from_file(model_path, llama_model_default_params());
|
||||
if (model == nullptr) {
|
||||
fprintf(stderr, "Warning: failed to load model '%s', skipping test\n", model_path);
|
||||
cleanup();
|
||||
return false;
|
||||
}
|
||||
vocab = llama_model_get_vocab(model);
|
||||
vocab = llama_model_get_vocab(model);
|
||||
n_vocab = llama_vocab_n_tokens(vocab);
|
||||
fprintf(stderr, "Vocabulary size: %d\n", n_vocab);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool setup(const char * model_path, std::vector<llama_sampler_seq_config> & configs) {
|
||||
if (model == nullptr) {
|
||||
setup_model(model_path);
|
||||
load_model(model_path);
|
||||
}
|
||||
|
||||
if (model != nullptr && ctx != nullptr) {
|
||||
if (ctx != nullptr) {
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
@ -73,10 +75,6 @@ struct test_model_context {
|
|||
}
|
||||
llama_set_warmup(ctx, false);
|
||||
|
||||
vocab = llama_model_get_vocab(model);
|
||||
n_vocab = llama_vocab_n_tokens(vocab);
|
||||
fprintf(stderr, "Vocabulary size: %d\n", n_vocab);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
@ -130,15 +128,15 @@ struct test_model_context {
|
|||
|
||||
|
||||
printf("Batch contents:\n");
|
||||
printf(" n_tokens: %d\n", batch.n_tokens);
|
||||
printf("n_tokens: %d\n", batch.n_tokens);
|
||||
for (int i = 0; i < batch.n_tokens; i++) {
|
||||
printf(" token[%d]: tok=%-5d, pos=%d, n_seq_id=%d, seq_ids=[", i, batch.token[i], batch.pos[i], batch.n_seq_id[i]);
|
||||
printf("token[%d]: tok=%-5d, pos=%d, n_seq_id=%d, seq_ids=[", i, batch.token[i], batch.pos[i], batch.n_seq_id[i]);
|
||||
|
||||
for (int j = 0; j < batch.n_seq_id[i]; j++) {
|
||||
printf("%d%s", batch.seq_id[i][j], j < batch.n_seq_id[i]-1 ? ", " : "");
|
||||
for (int j = 0; j < batch.n_seq_id[i]; j++) {
|
||||
printf("%d%s", batch.seq_id[i][j], j < batch.n_seq_id[i]-1 ? ", " : "");
|
||||
}
|
||||
printf("], logits=%d\n", batch.logits[i]);
|
||||
}
|
||||
printf("], logits=%d\n", batch.logits[i]);
|
||||
}
|
||||
|
||||
if (llama_decode(ctx, batch) != 0) {
|
||||
fprintf(stderr, "Warning: llama_decode failed\n");
|
||||
|
|
@ -151,7 +149,6 @@ struct test_model_context {
|
|||
if (batch.logits[i]) {
|
||||
llama_seq_id seq_id = batch.seq_id[i][0];
|
||||
last_batch_info[seq_id] = i;
|
||||
printf("seq %d : batch idx %d\n", seq_id, i);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -249,10 +246,15 @@ struct test_model_context {
|
|||
}
|
||||
|
||||
void cleanup() {
|
||||
if (ctx) llama_free(ctx);
|
||||
if (model) llama_model_free(model);
|
||||
if (ctx) {
|
||||
llama_free(ctx);
|
||||
}
|
||||
if (model) {
|
||||
llama_model_free(model);
|
||||
}
|
||||
llama_backend_free();
|
||||
ctx = nullptr;
|
||||
|
||||
ctx = nullptr;
|
||||
model = nullptr;
|
||||
vocab = nullptr;
|
||||
}
|
||||
|
|
@ -374,36 +376,44 @@ static void test_backend_temp_sampling(const char * model_path) {
|
|||
return;
|
||||
}
|
||||
|
||||
int32_t batch_idx_0 = test_ctx.idx_for_seq(0);
|
||||
int32_t batch_idx_1 = test_ctx.idx_for_seq(1);
|
||||
// Verfify sequence 0
|
||||
{
|
||||
int32_t batch_idx = test_ctx.idx_for_seq(0);
|
||||
int n_logits = llama_get_backend_sampled_logits_count_ith(test_ctx.ctx, batch_idx);
|
||||
GGML_ASSERT(n_logits == test_ctx.n_vocab);
|
||||
|
||||
int n_logits = llama_get_backend_sampled_logits_count_ith(test_ctx.ctx, batch_idx_0);
|
||||
GGML_ASSERT(n_logits == test_ctx.n_vocab);
|
||||
// Sample from sequence 0 using CPU sampler
|
||||
struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * chain = llama_sampler_chain_init(chain_params);
|
||||
llama_sampler_chain_add(chain, llama_sampler_init_dist(18));
|
||||
|
||||
// Sample from sequence 0 using CPU sampler
|
||||
struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * chain_0 = llama_sampler_chain_init(chain_params_0);
|
||||
llama_sampler_chain_add(chain_0, llama_sampler_init_dist(18));
|
||||
llama_token token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx);
|
||||
const std::string token_str = test_ctx.token_to_piece(token, false);
|
||||
printf("Sequence 0 sampled token id:%d, string: '%s'\n", token, token_str.c_str());
|
||||
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
|
||||
|
||||
llama_token token_0 = llama_sampler_sample(chain_0, test_ctx.ctx, batch_idx_0);
|
||||
const std::string token_0_str = test_ctx.token_to_piece(token_0, false);
|
||||
printf("Sequence 0 sampled token id:%d, string: '%s'\n", token_0, token_0_str.c_str());
|
||||
GGML_ASSERT(token_0 >= 0 && token_0 < test_ctx.n_vocab);
|
||||
llama_sampler_free(chain);
|
||||
}
|
||||
|
||||
// Sample from sequence 1 using CPU sampler
|
||||
struct llama_sampler_chain_params chain_params_1 = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * chain_1 = llama_sampler_chain_init(chain_params_1);
|
||||
llama_sampler_chain_add(chain_1, llama_sampler_init_dist(18));
|
||||
// Verfify sequence 1
|
||||
{
|
||||
int32_t batch_idx = test_ctx.idx_for_seq(1);
|
||||
|
||||
llama_token token_1 = llama_sampler_sample(chain_1, test_ctx.ctx, batch_idx_1);
|
||||
const std::string token_1_str = test_ctx.token_to_piece(token_1, false);
|
||||
printf("Sequence 1 sampled token id:%d, string: '%s'\n", token_1, token_1_str.c_str());
|
||||
GGML_ASSERT(token_1 >= 0 && token_1 < test_ctx.n_vocab);
|
||||
// Sample from sequence 1 using CPU sampler
|
||||
struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * chain = llama_sampler_chain_init(chain_params);
|
||||
llama_sampler_chain_add(chain, llama_sampler_init_dist(18));
|
||||
|
||||
llama_token token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx);
|
||||
const std::string token_str = test_ctx.token_to_piece(token, false);
|
||||
printf("Sequence 1 sampled token id:%d, string: '%s'\n", token, token_str.c_str());
|
||||
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
|
||||
|
||||
llama_sampler_free(chain);
|
||||
}
|
||||
|
||||
printf("backend temp sampling test PASSED\n");
|
||||
|
||||
llama_sampler_free(chain_0);
|
||||
llama_sampler_free(chain_1);
|
||||
}
|
||||
|
||||
static void test_backend_multi_sequence_sampling(const char * model_path) {
|
||||
|
|
@ -436,17 +446,23 @@ static void test_backend_multi_sequence_sampling(const char * model_path) {
|
|||
return;
|
||||
}
|
||||
|
||||
int32_t batch_idx_0 = test_ctx.idx_for_seq(0);
|
||||
llama_token seq0_token = llama_get_backend_sampled_token_ith(test_ctx.ctx, batch_idx_0);
|
||||
const std::string seq0_token_str = test_ctx.token_to_piece(seq0_token, false);
|
||||
printf("Seq 0 sampled token id=%d, string='%s'\n", seq0_token, seq0_token_str.c_str());
|
||||
GGML_ASSERT(seq0_token >= 0 && seq0_token < test_ctx.n_vocab);
|
||||
// Verfiy sequence 0
|
||||
{
|
||||
int32_t batch_idx = test_ctx.idx_for_seq(0);
|
||||
llama_token token = llama_get_backend_sampled_token_ith(test_ctx.ctx, batch_idx);
|
||||
const std::string token_str = test_ctx.token_to_piece(token, false);
|
||||
printf("Seq 0 sampled token id=%d, string='%s'\n", token, token_str.c_str());
|
||||
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
|
||||
}
|
||||
|
||||
int32_t batch_idx_1 = test_ctx.idx_for_seq(1);
|
||||
llama_token seq1_token = llama_get_backend_sampled_token_ith(test_ctx.ctx, batch_idx_1);
|
||||
const std::string seq1_token_str = test_ctx.token_to_piece(seq1_token, false);
|
||||
printf("Seq 1 sampled token id=%d, string='%s'\n", seq1_token, seq1_token_str.c_str());
|
||||
GGML_ASSERT(seq1_token >= 0 && seq1_token < test_ctx.n_vocab);
|
||||
// Verify sequence 1
|
||||
{
|
||||
int32_t batch_idx= test_ctx.idx_for_seq(1);
|
||||
llama_token token = llama_get_backend_sampled_token_ith(test_ctx.ctx, batch_idx);
|
||||
const std::string token_str = test_ctx.token_to_piece(token, false);
|
||||
printf("Seq 1 sampled token id=%d, string='%s'\n", token, token_str.c_str());
|
||||
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
|
||||
}
|
||||
|
||||
// Generate tokens for each sequence
|
||||
printf("\nMulti-sequence generation:\n");
|
||||
|
|
@ -473,26 +489,29 @@ static void test_backend_multi_sequence_sampling(const char * model_path) {
|
|||
static void test_backend_dist_sampling(const char * model_path) {
|
||||
test_model_context test_ctx;
|
||||
|
||||
const int seq_id = 189;
|
||||
const int32_t seed = 88;
|
||||
struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params);
|
||||
llama_sampler_chain_add(backend_sampler_chain, llama_sampler_backend_init_dist(seed));
|
||||
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ 0, backend_sampler_chain }};
|
||||
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain }};
|
||||
|
||||
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!test_ctx.decode({{0, "Hello"}})) {
|
||||
if (!test_ctx.decode({{seq_id, "Some"}})) {
|
||||
return;
|
||||
}
|
||||
|
||||
llama_token token = llama_get_backend_sampled_token_ith(test_ctx.ctx, test_ctx.idx_for_seq(0));
|
||||
printf("greedy sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str());
|
||||
int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
|
||||
llama_token token = llama_get_backend_sampled_token_ith(test_ctx.ctx, batch_idx);
|
||||
printf("dist sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str());
|
||||
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
|
||||
GGML_ASSERT(llama_get_backend_sampled_logits_ith(test_ctx.ctx, batch_idx) == nullptr);
|
||||
|
||||
token = llama_get_backend_sampled_token_ith(test_ctx.ctx, -1);
|
||||
printf("greedy sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str());
|
||||
printf("dist sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str());
|
||||
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
|
||||
}
|
||||
|
||||
|
|
@ -510,7 +529,7 @@ static void test_backend_dist_sampling_and_cpu(const char * model_path) {
|
|||
return;
|
||||
}
|
||||
|
||||
if (!test_ctx.decode({{seq_id, "Hello"}})) {
|
||||
if (!test_ctx.decode({{seq_id, "Some"}})) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -523,14 +542,15 @@ static void test_backend_dist_sampling_and_cpu(const char * model_path) {
|
|||
|
||||
llama_token backend_token = llama_get_backend_sampled_token_ith(test_ctx.ctx, batch_idx);
|
||||
llama_token cpu_token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx);
|
||||
printf("dist & cpu sampled id:%d, string:'%s'\n", cpu_token, test_ctx.token_to_piece(cpu_token, false).c_str());
|
||||
GGML_ASSERT(backend_token == cpu_token);
|
||||
}
|
||||
|
||||
static void test_backend_logit_bias_sampling(const char * model_path) {
|
||||
test_model_context test_ctx;
|
||||
|
||||
// Calling setup_model to ensure vocab is loaded and can be accessed
|
||||
if (!test_ctx.setup_model(model_path)) {
|
||||
// Calling load_model to ensure vocab is loaded and can be accessed
|
||||
if (!test_ctx.load_model(model_path)) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -616,7 +636,7 @@ static void test_backend_mixed_sampling(const char * model_path) {
|
|||
GGML_ASSERT(llama_get_backend_sampled_logits_count_ith(test_ctx.ctx, batch_idx) == 0);
|
||||
}
|
||||
|
||||
// Verfiy sequence 0 that used the top-k backend sampler.
|
||||
// Verfiy sequence 1 that used the top-k backend sampler.
|
||||
{
|
||||
int32_t batch_idx = test_ctx.idx_for_seq(1);
|
||||
float * logits = llama_get_backend_sampled_logits_ith(test_ctx.ctx, batch_idx);
|
||||
|
|
|
|||
Loading…
Reference in New Issue