diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index a5a48a4d7e..811029052a 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -920,6 +920,30 @@ static const char * llama_sampler_greedy_name(const struct llama_sampler * smpl) return sctx->get_name(); } +static void llama_sampler_greedy_reset(struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_greedy *) smpl->ctx; + GGML_UNUSED(ctx); +} + +static struct llama_sampler * llama_sampler_greedy_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_greedy *) smpl->ctx; + auto * result = llama_sampler_init_greedy(); + + // copy the state + { + auto * result_ctx = (llama_sampler_greedy *) result->ctx; + + GGML_UNUSED(ctx); + GGML_UNUSED(result_ctx); + } + + return result; +} + +static void llama_sampler_greedy_free(struct llama_sampler * smpl) { + delete (llama_sampler_greedy *) smpl->ctx; +} + static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) { cur_p->selected = 0; for (size_t i = 1; i < cur_p->size; ++i) { @@ -959,9 +983,9 @@ static struct llama_sampler_i llama_sampler_greedy_i = { /* .name = */ llama_sampler_greedy_name, /* .accept = */ nullptr, /* .apply = */ llama_sampler_greedy_apply, - /* .reset = */ nullptr, - /* .clone = */ nullptr, - /* .free = */ nullptr, + /* .reset = */ llama_sampler_greedy_reset, + /* .clone = */ llama_sampler_greedy_clone, + /* .free = */ llama_sampler_greedy_free, /* .backend_init = */ llama_sampler_greedy_backend_init, /* .backend_accept = */ nullptr, /* .backend_apply = */ llama_sampler_greedy_backend_apply, @@ -1069,6 +1093,12 @@ static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_da #endif } +static void llama_sampler_dist_reset(struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_dist *) smpl->ctx; + ctx->seed_cur = get_rng_seed(ctx->seed); + ctx->rng.seed(ctx->seed_cur); +} + static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) { const auto * ctx = (const llama_sampler_dist *) smpl->ctx; auto * result = llama_sampler_init_dist(ctx->seed); @@ -1083,12 +1113,6 @@ static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sample return result; } -static void llama_sampler_dist_reset(struct llama_sampler * smpl) { - auto * ctx = (llama_sampler_dist *) smpl->ctx; - ctx->seed_cur = get_rng_seed(ctx->seed); - ctx->rng.seed(ctx->seed_cur); -} - static void llama_sampler_dist_free(struct llama_sampler * smpl) { delete (llama_sampler_dist *) smpl->ctx; } diff --git a/tests/test-backend-sampler.cpp b/tests/test-backend-sampler.cpp index 7c33d0374c..db1a2631f0 100644 --- a/tests/test-backend-sampler.cpp +++ b/tests/test-backend-sampler.cpp @@ -314,6 +314,8 @@ static void test_backend_greedy_sampling(const char * model_path) { GGML_ASSERT(false && "Failed to decode token"); } } + + llama_sampler_free(backend_sampler_chain); } static void test_backend_top_k_sampling(const char * model_path) { @@ -349,6 +351,8 @@ static void test_backend_top_k_sampling(const char * model_path) { test_ctx.token_to_piece(candidates[i], false).c_str()); } + llama_sampler_free(backend_sampler_chain); + // Sample using CPU sampler for verification that it is possible to do hybrid // sampling, first top_k on the backend and then dist on the CPU. struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); @@ -392,6 +396,9 @@ static void test_backend_temp_sampling(const char * model_path) { GGML_ASSERT(false && "Failed to decode token"); } + llama_sampler_free(backend_sampler_chain_0); + llama_sampler_free(backend_sampler_chain_1); + // Verfify sequence 0 { int32_t batch_idx = test_ctx.idx_for_seq(0); @@ -457,6 +464,8 @@ static void test_backend_temp_sampling(const char * model_path) { uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx); GGML_ASSERT(n_logits == 1); + + llama_sampler_free(backend_sampler_chain); }; test_argmax_temp(0.0f); @@ -496,6 +505,8 @@ static void test_backend_temp_ext_sampling(const char * model_path) { int n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx); GGML_ASSERT(n_logits == test_ctx.n_vocab); } + + llama_sampler_free(backend_sampler_chain); } test_ctx.reset(); @@ -532,6 +543,8 @@ static void test_backend_temp_ext_sampling(const char * model_path) { } else { GGML_ASSERT(n_logits == (uint32_t) test_ctx.n_vocab); } + + llama_sampler_free(backend_sampler_chain); }; test_argmax_temp(0.0f, 0.3f, 1.0f); // Greedy (temp=0) @@ -597,6 +610,7 @@ static void test_backend_min_p_sampling(const char * model_path) { printf("min-p sampling test PASSED\n"); + llama_sampler_free(backend_sampler_chain); llama_sampler_free(chain); } @@ -653,6 +667,7 @@ static void test_backend_top_p_sampling(const char * model_path) { printf("top-p sampling test PASSED\n"); + llama_sampler_free(backend_sampler_chain); llama_sampler_free(chain); } @@ -723,6 +738,9 @@ static void test_backend_multi_sequence_sampling(const char * model_path) { } } + llama_sampler_free(sampler_chain_0); + llama_sampler_free(sampler_chain_1); + printf("backend multi-sequence sampling test PASSED\n"); } @@ -755,6 +773,7 @@ static void test_backend_dist_sampling(const char * model_path) { GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); llama_sampler_free(backend_sampler_chain); + printf("backend dist sampling test PASSED\n"); }