completion : fix prompt cache for recurrent models (#19045)
This commit is contained in:
parent
1243f93a2d
commit
080b161995
|
|
@ -2559,6 +2559,7 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
|
|||
}
|
||||
}
|
||||
|
||||
// [TAG_CONTEXT_STATE_LOGITS]
|
||||
// write logits
|
||||
{
|
||||
LLAMA_LOG_DEBUG("%s: - writing logits\n", __func__);
|
||||
|
|
|
|||
|
|
@ -342,44 +342,51 @@ int main(int argc, char ** argv) {
|
|||
return 1;
|
||||
}
|
||||
|
||||
// debug message about similarity of saved session, if applicable
|
||||
size_t n_matching_session_tokens = 0;
|
||||
if (!session_tokens.empty()) {
|
||||
for (llama_token id : session_tokens) {
|
||||
if (n_matching_session_tokens >= embd_inp.size() || id != embd_inp[n_matching_session_tokens]) {
|
||||
break;
|
||||
bool session_do_save = false;
|
||||
|
||||
{
|
||||
size_t n_match = 0;
|
||||
|
||||
if (!session_tokens.empty()) {
|
||||
for (llama_token id : session_tokens) {
|
||||
if (n_match >= embd_inp.size() || id != embd_inp[n_match]) {
|
||||
break;
|
||||
}
|
||||
n_match++;
|
||||
}
|
||||
if (params.prompt.empty() && n_match == embd_inp.size()) {
|
||||
LOG_INF("%s: using full prompt from session file\n", __func__);
|
||||
} else if (n_match >= embd_inp.size()) {
|
||||
LOG_INF("%s: session file has exact match for prompt!\n", __func__);
|
||||
} else if (n_match < (embd_inp.size() / 2)) {
|
||||
LOG_WRN("%s: session file has low similarity to prompt (%zu / %zu tokens); will mostly be reevaluated\n",
|
||||
__func__, n_match, embd_inp.size());
|
||||
} else {
|
||||
LOG_INF("%s: session file matches %zu / %zu tokens of prompt\n",
|
||||
__func__, n_match, embd_inp.size());
|
||||
}
|
||||
|
||||
if (session_tokens.size() == n_match) {
|
||||
// [TAG_CONTEXT_STATE_LOGITS]
|
||||
// in this case, we are going to reuse the logits from the session
|
||||
// if we ever decide to remove the logits from the session, we need to handle this somehow
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/18862#issuecomment-3756330941
|
||||
}
|
||||
|
||||
// remove any "future" tokens that we might have inherited from the previous session
|
||||
if (session_tokens.size() > n_match) {
|
||||
if (!llama_memory_seq_rm(mem, -1, n_match, -1)) {
|
||||
LOG_WRN("%s: unable to resuse common prefix (for example, when the memory is recurrent)\n", __func__);
|
||||
llama_memory_clear(mem, true);
|
||||
session_tokens.clear();
|
||||
n_match = 0;
|
||||
} else {
|
||||
session_tokens.resize(n_match);
|
||||
}
|
||||
}
|
||||
n_matching_session_tokens++;
|
||||
}
|
||||
if (params.prompt.empty() && n_matching_session_tokens == embd_inp.size()) {
|
||||
LOG_INF("%s: using full prompt from session file\n", __func__);
|
||||
} else if (n_matching_session_tokens >= embd_inp.size()) {
|
||||
LOG_INF("%s: session file has exact match for prompt!\n", __func__);
|
||||
} else if (n_matching_session_tokens < (embd_inp.size() / 2)) {
|
||||
LOG_WRN("%s: session file has low similarity to prompt (%zu / %zu tokens); will mostly be reevaluated\n",
|
||||
__func__, n_matching_session_tokens, embd_inp.size());
|
||||
} else {
|
||||
LOG_INF("%s: session file matches %zu / %zu tokens of prompt\n",
|
||||
__func__, n_matching_session_tokens, embd_inp.size());
|
||||
}
|
||||
|
||||
// remove any "future" tokens that we might have inherited from the previous session
|
||||
if (!llama_memory_seq_rm(mem, -1, n_matching_session_tokens, -1)) {
|
||||
LOG_INF("%s: unable to resuse common prefix\n", __func__);
|
||||
n_matching_session_tokens = 0;
|
||||
llama_memory_seq_rm(mem, -1, -1, -1);
|
||||
}
|
||||
}
|
||||
|
||||
LOG_DBG("recalculate the cached logits (check): embd_inp.size() %zu, n_matching_session_tokens %zu, embd_inp.size() %zu, session_tokens.size() %zu\n",
|
||||
embd_inp.size(), n_matching_session_tokens, embd_inp.size(), session_tokens.size());
|
||||
|
||||
// if we will use the cache for the full prompt without reaching the end of the cache, force
|
||||
// reevaluation of the last token to recalculate the cached logits
|
||||
if (!embd_inp.empty() && n_matching_session_tokens == embd_inp.size() && session_tokens.size() > embd_inp.size()) {
|
||||
LOG_DBG("recalculate the cached logits (do): session_tokens.resize( %zu )\n", embd_inp.size() - 1);
|
||||
|
||||
session_tokens.resize(embd_inp.size() - 1);
|
||||
session_do_save = !path_session.empty() && n_match < embd_inp.size() && !params.prompt_cache_ro;
|
||||
}
|
||||
|
||||
// number of tokens to keep when resetting context
|
||||
|
|
@ -521,10 +528,9 @@ int main(int argc, char ** argv) {
|
|||
is_interacting = params.interactive_first;
|
||||
}
|
||||
|
||||
bool is_antiprompt = false;
|
||||
bool input_echo = true;
|
||||
bool display = true;
|
||||
bool need_to_save_session = !path_session.empty() && n_matching_session_tokens < embd_inp.size();
|
||||
bool is_antiprompt = false;
|
||||
bool input_echo = true;
|
||||
bool display = true;
|
||||
|
||||
int n_past = 0;
|
||||
int n_remain = params.n_predict;
|
||||
|
|
@ -700,8 +706,8 @@ int main(int argc, char ** argv) {
|
|||
|
||||
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
|
||||
// optionally save the session on first sample (for faster prompt loading next time)
|
||||
if (!path_session.empty() && need_to_save_session && !params.prompt_cache_ro) {
|
||||
need_to_save_session = false;
|
||||
if (session_do_save) {
|
||||
session_do_save = false;
|
||||
llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
|
||||
|
||||
LOG_DBG("saved session to %s\n", path_session.c_str());
|
||||
|
|
|
|||
Loading…
Reference in New Issue