completion : add replying of session state

This commit updates the session handing in the completion tool to handle
the that logits are no longer stored in the session file. Instead, we
need to replay the last token to get the logits for sampling.
This commit is contained in:
Daniel Bevenius 2026-01-27 16:08:18 +01:00
parent 4bd1809675
commit 44bddc0a89
No known key found for this signature in database
1 changed files with 26 additions and 0 deletions

View File

@ -387,6 +387,32 @@ int main(int argc, char ** argv) {
}
session_do_save = !path_session.empty() && n_match < embd_inp.size() && !params.prompt_cache_ro;
// Logits are not stored as part of the session state so we need to
// "replay" the last token to get logits for sampling.
if (!session_tokens.empty() && n_match > 0 && n_match == session_tokens.size()) {
llama_token last_token = session_tokens.back();
int32_t pos;
if (llama_model_is_recurrent(model) || llama_model_is_hybrid(model)) {
LOG_INF("%s: recurrent/hybrid model: decode using next position: %d\n", __func__, (int)n_match);
pos = n_match; // use next position for decoding
} else {
LOG_INF("%s: non-recurrent model: removing and re-decode last position: %d\n", __func__, (int)n_match - 1);
if (!llama_memory_seq_rm(mem, 0, n_match - 1, n_match)) {
LOG_ERR("%s: failed to remove last position from KV cache\n", __func__);
return 1;
}
pos = n_match - 1;
}
llama_batch batch = llama_batch_get_one(&last_token, 1);
batch.pos = &pos;
if (llama_decode(ctx, batch)) {
LOG_ERR("%s: failed to regenerate logits after loading state\n", __func__);
return 1;
}
}
}
// number of tokens to keep when resetting context