tests : use smart pointers for backend samplers
This commit is contained in:
parent
c5d44b8525
commit
9a9ea2f6b1
|
|
@ -1,5 +1,6 @@
|
|||
#include "ggml.h"
|
||||
#include "llama.h"
|
||||
#include "llama-cpp.h"
|
||||
#include "get-model.h"
|
||||
#include "common.h"
|
||||
|
||||
|
|
@ -290,10 +291,10 @@ static void test_backend_greedy_sampling(const char * model_path) {
|
|||
const int seq_id = 0;
|
||||
|
||||
struct llama_sampler_chain_params backend_sampler_params = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_sampler_params);
|
||||
llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_sampler_params));
|
||||
|
||||
llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_greedy());
|
||||
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain }};
|
||||
llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_greedy());
|
||||
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
|
||||
|
||||
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
|
||||
return;
|
||||
|
|
@ -321,8 +322,6 @@ static void test_backend_greedy_sampling(const char * model_path) {
|
|||
GGML_ASSERT(false && "Failed to decode token");
|
||||
}
|
||||
}
|
||||
|
||||
llama_sampler_free(backend_sampler_chain);
|
||||
}
|
||||
|
||||
static void test_backend_top_k_sampling(const char * model_path) {
|
||||
|
|
@ -331,9 +330,9 @@ static void test_backend_top_k_sampling(const char * model_path) {
|
|||
const int seq_id = 0;
|
||||
const int32_t k = 8;
|
||||
struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params);
|
||||
llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_top_k(k));
|
||||
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain }};
|
||||
llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
|
||||
llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_top_k(k));
|
||||
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
|
||||
|
||||
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
|
||||
return;
|
||||
|
|
@ -358,22 +357,18 @@ static void test_backend_top_k_sampling(const char * model_path) {
|
|||
test_ctx.token_to_piece(candidates[i], false).c_str());
|
||||
}
|
||||
|
||||
llama_sampler_free(backend_sampler_chain);
|
||||
|
||||
// Sample using CPU sampler for verification that it is possible to do hybrid
|
||||
// sampling, first top_k on the backend and then dist on the CPU.
|
||||
struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * chain = llama_sampler_chain_init(chain_params);
|
||||
llama_sampler_ptr chain(llama_sampler_chain_init(chain_params));
|
||||
GGML_ASSERT(chain->iface->backend_apply != nullptr);
|
||||
|
||||
llama_sampler_chain_add(chain, llama_sampler_init_dist(18));
|
||||
llama_token token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx);
|
||||
llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(18));
|
||||
llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx, batch_idx);
|
||||
const std::string token_str = test_ctx.token_to_piece(token, false);
|
||||
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
|
||||
|
||||
printf("backend top-k hybrid sampling test PASSED\n");
|
||||
|
||||
llama_sampler_free(chain);
|
||||
}
|
||||
|
||||
static void test_backend_temp_sampling(const char * model_path) {
|
||||
|
|
@ -382,17 +377,17 @@ static void test_backend_temp_sampling(const char * model_path) {
|
|||
{
|
||||
const float temp_0 = 0.8f;
|
||||
struct llama_sampler_chain_params backend_chain_params_0 = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * backend_sampler_chain_0 = llama_sampler_chain_init(backend_chain_params_0);
|
||||
llama_sampler_chain_add(backend_sampler_chain_0, llama_sampler_init_temp(temp_0));
|
||||
llama_sampler_ptr backend_sampler_chain_0(llama_sampler_chain_init(backend_chain_params_0));
|
||||
llama_sampler_chain_add(backend_sampler_chain_0.get(), llama_sampler_init_temp(temp_0));
|
||||
|
||||
const float temp_1 = 0.1f;
|
||||
struct llama_sampler_chain_params backend_chain_params_1 = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * backend_sampler_chain_1 = llama_sampler_chain_init(backend_chain_params_1);
|
||||
llama_sampler_chain_add(backend_sampler_chain_1, llama_sampler_init_temp(temp_1));
|
||||
llama_sampler_ptr backend_sampler_chain_1(llama_sampler_chain_init(backend_chain_params_1));
|
||||
llama_sampler_chain_add(backend_sampler_chain_1.get(), llama_sampler_init_temp(temp_1));
|
||||
|
||||
std::vector<llama_sampler_seq_config> backend_sampler_configs = {
|
||||
{ 0, backend_sampler_chain_0 },
|
||||
{ 1, backend_sampler_chain_1 }
|
||||
{ 0, backend_sampler_chain_0.get() },
|
||||
{ 1, backend_sampler_chain_1.get() }
|
||||
};
|
||||
|
||||
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
|
||||
|
|
@ -403,9 +398,6 @@ static void test_backend_temp_sampling(const char * model_path) {
|
|||
GGML_ASSERT(false && "Failed to decode token");
|
||||
}
|
||||
|
||||
llama_sampler_free(backend_sampler_chain_0);
|
||||
llama_sampler_free(backend_sampler_chain_1);
|
||||
|
||||
// Verfify sequence 0
|
||||
{
|
||||
int32_t batch_idx = test_ctx.idx_for_seq(0);
|
||||
|
|
@ -414,15 +406,13 @@ static void test_backend_temp_sampling(const char * model_path) {
|
|||
|
||||
// Sample from sequence 0 using CPU sampler
|
||||
struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * chain = llama_sampler_chain_init(chain_params);
|
||||
llama_sampler_chain_add(chain, llama_sampler_init_dist(18));
|
||||
llama_sampler_ptr chain(llama_sampler_chain_init(chain_params));
|
||||
llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(18));
|
||||
|
||||
llama_token token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx);
|
||||
llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx, batch_idx);
|
||||
const std::string token_str = test_ctx.token_to_piece(token, false);
|
||||
printf("Sequence 0 sampled token id:%d, string: '%s'\n", token, token_str.c_str());
|
||||
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
|
||||
|
||||
llama_sampler_free(chain);
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -432,15 +422,13 @@ static void test_backend_temp_sampling(const char * model_path) {
|
|||
|
||||
// Sample from sequence 1 using CPU sampler
|
||||
struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * chain = llama_sampler_chain_init(chain_params);
|
||||
llama_sampler_chain_add(chain, llama_sampler_init_dist(18));
|
||||
llama_sampler_ptr chain(llama_sampler_chain_init(chain_params));
|
||||
llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(18));
|
||||
|
||||
llama_token token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx);
|
||||
llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx, batch_idx);
|
||||
const std::string token_str = test_ctx.token_to_piece(token, false);
|
||||
printf("Sequence 1 sampled token id:%d, string: '%s'\n", token, token_str.c_str());
|
||||
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
|
||||
|
||||
llama_sampler_free(chain);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -452,11 +440,11 @@ static void test_backend_temp_sampling(const char * model_path) {
|
|||
|
||||
int seq_id = 0;
|
||||
struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params);
|
||||
llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_temp(temp));
|
||||
llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
|
||||
llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_temp(temp));
|
||||
|
||||
std::vector<llama_sampler_seq_config> backend_sampler_configs = {
|
||||
{ seq_id, backend_sampler_chain },
|
||||
{ seq_id, backend_sampler_chain.get() },
|
||||
};
|
||||
|
||||
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
|
||||
|
|
@ -471,8 +459,6 @@ static void test_backend_temp_sampling(const char * model_path) {
|
|||
|
||||
uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx);
|
||||
GGML_ASSERT(n_logits == 1);
|
||||
|
||||
llama_sampler_free(backend_sampler_chain);
|
||||
};
|
||||
|
||||
test_argmax_temp(0.0f);
|
||||
|
|
@ -491,11 +477,11 @@ static void test_backend_temp_ext_sampling(const char * model_path) {
|
|||
const float delta = 0.5f;
|
||||
const float exponent = 1.5f;
|
||||
struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params);
|
||||
llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_temp_ext(temp, delta, exponent));
|
||||
llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
|
||||
llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_temp_ext(temp, delta, exponent));
|
||||
|
||||
std::vector<llama_sampler_seq_config> backend_sampler_configs = {
|
||||
{ seq_id, backend_sampler_chain },
|
||||
{ seq_id, backend_sampler_chain.get() },
|
||||
};
|
||||
|
||||
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
|
||||
|
|
@ -512,8 +498,6 @@ static void test_backend_temp_ext_sampling(const char * model_path) {
|
|||
int n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx);
|
||||
GGML_ASSERT(n_logits == test_ctx.n_vocab);
|
||||
}
|
||||
|
||||
llama_sampler_free(backend_sampler_chain);
|
||||
}
|
||||
|
||||
test_ctx.reset();
|
||||
|
|
@ -526,11 +510,11 @@ static void test_backend_temp_ext_sampling(const char * model_path) {
|
|||
|
||||
int seq_id = 0;
|
||||
struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params);
|
||||
llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_temp_ext(temp, delta, exponent));
|
||||
llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
|
||||
llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_temp_ext(temp, delta, exponent));
|
||||
|
||||
std::vector<llama_sampler_seq_config> backend_sampler_configs = {
|
||||
{ seq_id, backend_sampler_chain },
|
||||
{ seq_id, backend_sampler_chain.get() },
|
||||
};
|
||||
|
||||
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
|
||||
|
|
@ -550,8 +534,6 @@ static void test_backend_temp_ext_sampling(const char * model_path) {
|
|||
} else {
|
||||
GGML_ASSERT(n_logits == (uint32_t) test_ctx.n_vocab);
|
||||
}
|
||||
|
||||
llama_sampler_free(backend_sampler_chain);
|
||||
};
|
||||
|
||||
test_argmax_temp(0.0f, 0.3f, 1.0f); // Greedy (temp=0)
|
||||
|
|
@ -568,9 +550,9 @@ static void test_backend_min_p_sampling(const char * model_path) {
|
|||
const int seq_id = 0;
|
||||
const float p = 0.1;
|
||||
struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params);
|
||||
llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_min_p(p, 0));
|
||||
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain }};
|
||||
llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
|
||||
llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_min_p(p, 0));
|
||||
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
|
||||
|
||||
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
|
||||
return;
|
||||
|
|
@ -597,10 +579,10 @@ static void test_backend_min_p_sampling(const char * model_path) {
|
|||
|
||||
// Sample using CPU sampler for verification to inspect they are reasonable
|
||||
struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * chain = llama_sampler_chain_init(chain_params);
|
||||
llama_sampler_chain_add(chain, llama_sampler_init_dist(88));
|
||||
llama_sampler_ptr chain(llama_sampler_chain_init(chain_params));
|
||||
llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(88));
|
||||
|
||||
llama_token token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx);
|
||||
llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx, batch_idx);
|
||||
const std::string token_str = test_ctx.token_to_piece(token, false);
|
||||
printf("min-p cpu sampled token id:%d, string: '%s'\n", token, token_str.c_str());
|
||||
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
|
||||
|
|
@ -608,7 +590,7 @@ static void test_backend_min_p_sampling(const char * model_path) {
|
|||
// Decode and sampler 10 more tokens
|
||||
for (int i = 0; i < 10; i++) {
|
||||
int32_t loop_idx = test_ctx.idx_for_seq(seq_id);
|
||||
llama_token token = llama_sampler_sample(chain, test_ctx.ctx, loop_idx);
|
||||
llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx, loop_idx);
|
||||
printf("min-p gen step %d: token id :%5.d, string: %s\n", i, token, test_ctx.token_to_piece(token, false).c_str());
|
||||
if (!test_ctx.decode_token(token, 0)) {
|
||||
GGML_ASSERT(false && "Failed to decode token");
|
||||
|
|
@ -616,9 +598,6 @@ static void test_backend_min_p_sampling(const char * model_path) {
|
|||
}
|
||||
|
||||
printf("min-p sampling test PASSED\n");
|
||||
|
||||
llama_sampler_free(backend_sampler_chain);
|
||||
llama_sampler_free(chain);
|
||||
}
|
||||
|
||||
static void test_backend_top_p_sampling(const char * model_path) {
|
||||
|
|
@ -627,9 +606,9 @@ static void test_backend_top_p_sampling(const char * model_path) {
|
|||
const int seq_id = 0;
|
||||
const float p = 0.9;
|
||||
struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params);
|
||||
llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_top_p(p, 0));
|
||||
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain }};
|
||||
llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
|
||||
llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_top_p(p, 0));
|
||||
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
|
||||
|
||||
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
|
||||
return;
|
||||
|
|
@ -656,10 +635,10 @@ static void test_backend_top_p_sampling(const char * model_path) {
|
|||
|
||||
// Sample using CPU sampler for verification to inspect they are reasonable
|
||||
struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * chain = llama_sampler_chain_init(chain_params);
|
||||
llama_sampler_chain_add(chain, llama_sampler_init_dist(88));
|
||||
llama_sampler_ptr chain(llama_sampler_chain_init(chain_params));
|
||||
llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(88));
|
||||
|
||||
llama_token token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx);
|
||||
llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx, batch_idx);
|
||||
const std::string token_str = test_ctx.token_to_piece(token, false);
|
||||
printf("top-p cpu sampled token id:%d, string: '%s'\n", token, token_str.c_str());
|
||||
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
|
||||
|
|
@ -667,32 +646,29 @@ static void test_backend_top_p_sampling(const char * model_path) {
|
|||
// Decode and sampler 10 more tokens
|
||||
for (int i = 0; i < 10; i++) {
|
||||
int32_t loop_idx = test_ctx.idx_for_seq(seq_id);
|
||||
llama_token token = llama_sampler_sample(chain, test_ctx.ctx, loop_idx);
|
||||
llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx, loop_idx);
|
||||
printf("top-p gen step %d: token id :%5.d, string: %s\n", i, token, test_ctx.token_to_piece(token, false).c_str());
|
||||
test_ctx.decode_token(token, 0);
|
||||
}
|
||||
|
||||
printf("top-p sampling test PASSED\n");
|
||||
|
||||
llama_sampler_free(backend_sampler_chain);
|
||||
llama_sampler_free(chain);
|
||||
}
|
||||
|
||||
static void test_backend_multi_sequence_sampling(const char * model_path) {
|
||||
test_model_context test_ctx;
|
||||
|
||||
struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * sampler_chain_0 = llama_sampler_chain_init(chain_params_0);
|
||||
llama_sampler_chain_add(sampler_chain_0, llama_sampler_init_greedy());
|
||||
llama_sampler_ptr sampler_chain_0(llama_sampler_chain_init(chain_params_0));
|
||||
llama_sampler_chain_add(sampler_chain_0.get(), llama_sampler_init_greedy());
|
||||
|
||||
struct llama_sampler_chain_params chain_params_1 = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * sampler_chain_1 = llama_sampler_chain_init(chain_params_1);
|
||||
llama_sampler_chain_add(sampler_chain_1, llama_sampler_init_temp(0.8f));
|
||||
llama_sampler_chain_add(sampler_chain_1, llama_sampler_init_greedy());
|
||||
llama_sampler_ptr sampler_chain_1(llama_sampler_chain_init(chain_params_1));
|
||||
llama_sampler_chain_add(sampler_chain_1.get(), llama_sampler_init_temp(0.8f));
|
||||
llama_sampler_chain_add(sampler_chain_1.get(), llama_sampler_init_greedy());
|
||||
|
||||
std::vector<llama_sampler_seq_config> backend_sampler_configs = {
|
||||
{ 0, sampler_chain_0 },
|
||||
{ 1, sampler_chain_1 }
|
||||
{ 0, sampler_chain_0.get() },
|
||||
{ 1, sampler_chain_1.get() }
|
||||
};
|
||||
|
||||
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
|
||||
|
|
@ -745,9 +721,6 @@ static void test_backend_multi_sequence_sampling(const char * model_path) {
|
|||
}
|
||||
}
|
||||
|
||||
llama_sampler_free(sampler_chain_0);
|
||||
llama_sampler_free(sampler_chain_1);
|
||||
|
||||
printf("backend multi-sequence sampling test PASSED\n");
|
||||
}
|
||||
|
||||
|
|
@ -757,9 +730,9 @@ static void test_backend_dist_sampling(const char * model_path) {
|
|||
const int seq_id = 189;
|
||||
const int32_t seed = 88;
|
||||
struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params);
|
||||
llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_dist(seed));
|
||||
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain }};
|
||||
llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
|
||||
llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(seed));
|
||||
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
|
||||
|
||||
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
|
||||
return;
|
||||
|
|
@ -779,8 +752,6 @@ static void test_backend_dist_sampling(const char * model_path) {
|
|||
printf("dist sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str());
|
||||
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
|
||||
|
||||
llama_sampler_free(backend_sampler_chain);
|
||||
|
||||
printf("backend dist sampling test PASSED\n");
|
||||
}
|
||||
|
||||
|
|
@ -790,9 +761,9 @@ static void test_backend_dist_sampling_and_cpu(const char * model_path) {
|
|||
const int seq_id = 0;
|
||||
const int32_t seed = 88;
|
||||
struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params);
|
||||
llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_dist(seed));
|
||||
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain }};
|
||||
llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
|
||||
llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(seed));
|
||||
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
|
||||
|
||||
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
|
||||
return;
|
||||
|
|
@ -806,17 +777,14 @@ static void test_backend_dist_sampling_and_cpu(const char * model_path) {
|
|||
|
||||
// Sample using CPU sampler
|
||||
struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * chain = llama_sampler_chain_init(chain_params);
|
||||
llama_sampler_chain_add(chain, llama_sampler_init_dist(18));
|
||||
llama_sampler_ptr chain(llama_sampler_chain_init(chain_params));
|
||||
llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(18));
|
||||
|
||||
llama_token backend_token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx);
|
||||
llama_token cpu_token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx);
|
||||
llama_token cpu_token = llama_sampler_sample(chain.get(), test_ctx.ctx, batch_idx);
|
||||
printf("dist & cpu sampled id:%d, string:'%s'\n", cpu_token, test_ctx.token_to_piece(cpu_token, false).c_str());
|
||||
GGML_ASSERT(backend_token == cpu_token);
|
||||
|
||||
llama_sampler_free(backend_sampler_chain);
|
||||
llama_sampler_free(chain);
|
||||
|
||||
printf("backend dist & cpu sampling test PASSED\n");
|
||||
}
|
||||
|
||||
|
|
@ -842,15 +810,15 @@ static void test_backend_logit_bias_sampling(const char * model_path) {
|
|||
printf("biasing token piece '%s' -> token id %d\n", piece.c_str(), bias_token);
|
||||
|
||||
struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params);
|
||||
llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_logit_bias(
|
||||
llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
|
||||
llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_logit_bias(
|
||||
llama_vocab_n_tokens(test_ctx.vocab),
|
||||
logit_bias.size(),
|
||||
logit_bias.data()));
|
||||
llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_dist(88));
|
||||
llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(88));
|
||||
|
||||
std::vector<llama_sampler_seq_config> backend_sampler_configs = {
|
||||
{ seq_id, backend_sampler_chain },
|
||||
{ seq_id, backend_sampler_chain.get() },
|
||||
};
|
||||
|
||||
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
|
||||
|
|
@ -867,8 +835,6 @@ static void test_backend_logit_bias_sampling(const char * model_path) {
|
|||
GGML_ASSERT(backend_token == bias_token);
|
||||
|
||||
printf("backend logit bias sampling test PASSED\n");
|
||||
|
||||
llama_sampler_free(backend_sampler_chain);
|
||||
}
|
||||
|
||||
// This test verifies that it is possible to have two different backend sampler,
|
||||
|
|
@ -877,17 +843,17 @@ static void test_backend_mixed_sampling(const char * model_path) {
|
|||
test_model_context test_ctx;
|
||||
|
||||
struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * sampler_chain_0 = llama_sampler_chain_init(chain_params_0);
|
||||
llama_sampler_chain_add(sampler_chain_0, llama_sampler_init_dist(88));
|
||||
llama_sampler_ptr sampler_chain_0(llama_sampler_chain_init(chain_params_0));
|
||||
llama_sampler_chain_add(sampler_chain_0.get(), llama_sampler_init_dist(88));
|
||||
|
||||
int k = 40;
|
||||
struct llama_sampler_chain_params chain_params_1 = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * sampler_chain_1 = llama_sampler_chain_init(chain_params_1);
|
||||
llama_sampler_chain_add(sampler_chain_1, llama_sampler_init_top_k(k));
|
||||
llama_sampler_ptr sampler_chain_1(llama_sampler_chain_init(chain_params_1));
|
||||
llama_sampler_chain_add(sampler_chain_1.get(), llama_sampler_init_top_k(k));
|
||||
|
||||
std::vector<llama_sampler_seq_config> backend_sampler_configs = {
|
||||
{ 0, sampler_chain_0 },
|
||||
{ 1, sampler_chain_1 }
|
||||
{ 0, sampler_chain_0.get() },
|
||||
{ 1, sampler_chain_1.get() }
|
||||
};
|
||||
|
||||
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
|
||||
|
|
@ -924,9 +890,6 @@ static void test_backend_mixed_sampling(const char * model_path) {
|
|||
GGML_ASSERT(llama_get_sampled_token_ith(test_ctx.ctx, batch_idx) == LLAMA_TOKEN_NULL);
|
||||
}
|
||||
|
||||
llama_sampler_free(sampler_chain_0);
|
||||
llama_sampler_free(sampler_chain_1);
|
||||
|
||||
printf("backend mixed sampling test PASSED\n");
|
||||
}
|
||||
|
||||
|
|
@ -936,9 +899,9 @@ static void test_backend_set_sampler(const char * model_path) {
|
|||
const int32_t seed = 88;
|
||||
const int seq_id = 0;
|
||||
struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params);
|
||||
llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_dist(seed));
|
||||
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain }};
|
||||
llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
|
||||
llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(seed));
|
||||
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
|
||||
|
||||
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
|
||||
return;
|
||||
|
|
@ -961,8 +924,8 @@ static void test_backend_set_sampler(const char * model_path) {
|
|||
|
||||
// Sample using CPU sampler
|
||||
struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * chain = llama_sampler_chain_init(chain_params);
|
||||
llama_sampler_chain_add(chain, llama_sampler_init_dist(18));
|
||||
llama_sampler_ptr chain(llama_sampler_chain_init(chain_params));
|
||||
llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(18));
|
||||
|
||||
std::map<llama_seq_id, llama_token> tokens = { { seq_id, backend_token}, };
|
||||
if (!test_ctx.decode_tokens(tokens)) {
|
||||
|
|
@ -975,17 +938,17 @@ static void test_backend_set_sampler(const char * model_path) {
|
|||
GGML_ASSERT(llama_get_sampled_probs_ith(test_ctx.ctx, idx) == nullptr);
|
||||
|
||||
// Sample the token using the CPU sampler chain.
|
||||
llama_token token2 = llama_sampler_sample(chain, test_ctx.ctx, seq_id);
|
||||
llama_token token2 = llama_sampler_sample(chain.get(), test_ctx.ctx, seq_id);
|
||||
const std::string token2_str = test_ctx.token_to_piece(token2, false);
|
||||
printf("CPU sampled token after clearing backend sampler: id=%d, string='%s'\n", token2, token2_str.c_str());
|
||||
std::map<llama_seq_id, llama_token> tokens2 = { { seq_id, token2}, };
|
||||
|
||||
// Set a new backend sampler for the sequence.
|
||||
struct llama_sampler_chain_params new_backend_chain_params = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * new_backend_sampler_chain = llama_sampler_chain_init(new_backend_chain_params);
|
||||
llama_sampler_chain_add(new_backend_sampler_chain, llama_sampler_init_top_k(20));
|
||||
llama_sampler_chain_add(new_backend_sampler_chain, llama_sampler_init_dist(seed));
|
||||
llama_set_sampler(test_ctx.ctx, seq_id, new_backend_sampler_chain);
|
||||
llama_sampler_ptr new_backend_sampler_chain(llama_sampler_chain_init(new_backend_chain_params));
|
||||
llama_sampler_chain_add(new_backend_sampler_chain.get(), llama_sampler_init_top_k(20));
|
||||
llama_sampler_chain_add(new_backend_sampler_chain.get(), llama_sampler_init_dist(seed));
|
||||
llama_set_sampler(test_ctx.ctx, seq_id, new_backend_sampler_chain.get());
|
||||
|
||||
if (!test_ctx.decode_tokens(tokens2)) {
|
||||
GGML_ASSERT(false && "Failed to decode token");
|
||||
|
|
@ -995,10 +958,6 @@ static void test_backend_set_sampler(const char * model_path) {
|
|||
const std::string new_backend_token_str = test_ctx.token_to_piece(new_backend_token, false);
|
||||
printf("dist sampled token = %d, string='%s'\n", new_backend_token, new_backend_token_str.c_str());
|
||||
|
||||
llama_sampler_free(backend_sampler_chain);
|
||||
llama_sampler_free(chain);
|
||||
llama_sampler_free(new_backend_sampler_chain);
|
||||
|
||||
printf("backend set sampler test PASSED\n");
|
||||
}
|
||||
|
||||
|
|
@ -1007,11 +966,11 @@ static void test_backend_cpu_mixed_batch(const char * model_path) {
|
|||
|
||||
// Sequence 0 uses backend sampling
|
||||
struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * sampler_chain_0 = llama_sampler_chain_init(chain_params_0);
|
||||
llama_sampler_chain_add(sampler_chain_0, llama_sampler_init_dist(88));
|
||||
llama_sampler_ptr sampler_chain_0(llama_sampler_chain_init(chain_params_0));
|
||||
llama_sampler_chain_add(sampler_chain_0.get(), llama_sampler_init_dist(88));
|
||||
|
||||
std::vector<llama_sampler_seq_config> backend_sampler_configs = {
|
||||
{ 0, sampler_chain_0 },
|
||||
{ 0, sampler_chain_0.get() },
|
||||
};
|
||||
|
||||
// We need 2 sequences: seq 0 with backend sampling, seq 1 with CPU sampling
|
||||
|
|
@ -1045,14 +1004,13 @@ static void test_backend_cpu_mixed_batch(const char * model_path) {
|
|||
GGML_ASSERT(backend_token == LLAMA_TOKEN_NULL);
|
||||
|
||||
struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * chain = llama_sampler_chain_init(chain_params);
|
||||
llama_sampler_chain_add(chain, llama_sampler_init_greedy());
|
||||
llama_sampler_ptr chain(llama_sampler_chain_init(chain_params));
|
||||
llama_sampler_chain_add(chain.get(), llama_sampler_init_greedy());
|
||||
|
||||
llama_token token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx);
|
||||
llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx, batch_idx);
|
||||
const std::string token_str = test_ctx.token_to_piece(token, false);
|
||||
printf("Seq 1 (CPU) sampled token id=%d, string='%s'\n", token, token_str.c_str());
|
||||
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
|
||||
llama_sampler_free(chain);
|
||||
}
|
||||
|
||||
// Clear/remove the backend sampler, and sample again
|
||||
|
|
@ -1063,25 +1021,23 @@ static void test_backend_cpu_mixed_batch(const char * model_path) {
|
|||
|
||||
// Create a CPU sampler and verify we can sampler from it.
|
||||
struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * chain = llama_sampler_chain_init(chain_params);
|
||||
llama_sampler_chain_add(chain, llama_sampler_init_greedy());
|
||||
llama_sampler_ptr chain(llama_sampler_chain_init(chain_params));
|
||||
llama_sampler_chain_add(chain.get(), llama_sampler_init_greedy());
|
||||
|
||||
int32_t batch_idx = test_ctx.idx_for_seq(1);
|
||||
llama_token token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx);
|
||||
llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx, batch_idx);
|
||||
if (!test_ctx.decode_token(token, 1)) {
|
||||
GGML_ASSERT(false && "Failed to decode token");
|
||||
}
|
||||
|
||||
llama_sampler_free(chain);
|
||||
}
|
||||
|
||||
// Set a backend sampler so that we can verify that it can be reset
|
||||
{
|
||||
struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * sampler_chain = llama_sampler_chain_init(chain_params);
|
||||
llama_sampler_chain_add(sampler_chain, llama_sampler_init_dist(88));
|
||||
llama_sampler_ptr sampler_chain(llama_sampler_chain_init(chain_params));
|
||||
llama_sampler_chain_add(sampler_chain.get(), llama_sampler_init_dist(88));
|
||||
|
||||
llama_set_sampler(test_ctx.ctx, 0, sampler_chain);
|
||||
llama_set_sampler(test_ctx.ctx, 0, sampler_chain.get());
|
||||
|
||||
if (!test_ctx.decode_token(3834, 0)) {
|
||||
GGML_ASSERT(false && "Failed to decode token");
|
||||
|
|
@ -1092,12 +1048,8 @@ static void test_backend_cpu_mixed_batch(const char * model_path) {
|
|||
const std::string token_str = test_ctx.token_to_piece(token, false);
|
||||
printf("re-added backend sampled token id=%d, string='%s'\n", token, token_str.c_str());
|
||||
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
|
||||
|
||||
llama_sampler_free(sampler_chain);
|
||||
}
|
||||
|
||||
llama_sampler_free(sampler_chain_0);
|
||||
|
||||
printf("backend-cpu mixed batch test PASSED\n");
|
||||
}
|
||||
|
||||
|
|
@ -1107,9 +1059,9 @@ static void test_backend_max_outputs(const char * model_path) {
|
|||
const int seq_id = 0;
|
||||
const int32_t seed = 88;
|
||||
llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
|
||||
llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params);
|
||||
llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_dist(seed));
|
||||
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain }};
|
||||
llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
|
||||
llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(seed));
|
||||
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
|
||||
|
||||
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
|
||||
return;
|
||||
|
|
@ -1140,7 +1092,6 @@ static void test_backend_max_outputs(const char * model_path) {
|
|||
printf("<<< test_max_outputs expected error end.\n");
|
||||
llama_batch_free(batch);
|
||||
|
||||
llama_sampler_free(backend_sampler_chain);
|
||||
printf("backend max outputs test PASSED\n");
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue