From 90aa83c6bdaa6ea82e1af5a02f77ff80088653d0 Mon Sep 17 00:00:00 2001 From: mtmcp <141645996+mtmcp@users.noreply.github.com> Date: Tue, 31 Mar 2026 07:04:42 -0300 Subject: [PATCH] common: add bounds check in common_init_result::sampler to prevent segfault on failed model load (#21082) * common: add bounds check in common_init_result::sampler to prevent segfault on failed model load * Revert a308e584cae3fa8cee1d739a858a2d780f1de009 * Add regression test * Remove regression test for init-fail sampler check --- common/common.cpp | 3 +++ tools/completion/completion.cpp | 8 +------- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index a9bd494191..497cfaad5e 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1243,6 +1243,9 @@ llama_context * common_init_result::context() { } common_sampler * common_init_result::sampler(llama_seq_id seq_id) { + if (seq_id < 0 || seq_id >= (int) pimpl->samplers.size()) { + return nullptr; + } return pimpl->samplers[seq_id].get(); } diff --git a/tools/completion/completion.cpp b/tools/completion/completion.cpp index 716a30fe9a..813526a0ec 100644 --- a/tools/completion/completion.cpp +++ b/tools/completion/completion.cpp @@ -146,19 +146,13 @@ int main(int argc, char ** argv) { ctx = llama_init->context(); model = llama_init->model(); + smpl = llama_init->sampler(0); if (ctx == NULL) { LOG_ERR("%s: error: unable to create context\n", __func__); return 1; } - if (model == NULL) { - LOG_ERR("%s: error: unable to load model\n", __func__); - return 1; - } - - smpl = llama_init->sampler(0); - llama_memory_t mem = llama_get_memory(ctx); const llama_vocab * vocab = llama_model_get_vocab(model);