common : extract replay_last_token to common.h

This commit extracts the replay_last_token function from
save-load-state.cpp to common.h.

The motivation for this is to allow reuse of the function but also to
clarify the intent of code that replays the last token after loading
the session state.
This commit is contained in:
Daniel Bevenius 2026-02-11 13:22:50 +01:00
parent a70867c19e
commit d9a23126bf
No known key found for this signature in database
4 changed files with 19 additions and 20 deletions

View File

@ -1789,6 +1789,16 @@ float lr_opt::get_lr(float epoch) const {
return r;
}
bool common_replay_last_token(struct llama_context * ctx, llama_token last_token, int32_t pos) {
llama_batch batch = llama_batch_get_one(&last_token, 1);
batch.pos = &pos;
if (llama_decode(ctx, batch)) {
LOG_ERR("%s: failed to replay last token\n", __func__);
return false;
}
return true;
}
bool common_prompt_batch_decode(
struct llama_context * ctx,
const std::vector<llama_token> & tokens,

View File

@ -794,6 +794,10 @@ bool common_prompt_batch_decode(
bool save_state,
bool is_last_batch = true);
// replays the last token after loading state to regenerate logits
// used after loading session state to ensure the sampling context has valid logits
bool common_replay_last_token(struct llama_context * ctx, llama_token last_token, int32_t pos);
//
// Vocab utils
//

View File

@ -6,17 +6,6 @@
#include <vector>
#include <cstdio>
static bool replay_last_token(llama_context * ctx, llama_token last_token, int & n_past) {
llama_batch batch = llama_batch_get_one(&last_token, 1);
int pos = n_past;
batch.pos = &pos;
if (llama_decode(ctx, batch)) {
fprintf(stderr, "%s: failed to replay last token after loading state\n", __func__);
return false;
}
++n_past;
return true;
}
int main(int argc, char ** argv) {
common_params params;
@ -120,9 +109,10 @@ int main(int argc, char ** argv) {
// restore state (last tokens)
n_past = n_token_count_out;
if (!replay_last_token(ctx2, tokens.back(), n_past)) {
if (!common_replay_last_token(ctx2, tokens.back(), n_past)) {
return 1;
}
++n_past;
// second run
for (auto i = 0; i < params.n_predict; i++) {
@ -173,9 +163,10 @@ int main(int argc, char ** argv) {
// restore state (last tokens)
n_past = n_token_count_out;
if (!replay_last_token(ctx3, tokens.back(), n_past)) {
if (!common_replay_last_token(ctx3, tokens.back(), n_past)) {
return 1;
}
++n_past;
// save seq 0 and load into seq 1
{

View File

@ -391,13 +391,7 @@ int main(int argc, char ** argv) {
// Logits are not stored as part of the session state so we need to
// "replay" the last token to get logits for sampling.
if (!session_tokens.empty() && n_match > 0 && n_match == session_tokens.size()) {
llama_token last_token = session_tokens.back();
int32_t pos = n_match;
llama_batch batch = llama_batch_get_one(&last_token, 1);
batch.pos = &pos;
if (llama_decode(ctx, batch)) {
LOG_ERR("%s: failed to regenerate logits after loading state\n", __func__);
if (!common_replay_last_token(ctx, session_tokens.back(), n_match)) {
return 1;
}