diff --git a/tools/completion/completion.cpp b/tools/completion/completion.cpp index 977132756f..86590ee263 100644 --- a/tools/completion/completion.cpp +++ b/tools/completion/completion.cpp @@ -387,6 +387,32 @@ int main(int argc, char ** argv) { } session_do_save = !path_session.empty() && n_match < embd_inp.size() && !params.prompt_cache_ro; + + // Logits are not stored as part of the session state so we need to + // "replay" the last token to get logits for sampling. + if (!session_tokens.empty() && n_match > 0 && n_match == session_tokens.size()) { + llama_token last_token = session_tokens.back(); + int32_t pos; + + if (llama_model_is_recurrent(model) || llama_model_is_hybrid(model)) { + LOG_INF("%s: recurrent/hybrid model: decode using next position: %d\n", __func__, (int)n_match); + pos = n_match; // use next position for decoding + } else { + LOG_INF("%s: non-recurrent model: removing and re-decode last position: %d\n", __func__, (int)n_match - 1); + if (!llama_memory_seq_rm(mem, 0, n_match - 1, n_match)) { + LOG_ERR("%s: failed to remove last position from KV cache\n", __func__); + return 1; + } + pos = n_match - 1; + } + + llama_batch batch = llama_batch_get_one(&last_token, 1); + batch.pos = &pos; + if (llama_decode(ctx, batch)) { + LOG_ERR("%s: failed to regenerate logits after loading state\n", __func__); + return 1; + } + } } // number of tokens to keep when resetting context