diff --git a/tests/test-backend-sampler.cpp b/tests/test-backend-sampler.cpp index b3f202771a..cfe1ba1703 100644 --- a/tests/test-backend-sampler.cpp +++ b/tests/test-backend-sampler.cpp @@ -1,5 +1,6 @@ #include "ggml.h" #include "llama.h" +#include "llama-cpp.h" #include "get-model.h" #include "common.h" @@ -290,10 +291,10 @@ static void test_backend_greedy_sampling(const char * model_path) { const int seq_id = 0; struct llama_sampler_chain_params backend_sampler_params = llama_sampler_chain_default_params(); - struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_sampler_params); + llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_sampler_params)); - llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_greedy()); - std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain }}; + llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_greedy()); + std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }}; if (!test_ctx.setup(model_path, backend_sampler_configs)) { return; @@ -321,8 +322,6 @@ static void test_backend_greedy_sampling(const char * model_path) { GGML_ASSERT(false && "Failed to decode token"); } } - - llama_sampler_free(backend_sampler_chain); } static void test_backend_top_k_sampling(const char * model_path) { @@ -331,9 +330,9 @@ static void test_backend_top_k_sampling(const char * model_path) { const int seq_id = 0; const int32_t k = 8; struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); - struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params); - llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_top_k(k)); - std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain }}; + llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params)); + llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_top_k(k)); + std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }}; if (!test_ctx.setup(model_path, backend_sampler_configs)) { return; @@ -358,22 +357,18 @@ static void test_backend_top_k_sampling(const char * model_path) { test_ctx.token_to_piece(candidates[i], false).c_str()); } - llama_sampler_free(backend_sampler_chain); - // Sample using CPU sampler for verification that it is possible to do hybrid // sampling, first top_k on the backend and then dist on the CPU. struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); - struct llama_sampler * chain = llama_sampler_chain_init(chain_params); + llama_sampler_ptr chain(llama_sampler_chain_init(chain_params)); GGML_ASSERT(chain->iface->backend_apply != nullptr); - llama_sampler_chain_add(chain, llama_sampler_init_dist(18)); - llama_token token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx); + llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(18)); + llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx, batch_idx); const std::string token_str = test_ctx.token_to_piece(token, false); GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); printf("backend top-k hybrid sampling test PASSED\n"); - - llama_sampler_free(chain); } static void test_backend_temp_sampling(const char * model_path) { @@ -382,17 +377,17 @@ static void test_backend_temp_sampling(const char * model_path) { { const float temp_0 = 0.8f; struct llama_sampler_chain_params backend_chain_params_0 = llama_sampler_chain_default_params(); - struct llama_sampler * backend_sampler_chain_0 = llama_sampler_chain_init(backend_chain_params_0); - llama_sampler_chain_add(backend_sampler_chain_0, llama_sampler_init_temp(temp_0)); + llama_sampler_ptr backend_sampler_chain_0(llama_sampler_chain_init(backend_chain_params_0)); + llama_sampler_chain_add(backend_sampler_chain_0.get(), llama_sampler_init_temp(temp_0)); const float temp_1 = 0.1f; struct llama_sampler_chain_params backend_chain_params_1 = llama_sampler_chain_default_params(); - struct llama_sampler * backend_sampler_chain_1 = llama_sampler_chain_init(backend_chain_params_1); - llama_sampler_chain_add(backend_sampler_chain_1, llama_sampler_init_temp(temp_1)); + llama_sampler_ptr backend_sampler_chain_1(llama_sampler_chain_init(backend_chain_params_1)); + llama_sampler_chain_add(backend_sampler_chain_1.get(), llama_sampler_init_temp(temp_1)); std::vector backend_sampler_configs = { - { 0, backend_sampler_chain_0 }, - { 1, backend_sampler_chain_1 } + { 0, backend_sampler_chain_0.get() }, + { 1, backend_sampler_chain_1.get() } }; if (!test_ctx.setup(model_path, backend_sampler_configs)) { @@ -403,9 +398,6 @@ static void test_backend_temp_sampling(const char * model_path) { GGML_ASSERT(false && "Failed to decode token"); } - llama_sampler_free(backend_sampler_chain_0); - llama_sampler_free(backend_sampler_chain_1); - // Verfify sequence 0 { int32_t batch_idx = test_ctx.idx_for_seq(0); @@ -414,15 +406,13 @@ static void test_backend_temp_sampling(const char * model_path) { // Sample from sequence 0 using CPU sampler struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); - struct llama_sampler * chain = llama_sampler_chain_init(chain_params); - llama_sampler_chain_add(chain, llama_sampler_init_dist(18)); + llama_sampler_ptr chain(llama_sampler_chain_init(chain_params)); + llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(18)); - llama_token token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx); + llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx, batch_idx); const std::string token_str = test_ctx.token_to_piece(token, false); printf("Sequence 0 sampled token id:%d, string: '%s'\n", token, token_str.c_str()); GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); - - llama_sampler_free(chain); } @@ -432,15 +422,13 @@ static void test_backend_temp_sampling(const char * model_path) { // Sample from sequence 1 using CPU sampler struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); - struct llama_sampler * chain = llama_sampler_chain_init(chain_params); - llama_sampler_chain_add(chain, llama_sampler_init_dist(18)); + llama_sampler_ptr chain(llama_sampler_chain_init(chain_params)); + llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(18)); - llama_token token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx); + llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx, batch_idx); const std::string token_str = test_ctx.token_to_piece(token, false); printf("Sequence 1 sampled token id:%d, string: '%s'\n", token, token_str.c_str()); GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); - - llama_sampler_free(chain); } } @@ -452,11 +440,11 @@ static void test_backend_temp_sampling(const char * model_path) { int seq_id = 0; struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); - struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params); - llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_temp(temp)); + llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params)); + llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_temp(temp)); std::vector backend_sampler_configs = { - { seq_id, backend_sampler_chain }, + { seq_id, backend_sampler_chain.get() }, }; if (!test_ctx.setup(model_path, backend_sampler_configs)) { @@ -471,8 +459,6 @@ static void test_backend_temp_sampling(const char * model_path) { uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx); GGML_ASSERT(n_logits == 1); - - llama_sampler_free(backend_sampler_chain); }; test_argmax_temp(0.0f); @@ -491,11 +477,11 @@ static void test_backend_temp_ext_sampling(const char * model_path) { const float delta = 0.5f; const float exponent = 1.5f; struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); - struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params); - llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_temp_ext(temp, delta, exponent)); + llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params)); + llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_temp_ext(temp, delta, exponent)); std::vector backend_sampler_configs = { - { seq_id, backend_sampler_chain }, + { seq_id, backend_sampler_chain.get() }, }; if (!test_ctx.setup(model_path, backend_sampler_configs)) { @@ -512,8 +498,6 @@ static void test_backend_temp_ext_sampling(const char * model_path) { int n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx); GGML_ASSERT(n_logits == test_ctx.n_vocab); } - - llama_sampler_free(backend_sampler_chain); } test_ctx.reset(); @@ -526,11 +510,11 @@ static void test_backend_temp_ext_sampling(const char * model_path) { int seq_id = 0; struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); - struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params); - llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_temp_ext(temp, delta, exponent)); + llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params)); + llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_temp_ext(temp, delta, exponent)); std::vector backend_sampler_configs = { - { seq_id, backend_sampler_chain }, + { seq_id, backend_sampler_chain.get() }, }; if (!test_ctx.setup(model_path, backend_sampler_configs)) { @@ -550,8 +534,6 @@ static void test_backend_temp_ext_sampling(const char * model_path) { } else { GGML_ASSERT(n_logits == (uint32_t) test_ctx.n_vocab); } - - llama_sampler_free(backend_sampler_chain); }; test_argmax_temp(0.0f, 0.3f, 1.0f); // Greedy (temp=0) @@ -568,9 +550,9 @@ static void test_backend_min_p_sampling(const char * model_path) { const int seq_id = 0; const float p = 0.1; struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); - struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params); - llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_min_p(p, 0)); - std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain }}; + llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params)); + llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_min_p(p, 0)); + std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }}; if (!test_ctx.setup(model_path, backend_sampler_configs)) { return; @@ -597,10 +579,10 @@ static void test_backend_min_p_sampling(const char * model_path) { // Sample using CPU sampler for verification to inspect they are reasonable struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); - struct llama_sampler * chain = llama_sampler_chain_init(chain_params); - llama_sampler_chain_add(chain, llama_sampler_init_dist(88)); + llama_sampler_ptr chain(llama_sampler_chain_init(chain_params)); + llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(88)); - llama_token token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx); + llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx, batch_idx); const std::string token_str = test_ctx.token_to_piece(token, false); printf("min-p cpu sampled token id:%d, string: '%s'\n", token, token_str.c_str()); GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); @@ -608,7 +590,7 @@ static void test_backend_min_p_sampling(const char * model_path) { // Decode and sampler 10 more tokens for (int i = 0; i < 10; i++) { int32_t loop_idx = test_ctx.idx_for_seq(seq_id); - llama_token token = llama_sampler_sample(chain, test_ctx.ctx, loop_idx); + llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx, loop_idx); printf("min-p gen step %d: token id :%5.d, string: %s\n", i, token, test_ctx.token_to_piece(token, false).c_str()); if (!test_ctx.decode_token(token, 0)) { GGML_ASSERT(false && "Failed to decode token"); @@ -616,9 +598,6 @@ static void test_backend_min_p_sampling(const char * model_path) { } printf("min-p sampling test PASSED\n"); - - llama_sampler_free(backend_sampler_chain); - llama_sampler_free(chain); } static void test_backend_top_p_sampling(const char * model_path) { @@ -627,9 +606,9 @@ static void test_backend_top_p_sampling(const char * model_path) { const int seq_id = 0; const float p = 0.9; struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); - struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params); - llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_top_p(p, 0)); - std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain }}; + llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params)); + llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_top_p(p, 0)); + std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }}; if (!test_ctx.setup(model_path, backend_sampler_configs)) { return; @@ -656,10 +635,10 @@ static void test_backend_top_p_sampling(const char * model_path) { // Sample using CPU sampler for verification to inspect they are reasonable struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); - struct llama_sampler * chain = llama_sampler_chain_init(chain_params); - llama_sampler_chain_add(chain, llama_sampler_init_dist(88)); + llama_sampler_ptr chain(llama_sampler_chain_init(chain_params)); + llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(88)); - llama_token token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx); + llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx, batch_idx); const std::string token_str = test_ctx.token_to_piece(token, false); printf("top-p cpu sampled token id:%d, string: '%s'\n", token, token_str.c_str()); GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); @@ -667,32 +646,29 @@ static void test_backend_top_p_sampling(const char * model_path) { // Decode and sampler 10 more tokens for (int i = 0; i < 10; i++) { int32_t loop_idx = test_ctx.idx_for_seq(seq_id); - llama_token token = llama_sampler_sample(chain, test_ctx.ctx, loop_idx); + llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx, loop_idx); printf("top-p gen step %d: token id :%5.d, string: %s\n", i, token, test_ctx.token_to_piece(token, false).c_str()); test_ctx.decode_token(token, 0); } printf("top-p sampling test PASSED\n"); - - llama_sampler_free(backend_sampler_chain); - llama_sampler_free(chain); } static void test_backend_multi_sequence_sampling(const char * model_path) { test_model_context test_ctx; struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params(); - struct llama_sampler * sampler_chain_0 = llama_sampler_chain_init(chain_params_0); - llama_sampler_chain_add(sampler_chain_0, llama_sampler_init_greedy()); + llama_sampler_ptr sampler_chain_0(llama_sampler_chain_init(chain_params_0)); + llama_sampler_chain_add(sampler_chain_0.get(), llama_sampler_init_greedy()); struct llama_sampler_chain_params chain_params_1 = llama_sampler_chain_default_params(); - struct llama_sampler * sampler_chain_1 = llama_sampler_chain_init(chain_params_1); - llama_sampler_chain_add(sampler_chain_1, llama_sampler_init_temp(0.8f)); - llama_sampler_chain_add(sampler_chain_1, llama_sampler_init_greedy()); + llama_sampler_ptr sampler_chain_1(llama_sampler_chain_init(chain_params_1)); + llama_sampler_chain_add(sampler_chain_1.get(), llama_sampler_init_temp(0.8f)); + llama_sampler_chain_add(sampler_chain_1.get(), llama_sampler_init_greedy()); std::vector backend_sampler_configs = { - { 0, sampler_chain_0 }, - { 1, sampler_chain_1 } + { 0, sampler_chain_0.get() }, + { 1, sampler_chain_1.get() } }; if (!test_ctx.setup(model_path, backend_sampler_configs)) { @@ -745,9 +721,6 @@ static void test_backend_multi_sequence_sampling(const char * model_path) { } } - llama_sampler_free(sampler_chain_0); - llama_sampler_free(sampler_chain_1); - printf("backend multi-sequence sampling test PASSED\n"); } @@ -757,9 +730,9 @@ static void test_backend_dist_sampling(const char * model_path) { const int seq_id = 189; const int32_t seed = 88; struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); - struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params); - llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_dist(seed)); - std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain }}; + llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params)); + llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(seed)); + std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }}; if (!test_ctx.setup(model_path, backend_sampler_configs)) { return; @@ -779,8 +752,6 @@ static void test_backend_dist_sampling(const char * model_path) { printf("dist sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str()); GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); - llama_sampler_free(backend_sampler_chain); - printf("backend dist sampling test PASSED\n"); } @@ -790,9 +761,9 @@ static void test_backend_dist_sampling_and_cpu(const char * model_path) { const int seq_id = 0; const int32_t seed = 88; struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); - struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params); - llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_dist(seed)); - std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain }}; + llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params)); + llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(seed)); + std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }}; if (!test_ctx.setup(model_path, backend_sampler_configs)) { return; @@ -806,17 +777,14 @@ static void test_backend_dist_sampling_and_cpu(const char * model_path) { // Sample using CPU sampler struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); - struct llama_sampler * chain = llama_sampler_chain_init(chain_params); - llama_sampler_chain_add(chain, llama_sampler_init_dist(18)); + llama_sampler_ptr chain(llama_sampler_chain_init(chain_params)); + llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(18)); llama_token backend_token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx); - llama_token cpu_token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx); + llama_token cpu_token = llama_sampler_sample(chain.get(), test_ctx.ctx, batch_idx); printf("dist & cpu sampled id:%d, string:'%s'\n", cpu_token, test_ctx.token_to_piece(cpu_token, false).c_str()); GGML_ASSERT(backend_token == cpu_token); - llama_sampler_free(backend_sampler_chain); - llama_sampler_free(chain); - printf("backend dist & cpu sampling test PASSED\n"); } @@ -842,15 +810,15 @@ static void test_backend_logit_bias_sampling(const char * model_path) { printf("biasing token piece '%s' -> token id %d\n", piece.c_str(), bias_token); struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); - struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params); - llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_logit_bias( + llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params)); + llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_logit_bias( llama_vocab_n_tokens(test_ctx.vocab), logit_bias.size(), logit_bias.data())); - llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_dist(88)); + llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(88)); std::vector backend_sampler_configs = { - { seq_id, backend_sampler_chain }, + { seq_id, backend_sampler_chain.get() }, }; if (!test_ctx.setup(model_path, backend_sampler_configs)) { @@ -867,8 +835,6 @@ static void test_backend_logit_bias_sampling(const char * model_path) { GGML_ASSERT(backend_token == bias_token); printf("backend logit bias sampling test PASSED\n"); - - llama_sampler_free(backend_sampler_chain); } // This test verifies that it is possible to have two different backend sampler, @@ -877,17 +843,17 @@ static void test_backend_mixed_sampling(const char * model_path) { test_model_context test_ctx; struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params(); - struct llama_sampler * sampler_chain_0 = llama_sampler_chain_init(chain_params_0); - llama_sampler_chain_add(sampler_chain_0, llama_sampler_init_dist(88)); + llama_sampler_ptr sampler_chain_0(llama_sampler_chain_init(chain_params_0)); + llama_sampler_chain_add(sampler_chain_0.get(), llama_sampler_init_dist(88)); int k = 40; struct llama_sampler_chain_params chain_params_1 = llama_sampler_chain_default_params(); - struct llama_sampler * sampler_chain_1 = llama_sampler_chain_init(chain_params_1); - llama_sampler_chain_add(sampler_chain_1, llama_sampler_init_top_k(k)); + llama_sampler_ptr sampler_chain_1(llama_sampler_chain_init(chain_params_1)); + llama_sampler_chain_add(sampler_chain_1.get(), llama_sampler_init_top_k(k)); std::vector backend_sampler_configs = { - { 0, sampler_chain_0 }, - { 1, sampler_chain_1 } + { 0, sampler_chain_0.get() }, + { 1, sampler_chain_1.get() } }; if (!test_ctx.setup(model_path, backend_sampler_configs)) { @@ -924,9 +890,6 @@ static void test_backend_mixed_sampling(const char * model_path) { GGML_ASSERT(llama_get_sampled_token_ith(test_ctx.ctx, batch_idx) == LLAMA_TOKEN_NULL); } - llama_sampler_free(sampler_chain_0); - llama_sampler_free(sampler_chain_1); - printf("backend mixed sampling test PASSED\n"); } @@ -936,9 +899,9 @@ static void test_backend_set_sampler(const char * model_path) { const int32_t seed = 88; const int seq_id = 0; struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); - struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params); - llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_dist(seed)); - std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain }}; + llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params)); + llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(seed)); + std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }}; if (!test_ctx.setup(model_path, backend_sampler_configs)) { return; @@ -961,8 +924,8 @@ static void test_backend_set_sampler(const char * model_path) { // Sample using CPU sampler struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); - struct llama_sampler * chain = llama_sampler_chain_init(chain_params); - llama_sampler_chain_add(chain, llama_sampler_init_dist(18)); + llama_sampler_ptr chain(llama_sampler_chain_init(chain_params)); + llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(18)); std::map tokens = { { seq_id, backend_token}, }; if (!test_ctx.decode_tokens(tokens)) { @@ -975,17 +938,17 @@ static void test_backend_set_sampler(const char * model_path) { GGML_ASSERT(llama_get_sampled_probs_ith(test_ctx.ctx, idx) == nullptr); // Sample the token using the CPU sampler chain. - llama_token token2 = llama_sampler_sample(chain, test_ctx.ctx, seq_id); + llama_token token2 = llama_sampler_sample(chain.get(), test_ctx.ctx, seq_id); const std::string token2_str = test_ctx.token_to_piece(token2, false); printf("CPU sampled token after clearing backend sampler: id=%d, string='%s'\n", token2, token2_str.c_str()); std::map tokens2 = { { seq_id, token2}, }; // Set a new backend sampler for the sequence. struct llama_sampler_chain_params new_backend_chain_params = llama_sampler_chain_default_params(); - struct llama_sampler * new_backend_sampler_chain = llama_sampler_chain_init(new_backend_chain_params); - llama_sampler_chain_add(new_backend_sampler_chain, llama_sampler_init_top_k(20)); - llama_sampler_chain_add(new_backend_sampler_chain, llama_sampler_init_dist(seed)); - llama_set_sampler(test_ctx.ctx, seq_id, new_backend_sampler_chain); + llama_sampler_ptr new_backend_sampler_chain(llama_sampler_chain_init(new_backend_chain_params)); + llama_sampler_chain_add(new_backend_sampler_chain.get(), llama_sampler_init_top_k(20)); + llama_sampler_chain_add(new_backend_sampler_chain.get(), llama_sampler_init_dist(seed)); + llama_set_sampler(test_ctx.ctx, seq_id, new_backend_sampler_chain.get()); if (!test_ctx.decode_tokens(tokens2)) { GGML_ASSERT(false && "Failed to decode token"); @@ -995,10 +958,6 @@ static void test_backend_set_sampler(const char * model_path) { const std::string new_backend_token_str = test_ctx.token_to_piece(new_backend_token, false); printf("dist sampled token = %d, string='%s'\n", new_backend_token, new_backend_token_str.c_str()); - llama_sampler_free(backend_sampler_chain); - llama_sampler_free(chain); - llama_sampler_free(new_backend_sampler_chain); - printf("backend set sampler test PASSED\n"); } @@ -1007,11 +966,11 @@ static void test_backend_cpu_mixed_batch(const char * model_path) { // Sequence 0 uses backend sampling struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params(); - struct llama_sampler * sampler_chain_0 = llama_sampler_chain_init(chain_params_0); - llama_sampler_chain_add(sampler_chain_0, llama_sampler_init_dist(88)); + llama_sampler_ptr sampler_chain_0(llama_sampler_chain_init(chain_params_0)); + llama_sampler_chain_add(sampler_chain_0.get(), llama_sampler_init_dist(88)); std::vector backend_sampler_configs = { - { 0, sampler_chain_0 }, + { 0, sampler_chain_0.get() }, }; // We need 2 sequences: seq 0 with backend sampling, seq 1 with CPU sampling @@ -1045,14 +1004,13 @@ static void test_backend_cpu_mixed_batch(const char * model_path) { GGML_ASSERT(backend_token == LLAMA_TOKEN_NULL); struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); - struct llama_sampler * chain = llama_sampler_chain_init(chain_params); - llama_sampler_chain_add(chain, llama_sampler_init_greedy()); + llama_sampler_ptr chain(llama_sampler_chain_init(chain_params)); + llama_sampler_chain_add(chain.get(), llama_sampler_init_greedy()); - llama_token token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx); + llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx, batch_idx); const std::string token_str = test_ctx.token_to_piece(token, false); printf("Seq 1 (CPU) sampled token id=%d, string='%s'\n", token, token_str.c_str()); GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); - llama_sampler_free(chain); } // Clear/remove the backend sampler, and sample again @@ -1063,25 +1021,23 @@ static void test_backend_cpu_mixed_batch(const char * model_path) { // Create a CPU sampler and verify we can sampler from it. struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); - struct llama_sampler * chain = llama_sampler_chain_init(chain_params); - llama_sampler_chain_add(chain, llama_sampler_init_greedy()); + llama_sampler_ptr chain(llama_sampler_chain_init(chain_params)); + llama_sampler_chain_add(chain.get(), llama_sampler_init_greedy()); int32_t batch_idx = test_ctx.idx_for_seq(1); - llama_token token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx); + llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx, batch_idx); if (!test_ctx.decode_token(token, 1)) { GGML_ASSERT(false && "Failed to decode token"); } - - llama_sampler_free(chain); } // Set a backend sampler so that we can verify that it can be reset { struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); - struct llama_sampler * sampler_chain = llama_sampler_chain_init(chain_params); - llama_sampler_chain_add(sampler_chain, llama_sampler_init_dist(88)); + llama_sampler_ptr sampler_chain(llama_sampler_chain_init(chain_params)); + llama_sampler_chain_add(sampler_chain.get(), llama_sampler_init_dist(88)); - llama_set_sampler(test_ctx.ctx, 0, sampler_chain); + llama_set_sampler(test_ctx.ctx, 0, sampler_chain.get()); if (!test_ctx.decode_token(3834, 0)) { GGML_ASSERT(false && "Failed to decode token"); @@ -1092,12 +1048,8 @@ static void test_backend_cpu_mixed_batch(const char * model_path) { const std::string token_str = test_ctx.token_to_piece(token, false); printf("re-added backend sampled token id=%d, string='%s'\n", token, token_str.c_str()); GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); - - llama_sampler_free(sampler_chain); } - llama_sampler_free(sampler_chain_0); - printf("backend-cpu mixed batch test PASSED\n"); } @@ -1107,9 +1059,9 @@ static void test_backend_max_outputs(const char * model_path) { const int seq_id = 0; const int32_t seed = 88; llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); - llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params); - llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_dist(seed)); - std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain }}; + llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params)); + llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(seed)); + std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }}; if (!test_ctx.setup(model_path, backend_sampler_configs)) { return; @@ -1140,7 +1092,6 @@ static void test_backend_max_outputs(const char * model_path) { printf("<<< test_max_outputs expected error end.\n"); llama_batch_free(batch); - llama_sampler_free(backend_sampler_chain); printf("backend max outputs test PASSED\n"); }