completion : fix prompt cache for recurrent models (#19045)

This commit is contained in:
Georgi Gerganov 2026-01-25 09:12:50 +02:00 committed by GitHub
parent 1243f93a2d
commit 080b161995
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 48 additions and 41 deletions

View File

@ -2559,6 +2559,7 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
}
}
// [TAG_CONTEXT_STATE_LOGITS]
// write logits
{
LLAMA_LOG_DEBUG("%s: - writing logits\n", __func__);

View File

@ -342,44 +342,51 @@ int main(int argc, char ** argv) {
return 1;
}
// debug message about similarity of saved session, if applicable
size_t n_matching_session_tokens = 0;
if (!session_tokens.empty()) {
for (llama_token id : session_tokens) {
if (n_matching_session_tokens >= embd_inp.size() || id != embd_inp[n_matching_session_tokens]) {
break;
bool session_do_save = false;
{
size_t n_match = 0;
if (!session_tokens.empty()) {
for (llama_token id : session_tokens) {
if (n_match >= embd_inp.size() || id != embd_inp[n_match]) {
break;
}
n_match++;
}
if (params.prompt.empty() && n_match == embd_inp.size()) {
LOG_INF("%s: using full prompt from session file\n", __func__);
} else if (n_match >= embd_inp.size()) {
LOG_INF("%s: session file has exact match for prompt!\n", __func__);
} else if (n_match < (embd_inp.size() / 2)) {
LOG_WRN("%s: session file has low similarity to prompt (%zu / %zu tokens); will mostly be reevaluated\n",
__func__, n_match, embd_inp.size());
} else {
LOG_INF("%s: session file matches %zu / %zu tokens of prompt\n",
__func__, n_match, embd_inp.size());
}
if (session_tokens.size() == n_match) {
// [TAG_CONTEXT_STATE_LOGITS]
// in this case, we are going to reuse the logits from the session
// if we ever decide to remove the logits from the session, we need to handle this somehow
// ref: https://github.com/ggml-org/llama.cpp/pull/18862#issuecomment-3756330941
}
// remove any "future" tokens that we might have inherited from the previous session
if (session_tokens.size() > n_match) {
if (!llama_memory_seq_rm(mem, -1, n_match, -1)) {
LOG_WRN("%s: unable to resuse common prefix (for example, when the memory is recurrent)\n", __func__);
llama_memory_clear(mem, true);
session_tokens.clear();
n_match = 0;
} else {
session_tokens.resize(n_match);
}
}
n_matching_session_tokens++;
}
if (params.prompt.empty() && n_matching_session_tokens == embd_inp.size()) {
LOG_INF("%s: using full prompt from session file\n", __func__);
} else if (n_matching_session_tokens >= embd_inp.size()) {
LOG_INF("%s: session file has exact match for prompt!\n", __func__);
} else if (n_matching_session_tokens < (embd_inp.size() / 2)) {
LOG_WRN("%s: session file has low similarity to prompt (%zu / %zu tokens); will mostly be reevaluated\n",
__func__, n_matching_session_tokens, embd_inp.size());
} else {
LOG_INF("%s: session file matches %zu / %zu tokens of prompt\n",
__func__, n_matching_session_tokens, embd_inp.size());
}
// remove any "future" tokens that we might have inherited from the previous session
if (!llama_memory_seq_rm(mem, -1, n_matching_session_tokens, -1)) {
LOG_INF("%s: unable to resuse common prefix\n", __func__);
n_matching_session_tokens = 0;
llama_memory_seq_rm(mem, -1, -1, -1);
}
}
LOG_DBG("recalculate the cached logits (check): embd_inp.size() %zu, n_matching_session_tokens %zu, embd_inp.size() %zu, session_tokens.size() %zu\n",
embd_inp.size(), n_matching_session_tokens, embd_inp.size(), session_tokens.size());
// if we will use the cache for the full prompt without reaching the end of the cache, force
// reevaluation of the last token to recalculate the cached logits
if (!embd_inp.empty() && n_matching_session_tokens == embd_inp.size() && session_tokens.size() > embd_inp.size()) {
LOG_DBG("recalculate the cached logits (do): session_tokens.resize( %zu )\n", embd_inp.size() - 1);
session_tokens.resize(embd_inp.size() - 1);
session_do_save = !path_session.empty() && n_match < embd_inp.size() && !params.prompt_cache_ro;
}
// number of tokens to keep when resetting context
@ -521,10 +528,9 @@ int main(int argc, char ** argv) {
is_interacting = params.interactive_first;
}
bool is_antiprompt = false;
bool input_echo = true;
bool display = true;
bool need_to_save_session = !path_session.empty() && n_matching_session_tokens < embd_inp.size();
bool is_antiprompt = false;
bool input_echo = true;
bool display = true;
int n_past = 0;
int n_remain = params.n_predict;
@ -700,8 +706,8 @@ int main(int argc, char ** argv) {
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
// optionally save the session on first sample (for faster prompt loading next time)
if (!path_session.empty() && need_to_save_session && !params.prompt_cache_ro) {
need_to_save_session = false;
if (session_do_save) {
session_do_save = false;
llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
LOG_DBG("saved session to %s\n", path_session.c_str());