tests : use smart pointers for model and context
This commit is contained in:
parent
9a9ea2f6b1
commit
9845996919
|
|
@ -18,8 +18,8 @@
|
|||
#include <vector>
|
||||
|
||||
struct test_model_context {
|
||||
llama_model * model = nullptr;
|
||||
llama_context * ctx = nullptr;
|
||||
llama_model_ptr model;
|
||||
llama_context_ptr ctx;
|
||||
const llama_vocab * vocab = nullptr;
|
||||
int n_vocab = 0;
|
||||
|
||||
|
|
@ -27,7 +27,7 @@ struct test_model_context {
|
|||
std::unordered_map<llama_seq_id, int32_t> last_batch_info;
|
||||
|
||||
bool load_model(const char * model_path) {
|
||||
if (model != nullptr) {
|
||||
if (model) {
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
@ -41,13 +41,14 @@ struct test_model_context {
|
|||
auto mparams = llama_model_default_params();
|
||||
mparams.devices = devs;
|
||||
|
||||
model = llama_model_load_from_file(model_path, mparams);
|
||||
if (model == nullptr) {
|
||||
model.reset(llama_model_load_from_file(model_path, mparams));
|
||||
|
||||
if (!model) {
|
||||
fprintf(stderr, "Warning: failed to load model '%s', skipping test\n", model_path);
|
||||
cleanup();
|
||||
return false;
|
||||
}
|
||||
vocab = llama_model_get_vocab(model);
|
||||
vocab = llama_model_get_vocab(model.get());
|
||||
n_vocab = llama_vocab_n_tokens(vocab);
|
||||
fprintf(stderr, "Vocabulary size: %d\n", n_vocab);
|
||||
|
||||
|
|
@ -59,7 +60,7 @@ struct test_model_context {
|
|||
load_model(model_path);
|
||||
}
|
||||
|
||||
if (ctx != nullptr) {
|
||||
if (ctx) {
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
@ -80,13 +81,13 @@ struct test_model_context {
|
|||
cparams.n_seq_max = n_seq_max;
|
||||
}
|
||||
|
||||
ctx = llama_init_from_model(model, cparams);
|
||||
ctx.reset(llama_init_from_model(model.get(), cparams));
|
||||
if (ctx == nullptr) {
|
||||
fprintf(stderr, "Warning: failed to create context, skipping test\n");
|
||||
cleanup();
|
||||
return false;
|
||||
}
|
||||
llama_set_warmup(ctx, false);
|
||||
llama_set_warmup(ctx.get(), false);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
|
@ -151,7 +152,7 @@ struct test_model_context {
|
|||
printf("], logits=%d\n", batch.logits[i]);
|
||||
}
|
||||
|
||||
if (llama_decode(ctx, batch) != 0) {
|
||||
if (llama_decode(ctx.get(), batch) != 0) {
|
||||
fprintf(stderr, "Warning: llama_decode failed\n");
|
||||
llama_batch_free(batch);
|
||||
return false;
|
||||
|
|
@ -188,7 +189,7 @@ struct test_model_context {
|
|||
int32_t pos = seq_positions[seq_id];
|
||||
common_batch_add(batch, token, pos, { seq_id }, true);
|
||||
|
||||
if (llama_decode(ctx, batch) != 0) {
|
||||
if (llama_decode(ctx.get(), batch) != 0) {
|
||||
fprintf(stderr, "Warning: llama_decode failed for token %d in seq %d\n", token, seq_id);
|
||||
llama_batch_free(batch);
|
||||
return false;
|
||||
|
|
@ -220,7 +221,7 @@ struct test_model_context {
|
|||
common_batch_add(batch, token, pos, { seq_id }, true);
|
||||
}
|
||||
|
||||
if (llama_decode(ctx, batch) != 0) {
|
||||
if (llama_decode(ctx.get(), batch) != 0) {
|
||||
fprintf(stderr, "Warning: llama_decode failed for batch tokens\n");
|
||||
llama_batch_free(batch);
|
||||
return false;
|
||||
|
|
@ -260,23 +261,13 @@ struct test_model_context {
|
|||
|
||||
void reset() {
|
||||
if (ctx) {
|
||||
llama_free(ctx);
|
||||
ctx = nullptr;
|
||||
ctx.reset();
|
||||
}
|
||||
seq_positions.clear();
|
||||
last_batch_info.clear();
|
||||
}
|
||||
|
||||
void cleanup() {
|
||||
if (ctx) {
|
||||
llama_free(ctx);
|
||||
}
|
||||
if (model) {
|
||||
llama_model_free(model);
|
||||
}
|
||||
|
||||
ctx = nullptr;
|
||||
model = nullptr;
|
||||
vocab = nullptr;
|
||||
}
|
||||
|
||||
|
|
@ -306,17 +297,17 @@ static void test_backend_greedy_sampling(const char * model_path) {
|
|||
|
||||
int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
|
||||
|
||||
llama_token token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx);
|
||||
llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
|
||||
printf("greedy sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str());
|
||||
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
|
||||
|
||||
token = llama_get_sampled_token_ith(test_ctx.ctx, -1);
|
||||
token = llama_get_sampled_token_ith(test_ctx.ctx.get(), -1);
|
||||
printf("greedy sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str());
|
||||
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
|
||||
|
||||
for (int i = 0; i < 10; i++) {
|
||||
int32_t loop_idx = test_ctx.idx_for_seq(seq_id);
|
||||
llama_token token = llama_get_sampled_token_ith(test_ctx.ctx, loop_idx);
|
||||
llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), loop_idx);
|
||||
printf("Generation step %d: token id:%d, string: %s\n", i, token, test_ctx.token_to_piece(token, false).c_str());
|
||||
if (!test_ctx.decode_token(token, 0)) {
|
||||
GGML_ASSERT(false && "Failed to decode token");
|
||||
|
|
@ -344,14 +335,14 @@ static void test_backend_top_k_sampling(const char * model_path) {
|
|||
|
||||
int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
|
||||
|
||||
float * logits = llama_get_sampled_logits_ith(test_ctx.ctx, batch_idx);
|
||||
uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx);
|
||||
float * logits = llama_get_sampled_logits_ith(test_ctx.ctx.get(), batch_idx);
|
||||
uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx);
|
||||
for (size_t i = 0; i < n_logits; ++i) {
|
||||
printf("top_k logit[%zu] = %.6f\n", i, logits[i]);
|
||||
}
|
||||
|
||||
llama_token * candidates = llama_get_sampled_candidates_ith(test_ctx.ctx, batch_idx);
|
||||
uint32_t n_candidates = llama_get_sampled_candidates_count_ith(test_ctx.ctx, batch_idx);
|
||||
llama_token * candidates = llama_get_sampled_candidates_ith(test_ctx.ctx.get(), batch_idx);
|
||||
uint32_t n_candidates = llama_get_sampled_candidates_count_ith(test_ctx.ctx.get(), batch_idx);
|
||||
for (size_t i = 0; i < n_candidates; ++i) {
|
||||
printf("top_k candidate[%zu] = %d : %s\n", i, candidates[i],
|
||||
test_ctx.token_to_piece(candidates[i], false).c_str());
|
||||
|
|
@ -364,7 +355,7 @@ static void test_backend_top_k_sampling(const char * model_path) {
|
|||
GGML_ASSERT(chain->iface->backend_apply != nullptr);
|
||||
|
||||
llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(18));
|
||||
llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx, batch_idx);
|
||||
llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx);
|
||||
const std::string token_str = test_ctx.token_to_piece(token, false);
|
||||
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
|
||||
|
||||
|
|
@ -401,7 +392,7 @@ static void test_backend_temp_sampling(const char * model_path) {
|
|||
// Verfify sequence 0
|
||||
{
|
||||
int32_t batch_idx = test_ctx.idx_for_seq(0);
|
||||
int n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx);
|
||||
int n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx);
|
||||
GGML_ASSERT(n_logits == test_ctx.n_vocab);
|
||||
|
||||
// Sample from sequence 0 using CPU sampler
|
||||
|
|
@ -409,7 +400,7 @@ static void test_backend_temp_sampling(const char * model_path) {
|
|||
llama_sampler_ptr chain(llama_sampler_chain_init(chain_params));
|
||||
llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(18));
|
||||
|
||||
llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx, batch_idx);
|
||||
llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx);
|
||||
const std::string token_str = test_ctx.token_to_piece(token, false);
|
||||
printf("Sequence 0 sampled token id:%d, string: '%s'\n", token, token_str.c_str());
|
||||
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
|
||||
|
|
@ -425,7 +416,7 @@ static void test_backend_temp_sampling(const char * model_path) {
|
|||
llama_sampler_ptr chain(llama_sampler_chain_init(chain_params));
|
||||
llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(18));
|
||||
|
||||
llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx, batch_idx);
|
||||
llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx);
|
||||
const std::string token_str = test_ctx.token_to_piece(token, false);
|
||||
printf("Sequence 1 sampled token id:%d, string: '%s'\n", token, token_str.c_str());
|
||||
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
|
||||
|
|
@ -457,7 +448,7 @@ static void test_backend_temp_sampling(const char * model_path) {
|
|||
|
||||
int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
|
||||
|
||||
uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx);
|
||||
uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx);
|
||||
GGML_ASSERT(n_logits == 1);
|
||||
};
|
||||
|
||||
|
|
@ -495,7 +486,7 @@ static void test_backend_temp_ext_sampling(const char * model_path) {
|
|||
// Verify sequence 0
|
||||
{
|
||||
int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
|
||||
int n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx);
|
||||
int n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx);
|
||||
GGML_ASSERT(n_logits == test_ctx.n_vocab);
|
||||
}
|
||||
}
|
||||
|
|
@ -527,7 +518,7 @@ static void test_backend_temp_ext_sampling(const char * model_path) {
|
|||
|
||||
int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
|
||||
|
||||
uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx);
|
||||
uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx);
|
||||
|
||||
if (temp <= 0.0f && delta >= 0.0f) {
|
||||
GGML_ASSERT(n_logits == 1);
|
||||
|
|
@ -564,8 +555,8 @@ static void test_backend_min_p_sampling(const char * model_path) {
|
|||
|
||||
int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
|
||||
|
||||
float * logits = llama_get_sampled_logits_ith(test_ctx.ctx, batch_idx);
|
||||
uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx);
|
||||
float * logits = llama_get_sampled_logits_ith(test_ctx.ctx.get(), batch_idx);
|
||||
uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx);
|
||||
|
||||
// Print the logits that are above the min-p threshold
|
||||
std::vector<float> filtered_logits;
|
||||
|
|
@ -582,7 +573,7 @@ static void test_backend_min_p_sampling(const char * model_path) {
|
|||
llama_sampler_ptr chain(llama_sampler_chain_init(chain_params));
|
||||
llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(88));
|
||||
|
||||
llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx, batch_idx);
|
||||
llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx);
|
||||
const std::string token_str = test_ctx.token_to_piece(token, false);
|
||||
printf("min-p cpu sampled token id:%d, string: '%s'\n", token, token_str.c_str());
|
||||
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
|
||||
|
|
@ -590,7 +581,7 @@ static void test_backend_min_p_sampling(const char * model_path) {
|
|||
// Decode and sampler 10 more tokens
|
||||
for (int i = 0; i < 10; i++) {
|
||||
int32_t loop_idx = test_ctx.idx_for_seq(seq_id);
|
||||
llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx, loop_idx);
|
||||
llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), loop_idx);
|
||||
printf("min-p gen step %d: token id :%5.d, string: %s\n", i, token, test_ctx.token_to_piece(token, false).c_str());
|
||||
if (!test_ctx.decode_token(token, 0)) {
|
||||
GGML_ASSERT(false && "Failed to decode token");
|
||||
|
|
@ -620,8 +611,8 @@ static void test_backend_top_p_sampling(const char * model_path) {
|
|||
|
||||
int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
|
||||
|
||||
float * logits = llama_get_sampled_logits_ith(test_ctx.ctx, batch_idx);
|
||||
uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx);
|
||||
float * logits = llama_get_sampled_logits_ith(test_ctx.ctx.get(), batch_idx);
|
||||
uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx);
|
||||
|
||||
// Print the logits that are above the min-p threshold
|
||||
std::vector<float> filtered_logits;
|
||||
|
|
@ -638,7 +629,7 @@ static void test_backend_top_p_sampling(const char * model_path) {
|
|||
llama_sampler_ptr chain(llama_sampler_chain_init(chain_params));
|
||||
llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(88));
|
||||
|
||||
llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx, batch_idx);
|
||||
llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx);
|
||||
const std::string token_str = test_ctx.token_to_piece(token, false);
|
||||
printf("top-p cpu sampled token id:%d, string: '%s'\n", token, token_str.c_str());
|
||||
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
|
||||
|
|
@ -646,7 +637,7 @@ static void test_backend_top_p_sampling(const char * model_path) {
|
|||
// Decode and sampler 10 more tokens
|
||||
for (int i = 0; i < 10; i++) {
|
||||
int32_t loop_idx = test_ctx.idx_for_seq(seq_id);
|
||||
llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx, loop_idx);
|
||||
llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), loop_idx);
|
||||
printf("top-p gen step %d: token id :%5.d, string: %s\n", i, token, test_ctx.token_to_piece(token, false).c_str());
|
||||
test_ctx.decode_token(token, 0);
|
||||
}
|
||||
|
|
@ -687,7 +678,7 @@ static void test_backend_multi_sequence_sampling(const char * model_path) {
|
|||
// Verfiy sequence 0
|
||||
{
|
||||
int32_t batch_idx = test_ctx.idx_for_seq(0);
|
||||
llama_token token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx);
|
||||
llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
|
||||
const std::string token_str = test_ctx.token_to_piece(token, false);
|
||||
printf("Seq 0 sampled token id=%d, string='%s'\n", token, token_str.c_str());
|
||||
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
|
||||
|
|
@ -696,7 +687,7 @@ static void test_backend_multi_sequence_sampling(const char * model_path) {
|
|||
// Verify sequence 1
|
||||
{
|
||||
int32_t batch_idx= test_ctx.idx_for_seq(1);
|
||||
llama_token token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx);
|
||||
llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
|
||||
const std::string token_str = test_ctx.token_to_piece(token, false);
|
||||
printf("Seq 1 sampled token id=%d, string='%s'\n", token, token_str.c_str());
|
||||
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
|
||||
|
|
@ -709,7 +700,7 @@ static void test_backend_multi_sequence_sampling(const char * model_path) {
|
|||
|
||||
for (llama_seq_id seq_id : {0, 1}) {
|
||||
int32_t idx = test_ctx.idx_for_seq(seq_id);
|
||||
llama_token token = llama_get_sampled_token_ith(test_ctx.ctx, idx);
|
||||
llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), idx);
|
||||
const std::string token_str = test_ctx.token_to_piece(token, false);
|
||||
printf(" Seq %d, step %d: token id=%d, string='%s'\n", seq_id, step, token, token_str.c_str());
|
||||
tokens[seq_id] = token;
|
||||
|
|
@ -743,12 +734,12 @@ static void test_backend_dist_sampling(const char * model_path) {
|
|||
}
|
||||
|
||||
int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
|
||||
llama_token token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx);
|
||||
llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
|
||||
printf("dist sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str());
|
||||
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
|
||||
//GGML_ASSERT(llama_get_sampled_logits_ith(test_ctx.ctx, batch_idx) == nullptr);
|
||||
//GGML_ASSERT(llama_get_sampled_logits_ith(test_ctx.ctx.get(), batch_idx) == nullptr);
|
||||
|
||||
token = llama_get_sampled_token_ith(test_ctx.ctx, -1);
|
||||
token = llama_get_sampled_token_ith(test_ctx.ctx.get(), -1);
|
||||
printf("dist sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str());
|
||||
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
|
||||
|
||||
|
|
@ -780,8 +771,8 @@ static void test_backend_dist_sampling_and_cpu(const char * model_path) {
|
|||
llama_sampler_ptr chain(llama_sampler_chain_init(chain_params));
|
||||
llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(18));
|
||||
|
||||
llama_token backend_token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx);
|
||||
llama_token cpu_token = llama_sampler_sample(chain.get(), test_ctx.ctx, batch_idx);
|
||||
llama_token backend_token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
|
||||
llama_token cpu_token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx);
|
||||
printf("dist & cpu sampled id:%d, string:'%s'\n", cpu_token, test_ctx.token_to_piece(cpu_token, false).c_str());
|
||||
GGML_ASSERT(backend_token == cpu_token);
|
||||
|
||||
|
|
@ -829,7 +820,7 @@ static void test_backend_logit_bias_sampling(const char * model_path) {
|
|||
GGML_ASSERT(false && "Failed to decode token");
|
||||
}
|
||||
|
||||
llama_token backend_token = llama_get_sampled_token_ith(test_ctx.ctx, test_ctx.idx_for_seq(seq_id));
|
||||
llama_token backend_token = llama_get_sampled_token_ith(test_ctx.ctx.get(), test_ctx.idx_for_seq(seq_id));
|
||||
const std::string backend_token_str = test_ctx.token_to_piece(backend_token, false);
|
||||
printf("logit bias sampled token = %d, string='%s'\n", backend_token, backend_token_str.c_str());
|
||||
GGML_ASSERT(backend_token == bias_token);
|
||||
|
|
@ -872,22 +863,22 @@ static void test_backend_mixed_sampling(const char * model_path) {
|
|||
// Verfiy sequence 0 that used the dist backend sampler.
|
||||
{
|
||||
int32_t batch_idx = test_ctx.idx_for_seq(0);
|
||||
llama_token token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx);
|
||||
llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
|
||||
const std::string token_str = test_ctx.token_to_piece(token, false);
|
||||
printf("sampled token id=%d, string='%s'\n", token, token_str.c_str());
|
||||
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
|
||||
//GGML_ASSERT(llama_get_sampled_logits_ith(test_ctx.ctx, batch_idx) == nullptr);
|
||||
//GGML_ASSERT(llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx) == 0);
|
||||
//GGML_ASSERT(llama_get_sampled_logits_ith(test_ctx.ctx.get(), batch_idx) == nullptr);
|
||||
//GGML_ASSERT(llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx) == 0);
|
||||
}
|
||||
|
||||
// Verfiy sequence 1 that used the top-k backend sampler.
|
||||
{
|
||||
int32_t batch_idx = test_ctx.idx_for_seq(1);
|
||||
float * logits = llama_get_sampled_logits_ith(test_ctx.ctx, batch_idx);
|
||||
float * logits = llama_get_sampled_logits_ith(test_ctx.ctx.get(), batch_idx);
|
||||
GGML_ASSERT(logits != nullptr);
|
||||
size_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx);
|
||||
size_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx);
|
||||
GGML_ASSERT(n_logits == (size_t) k);
|
||||
GGML_ASSERT(llama_get_sampled_token_ith(test_ctx.ctx, batch_idx) == LLAMA_TOKEN_NULL);
|
||||
GGML_ASSERT(llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx) == LLAMA_TOKEN_NULL);
|
||||
}
|
||||
|
||||
printf("backend mixed sampling test PASSED\n");
|
||||
|
|
@ -914,12 +905,12 @@ static void test_backend_set_sampler(const char * model_path) {
|
|||
int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
|
||||
|
||||
// Sample using backend sampler configured above
|
||||
llama_token backend_token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx);
|
||||
llama_token backend_token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
|
||||
const std::string backend_token_str = test_ctx.token_to_piece(backend_token, false);
|
||||
printf("dist sampled token = %d, string='%s'\n", backend_token, backend_token_str.c_str());
|
||||
|
||||
// Now clear the backend sampler for this sequence.
|
||||
llama_set_sampler(test_ctx.ctx, seq_id, nullptr);
|
||||
llama_set_sampler(test_ctx.ctx.get(), seq_id, nullptr);
|
||||
printf("Cleared backend sampler for seq_id %d\n", seq_id);
|
||||
|
||||
// Sample using CPU sampler
|
||||
|
|
@ -934,11 +925,11 @@ static void test_backend_set_sampler(const char * model_path) {
|
|||
|
||||
// Should not have any sampled token or probs after clearing the backend sampler.
|
||||
const int32_t idx = test_ctx.idx_for_seq(seq_id);
|
||||
GGML_ASSERT(llama_get_sampled_token_ith(test_ctx.ctx, idx) == LLAMA_TOKEN_NULL);
|
||||
GGML_ASSERT(llama_get_sampled_probs_ith(test_ctx.ctx, idx) == nullptr);
|
||||
GGML_ASSERT(llama_get_sampled_token_ith(test_ctx.ctx.get(), idx) == LLAMA_TOKEN_NULL);
|
||||
GGML_ASSERT(llama_get_sampled_probs_ith(test_ctx.ctx.get(), idx) == nullptr);
|
||||
|
||||
// Sample the token using the CPU sampler chain.
|
||||
llama_token token2 = llama_sampler_sample(chain.get(), test_ctx.ctx, seq_id);
|
||||
llama_token token2 = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), seq_id);
|
||||
const std::string token2_str = test_ctx.token_to_piece(token2, false);
|
||||
printf("CPU sampled token after clearing backend sampler: id=%d, string='%s'\n", token2, token2_str.c_str());
|
||||
std::map<llama_seq_id, llama_token> tokens2 = { { seq_id, token2}, };
|
||||
|
|
@ -948,13 +939,13 @@ static void test_backend_set_sampler(const char * model_path) {
|
|||
llama_sampler_ptr new_backend_sampler_chain(llama_sampler_chain_init(new_backend_chain_params));
|
||||
llama_sampler_chain_add(new_backend_sampler_chain.get(), llama_sampler_init_top_k(20));
|
||||
llama_sampler_chain_add(new_backend_sampler_chain.get(), llama_sampler_init_dist(seed));
|
||||
llama_set_sampler(test_ctx.ctx, seq_id, new_backend_sampler_chain.get());
|
||||
llama_set_sampler(test_ctx.ctx.get(), seq_id, new_backend_sampler_chain.get());
|
||||
|
||||
if (!test_ctx.decode_tokens(tokens2)) {
|
||||
GGML_ASSERT(false && "Failed to decode token");
|
||||
}
|
||||
|
||||
llama_token new_backend_token = llama_get_sampled_token_ith(test_ctx.ctx, test_ctx.idx_for_seq(seq_id));
|
||||
llama_token new_backend_token = llama_get_sampled_token_ith(test_ctx.ctx.get(), test_ctx.idx_for_seq(seq_id));
|
||||
const std::string new_backend_token_str = test_ctx.token_to_piece(new_backend_token, false);
|
||||
printf("dist sampled token = %d, string='%s'\n", new_backend_token, new_backend_token_str.c_str());
|
||||
|
||||
|
|
@ -990,7 +981,7 @@ static void test_backend_cpu_mixed_batch(const char * model_path) {
|
|||
// Verify sequence 0 (backend sampled)
|
||||
{
|
||||
int32_t batch_idx = test_ctx.idx_for_seq(0);
|
||||
llama_token token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx);
|
||||
llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
|
||||
const std::string token_str = test_ctx.token_to_piece(token, false);
|
||||
printf("Seq 0 (backend) sampled token id=%d, string='%s'\n", token, token_str.c_str());
|
||||
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
|
||||
|
|
@ -1000,14 +991,14 @@ static void test_backend_cpu_mixed_batch(const char * model_path) {
|
|||
{
|
||||
int32_t batch_idx = test_ctx.idx_for_seq(1);
|
||||
|
||||
llama_token backend_token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx);
|
||||
llama_token backend_token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
|
||||
GGML_ASSERT(backend_token == LLAMA_TOKEN_NULL);
|
||||
|
||||
struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
|
||||
llama_sampler_ptr chain(llama_sampler_chain_init(chain_params));
|
||||
llama_sampler_chain_add(chain.get(), llama_sampler_init_greedy());
|
||||
|
||||
llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx, batch_idx);
|
||||
llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx);
|
||||
const std::string token_str = test_ctx.token_to_piece(token, false);
|
||||
printf("Seq 1 (CPU) sampled token id=%d, string='%s'\n", token, token_str.c_str());
|
||||
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
|
||||
|
|
@ -1017,7 +1008,7 @@ static void test_backend_cpu_mixed_batch(const char * model_path) {
|
|||
{
|
||||
// clear the backend sampler for seq 0 so that there are no backend
|
||||
// samplers.
|
||||
llama_set_sampler(test_ctx.ctx, 0, nullptr);
|
||||
llama_set_sampler(test_ctx.ctx.get(), 0, nullptr);
|
||||
|
||||
// Create a CPU sampler and verify we can sampler from it.
|
||||
struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
|
||||
|
|
@ -1025,7 +1016,7 @@ static void test_backend_cpu_mixed_batch(const char * model_path) {
|
|||
llama_sampler_chain_add(chain.get(), llama_sampler_init_greedy());
|
||||
|
||||
int32_t batch_idx = test_ctx.idx_for_seq(1);
|
||||
llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx, batch_idx);
|
||||
llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx);
|
||||
if (!test_ctx.decode_token(token, 1)) {
|
||||
GGML_ASSERT(false && "Failed to decode token");
|
||||
}
|
||||
|
|
@ -1037,14 +1028,14 @@ static void test_backend_cpu_mixed_batch(const char * model_path) {
|
|||
llama_sampler_ptr sampler_chain(llama_sampler_chain_init(chain_params));
|
||||
llama_sampler_chain_add(sampler_chain.get(), llama_sampler_init_dist(88));
|
||||
|
||||
llama_set_sampler(test_ctx.ctx, 0, sampler_chain.get());
|
||||
llama_set_sampler(test_ctx.ctx.get(), 0, sampler_chain.get());
|
||||
|
||||
if (!test_ctx.decode_token(3834, 0)) {
|
||||
GGML_ASSERT(false && "Failed to decode token");
|
||||
}
|
||||
|
||||
int32_t batch_idx = test_ctx.idx_for_seq(0);
|
||||
llama_token token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx);
|
||||
llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
|
||||
const std::string token_str = test_ctx.token_to_piece(token, false);
|
||||
printf("re-added backend sampled token id=%d, string='%s'\n", token, token_str.c_str());
|
||||
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
|
||||
|
|
@ -1087,7 +1078,7 @@ static void test_backend_max_outputs(const char * model_path) {
|
|||
}
|
||||
|
||||
printf(">>> test_max_outputs expected error start:\n");
|
||||
const int ret = llama_decode(test_ctx.ctx, batch);
|
||||
const int ret = llama_decode(test_ctx.ctx.get(), batch);
|
||||
GGML_ASSERT(ret != 0 && "llama_decode should not succeed multiple outputs per sequence");
|
||||
printf("<<< test_max_outputs expected error end.\n");
|
||||
llama_batch_free(batch);
|
||||
|
|
|
|||
Loading…
Reference in New Issue