diff --git a/common/sampling.cpp b/common/sampling.cpp index 8a931d51fc..54f6377a1e 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -334,15 +334,21 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st } void common_sampler_free(struct common_sampler * gsmpl) { - if (gsmpl) { - llama_sampler_free(gsmpl->grmr); - llama_sampler_free(gsmpl->chain); - - delete gsmpl; + if (!gsmpl) { + return; } + + llama_sampler_free(gsmpl->grmr); + llama_sampler_free(gsmpl->chain); + + delete gsmpl; } void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) { + if (!gsmpl) { + return; + } + const auto tm = gsmpl->tm(); if (gsmpl->grmr && accept_grammar) { @@ -355,6 +361,10 @@ void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, boo } void common_sampler_reset(struct common_sampler * gsmpl) { + if (!gsmpl) { + return; + } + gsmpl->reset(); } @@ -415,6 +425,10 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam } struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl) { + if (!gsmpl) { + return nullptr; + } + return gsmpl->chain; }