From d9a23126bf1b081ae408695a1db74d66c3b301e1 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Wed, 11 Feb 2026 13:22:50 +0100 Subject: [PATCH] common : extract replay_last_token to common.h This commit extracts the replay_last_token function from save-load-state.cpp to common.h. The motivation for this is to allow reuse of the function but also to clarify the intent of code that replays the last token after loading the session state. --- common/common.cpp | 10 ++++++++++ common/common.h | 4 ++++ examples/save-load-state/save-load-state.cpp | 17 ++++------------- tools/completion/completion.cpp | 8 +------- 4 files changed, 19 insertions(+), 20 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 383379cb5b..55e5cb3372 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1789,6 +1789,16 @@ float lr_opt::get_lr(float epoch) const { return r; } +bool common_replay_last_token(struct llama_context * ctx, llama_token last_token, int32_t pos) { + llama_batch batch = llama_batch_get_one(&last_token, 1); + batch.pos = &pos; + if (llama_decode(ctx, batch)) { + LOG_ERR("%s: failed to replay last token\n", __func__); + return false; + } + return true; +} + bool common_prompt_batch_decode( struct llama_context * ctx, const std::vector & tokens, diff --git a/common/common.h b/common/common.h index c7de7d7ecb..b1124d92b6 100644 --- a/common/common.h +++ b/common/common.h @@ -794,6 +794,10 @@ bool common_prompt_batch_decode( bool save_state, bool is_last_batch = true); +// replays the last token after loading state to regenerate logits +// used after loading session state to ensure the sampling context has valid logits +bool common_replay_last_token(struct llama_context * ctx, llama_token last_token, int32_t pos); + // // Vocab utils // diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index c315933163..a949e1b643 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -6,17 +6,6 @@ #include #include -static bool replay_last_token(llama_context * ctx, llama_token last_token, int & n_past) { - llama_batch batch = llama_batch_get_one(&last_token, 1); - int pos = n_past; - batch.pos = &pos; - if (llama_decode(ctx, batch)) { - fprintf(stderr, "%s: failed to replay last token after loading state\n", __func__); - return false; - } - ++n_past; - return true; -} int main(int argc, char ** argv) { common_params params; @@ -120,9 +109,10 @@ int main(int argc, char ** argv) { // restore state (last tokens) n_past = n_token_count_out; - if (!replay_last_token(ctx2, tokens.back(), n_past)) { + if (!common_replay_last_token(ctx2, tokens.back(), n_past)) { return 1; } + ++n_past; // second run for (auto i = 0; i < params.n_predict; i++) { @@ -173,9 +163,10 @@ int main(int argc, char ** argv) { // restore state (last tokens) n_past = n_token_count_out; - if (!replay_last_token(ctx3, tokens.back(), n_past)) { + if (!common_replay_last_token(ctx3, tokens.back(), n_past)) { return 1; } + ++n_past; // save seq 0 and load into seq 1 { diff --git a/tools/completion/completion.cpp b/tools/completion/completion.cpp index 7607bc1e01..d1acbbc538 100644 --- a/tools/completion/completion.cpp +++ b/tools/completion/completion.cpp @@ -391,13 +391,7 @@ int main(int argc, char ** argv) { // Logits are not stored as part of the session state so we need to // "replay" the last token to get logits for sampling. if (!session_tokens.empty() && n_match > 0 && n_match == session_tokens.size()) { - llama_token last_token = session_tokens.back(); - int32_t pos = n_match; - - llama_batch batch = llama_batch_get_one(&last_token, 1); - batch.pos = &pos; - if (llama_decode(ctx, batch)) { - LOG_ERR("%s: failed to regenerate logits after loading state\n", __func__); + if (!common_replay_last_token(ctx, session_tokens.back(), n_match)) { return 1; }