From d9a6e49844a0eddee7c2c52729c5dc1a78076134 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Fri, 13 Feb 2026 12:43:50 +0100 Subject: [PATCH] remove is_last_batch parameter from common_prompt_batch_decode --- common/common.cpp | 5 ++--- common/common.h | 3 +-- tools/completion/completion.cpp | 3 ++- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index b1382e0644..615ce81c62 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1805,14 +1805,13 @@ bool common_prompt_batch_decode( int & n_past, int n_batch, const std::string_view & state_path, - bool save_state, - bool is_last_batch) { + bool save_state) { const int n_eval = tokens.size(); if (n_eval == 0) { return true; } - if (save_state && is_last_batch && n_eval > 1) { + if (save_state && n_eval > 1) { const int n_tokens_before_last = n_eval - 1; GGML_ASSERT(n_eval <= n_batch); diff --git a/common/common.h b/common/common.h index 5b3e5bdc5b..d7ced8a2ca 100644 --- a/common/common.h +++ b/common/common.h @@ -790,8 +790,7 @@ bool common_prompt_batch_decode( int & n_past, int n_batch, const std::string_view & state_path, - bool save_state, - bool is_last_batch = true); + bool save_state); // replays the last token after loading state to regenerate logits // used after loading session state to ensure the sampling context has valid logits diff --git a/tools/completion/completion.cpp b/tools/completion/completion.cpp index d1acbbc538..aed2c0e38f 100644 --- a/tools/completion/completion.cpp +++ b/tools/completion/completion.cpp @@ -687,7 +687,8 @@ int main(int argc, char ** argv) { if (!embd.empty()) { const bool is_last_batch = (n_consumed >= (int) embd_inp.size()); - if (!common_prompt_batch_decode(ctx, embd, n_past, params.n_batch, path_session, session_do_save, is_last_batch)) { + const bool save_now = session_do_save && is_last_batch; + if (!common_prompt_batch_decode(ctx, embd, n_past, params.n_batch, path_session, save_now)) { return 1; } session_tokens.insert(session_tokens.end(), embd.begin(), embd.begin());