diff --git a/common/common.cpp b/common/common.cpp index a9bd494191..497cfaad5e 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1243,6 +1243,9 @@ llama_context * common_init_result::context() { } common_sampler * common_init_result::sampler(llama_seq_id seq_id) { + if (seq_id < 0 || seq_id >= (int) pimpl->samplers.size()) { + return nullptr; + } return pimpl->samplers[seq_id].get(); } diff --git a/tools/completion/completion.cpp b/tools/completion/completion.cpp index 716a30fe9a..813526a0ec 100644 --- a/tools/completion/completion.cpp +++ b/tools/completion/completion.cpp @@ -146,19 +146,13 @@ int main(int argc, char ** argv) { ctx = llama_init->context(); model = llama_init->model(); + smpl = llama_init->sampler(0); if (ctx == NULL) { LOG_ERR("%s: error: unable to create context\n", __func__); return 1; } - if (model == NULL) { - LOG_ERR("%s: error: unable to load model\n", __func__); - return 1; - } - - smpl = llama_init->sampler(0); - llama_memory_t mem = llama_get_memory(ctx); const llama_vocab * vocab = llama_model_get_vocab(model);