remove is_last_batch parameter from common_prompt_batch_decode

This commit is contained in:
Daniel Bevenius 2026-02-13 12:43:50 +01:00
parent 7902ae7380
commit d9a6e49844
No known key found for this signature in database
3 changed files with 5 additions and 6 deletions

View File

@ -1805,14 +1805,13 @@ bool common_prompt_batch_decode(
int & n_past, int & n_past,
int n_batch, int n_batch,
const std::string_view & state_path, const std::string_view & state_path,
bool save_state, bool save_state) {
bool is_last_batch) {
const int n_eval = tokens.size(); const int n_eval = tokens.size();
if (n_eval == 0) { if (n_eval == 0) {
return true; return true;
} }
if (save_state && is_last_batch && n_eval > 1) { if (save_state && n_eval > 1) {
const int n_tokens_before_last = n_eval - 1; const int n_tokens_before_last = n_eval - 1;
GGML_ASSERT(n_eval <= n_batch); GGML_ASSERT(n_eval <= n_batch);

View File

@ -790,8 +790,7 @@ bool common_prompt_batch_decode(
int & n_past, int & n_past,
int n_batch, int n_batch,
const std::string_view & state_path, const std::string_view & state_path,
bool save_state, bool save_state);
bool is_last_batch = true);
// replays the last token after loading state to regenerate logits // replays the last token after loading state to regenerate logits
// used after loading session state to ensure the sampling context has valid logits // used after loading session state to ensure the sampling context has valid logits

View File

@ -687,7 +687,8 @@ int main(int argc, char ** argv) {
if (!embd.empty()) { if (!embd.empty()) {
const bool is_last_batch = (n_consumed >= (int) embd_inp.size()); const bool is_last_batch = (n_consumed >= (int) embd_inp.size());
if (!common_prompt_batch_decode(ctx, embd, n_past, params.n_batch, path_session, session_do_save, is_last_batch)) { const bool save_now = session_do_save && is_last_batch;
if (!common_prompt_batch_decode(ctx, embd, n_past, params.n_batch, path_session, save_now)) {
return 1; return 1;
} }
session_tokens.insert(session_tokens.end(), embd.begin(), embd.begin()); session_tokens.insert(session_tokens.end(), embd.begin(), embd.begin());