completion : add replying of session state
This commit updates the session handing in the completion tool to handle the that logits are no longer stored in the session file. Instead, we need to replay the last token to get the logits for sampling.
This commit is contained in:
parent
4bd1809675
commit
44bddc0a89
|
|
@ -387,6 +387,32 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
session_do_save = !path_session.empty() && n_match < embd_inp.size() && !params.prompt_cache_ro;
|
||||
|
||||
// Logits are not stored as part of the session state so we need to
|
||||
// "replay" the last token to get logits for sampling.
|
||||
if (!session_tokens.empty() && n_match > 0 && n_match == session_tokens.size()) {
|
||||
llama_token last_token = session_tokens.back();
|
||||
int32_t pos;
|
||||
|
||||
if (llama_model_is_recurrent(model) || llama_model_is_hybrid(model)) {
|
||||
LOG_INF("%s: recurrent/hybrid model: decode using next position: %d\n", __func__, (int)n_match);
|
||||
pos = n_match; // use next position for decoding
|
||||
} else {
|
||||
LOG_INF("%s: non-recurrent model: removing and re-decode last position: %d\n", __func__, (int)n_match - 1);
|
||||
if (!llama_memory_seq_rm(mem, 0, n_match - 1, n_match)) {
|
||||
LOG_ERR("%s: failed to remove last position from KV cache\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
pos = n_match - 1;
|
||||
}
|
||||
|
||||
llama_batch batch = llama_batch_get_one(&last_token, 1);
|
||||
batch.pos = &pos;
|
||||
if (llama_decode(ctx, batch)) {
|
||||
LOG_ERR("%s: failed to regenerate logits after loading state\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// number of tokens to keep when resetting context
|
||||
|
|
|
|||
Loading…
Reference in New Issue