diff --git a/common/common.cpp b/common/common.cpp index 383379cb5b..55e5cb3372 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1789,6 +1789,16 @@ float lr_opt::get_lr(float epoch) const { return r; } +bool common_replay_last_token(struct llama_context * ctx, llama_token last_token, int32_t pos) { + llama_batch batch = llama_batch_get_one(&last_token, 1); + batch.pos = &pos; + if (llama_decode(ctx, batch)) { + LOG_ERR("%s: failed to replay last token\n", __func__); + return false; + } + return true; +} + bool common_prompt_batch_decode( struct llama_context * ctx, const std::vector & tokens, diff --git a/common/common.h b/common/common.h index c7de7d7ecb..b1124d92b6 100644 --- a/common/common.h +++ b/common/common.h @@ -794,6 +794,10 @@ bool common_prompt_batch_decode( bool save_state, bool is_last_batch = true); +// replays the last token after loading state to regenerate logits +// used after loading session state to ensure the sampling context has valid logits +bool common_replay_last_token(struct llama_context * ctx, llama_token last_token, int32_t pos); + // // Vocab utils // diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index c315933163..a949e1b643 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -6,17 +6,6 @@ #include #include -static bool replay_last_token(llama_context * ctx, llama_token last_token, int & n_past) { - llama_batch batch = llama_batch_get_one(&last_token, 1); - int pos = n_past; - batch.pos = &pos; - if (llama_decode(ctx, batch)) { - fprintf(stderr, "%s: failed to replay last token after loading state\n", __func__); - return false; - } - ++n_past; - return true; -} int main(int argc, char ** argv) { common_params params; @@ -120,9 +109,10 @@ int main(int argc, char ** argv) { // restore state (last tokens) n_past = n_token_count_out; - if (!replay_last_token(ctx2, tokens.back(), n_past)) { + if (!common_replay_last_token(ctx2, tokens.back(), n_past)) { return 1; } + ++n_past; // second run for (auto i = 0; i < params.n_predict; i++) { @@ -173,9 +163,10 @@ int main(int argc, char ** argv) { // restore state (last tokens) n_past = n_token_count_out; - if (!replay_last_token(ctx3, tokens.back(), n_past)) { + if (!common_replay_last_token(ctx3, tokens.back(), n_past)) { return 1; } + ++n_past; // save seq 0 and load into seq 1 { diff --git a/tools/completion/completion.cpp b/tools/completion/completion.cpp index 7607bc1e01..d1acbbc538 100644 --- a/tools/completion/completion.cpp +++ b/tools/completion/completion.cpp @@ -391,13 +391,7 @@ int main(int argc, char ** argv) { // Logits are not stored as part of the session state so we need to // "replay" the last token to get logits for sampling. if (!session_tokens.empty() && n_match > 0 && n_match == session_tokens.size()) { - llama_token last_token = session_tokens.back(); - int32_t pos = n_match; - - llama_batch batch = llama_batch_get_one(&last_token, 1); - batch.pos = &pos; - if (llama_decode(ctx, batch)) { - LOG_ERR("%s: failed to regenerate logits after loading state\n", __func__); + if (!common_replay_last_token(ctx, session_tokens.back(), n_match)) { return 1; }