diff --git a/common/common.cpp b/common/common.cpp index 7a89f16250..9792c0b6a6 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1178,6 +1178,12 @@ common_sampler * common_init_result::sampler(llama_seq_id seq_id) { return pimpl->samplers[seq_id].get(); } +void common_init_result::reset_samplers() { + for (int i = 0; i < (int) pimpl->samplers.size(); ++i) { + llama_sampler_reset(common_sampler_get(pimpl->samplers[i].get())); + } +} + std::vector & common_init_result::lora() { return pimpl->lora; } @@ -1311,6 +1317,8 @@ common_init_result_ptr common_init_from_params(common_params & params) { llama_synchronize(lctx); llama_perf_context_reset(lctx); llama_set_warmup(lctx, false); + // reset samplers to reset RNG state after warmup to the seeded state + res->reset_samplers(); } return res; diff --git a/common/common.h b/common/common.h index 431bc6f3dc..5eeee7d64a 100644 --- a/common/common.h +++ b/common/common.h @@ -690,7 +690,9 @@ struct common_init_result { llama_model * model(); llama_context * context(); + common_sampler * sampler(llama_seq_id seq_id); + void reset_samplers(); std::vector & lora();