diff --git a/common/common.cpp b/common/common.cpp index 3aa396127c..bd94d951c5 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1863,3 +1863,56 @@ float lr_opt::get_lr(float epoch) const { LOG_INF("epoch %.2g lr=%.2g\n", epoch, r); return r; } + +bool common_prompt_batch_decode( + struct llama_context * ctx, + const std::vector & tokens, + int & n_past, + int n_batch, + const std::filesystem::path & state_path, + bool save_state, + bool is_last_batch) { + const int n_eval = tokens.size(); + if (n_eval == 0) { + return true; + } + + if (save_state && is_last_batch && n_eval > 1) { + const int n_tokens_before_last = n_eval - 1; + + GGML_ASSERT(n_eval <= n_batch); + + // Decode all but the last token so we can save the memory state before decoding the last token. + // This is done so we can restore the session state later and replay the last token. + // Memory implementations in recurrent/hybrid models don't support removing tokens from their + // memory, so we can't just remove the last token from the memory and replay the last token which + // is the reason for this logic. + if (llama_decode(ctx, llama_batch_get_one(const_cast(tokens.data()), n_tokens_before_last))) { + LOG_ERR("%s : failed to eval\n", __func__); + return false; + } + n_past += n_tokens_before_last; + + llama_state_save_file(ctx, state_path.string().c_str(), tokens.data(), n_tokens_before_last); + LOG_INF("saved session before last token to %s, n_tokens = %d\n", state_path.string().c_str(), n_tokens_before_last); + + llama_token last_token = tokens.back(); + llama_batch batch = llama_batch_get_one(&last_token, 1); + int32_t pos = n_past; + batch.pos = &pos; + + if (llama_decode(ctx, batch)) { + LOG_ERR("%s : failed to eval last token\n", __func__); + return false; + } + n_past++; + } else { + if (llama_decode(ctx, llama_batch_get_one(const_cast(tokens.data()), n_eval))) { + LOG_ERR("%s : failed to eval\n", __func__); + return false; + } + n_past += n_eval; + } + + return true; +} diff --git a/common/common.h b/common/common.h index 398ebb0960..dad97b5a96 100644 --- a/common/common.h +++ b/common/common.h @@ -5,6 +5,7 @@ #include "ggml-opt.h" #include "llama-cpp.h" +#include #include #include #include @@ -780,6 +781,20 @@ void common_batch_add( const std::vector & seq_ids, bool logits); +// decodes a single batch of tokens for a prompt and manages session tokens +// +// Note: We save state before the last token so that we can replay it to ensure +// compatibility with all memory types. Recurrent/hybrid models cannot remove +// tokens from memory, so this approach works across all model architectures. +bool common_prompt_batch_decode( + struct llama_context * ctx, + const std::vector & embd, + int & n_past, + int n_batch, + const std::filesystem::path & state_path, + bool save_state, + bool is_last_batch = true); + // // Token utils // diff --git a/tools/completion/completion.cpp b/tools/completion/completion.cpp index 86590ee263..7607bc1e01 100644 --- a/tools/completion/completion.cpp +++ b/tools/completion/completion.cpp @@ -392,19 +392,7 @@ int main(int argc, char ** argv) { // "replay" the last token to get logits for sampling. if (!session_tokens.empty() && n_match > 0 && n_match == session_tokens.size()) { llama_token last_token = session_tokens.back(); - int32_t pos; - - if (llama_model_is_recurrent(model) || llama_model_is_hybrid(model)) { - LOG_INF("%s: recurrent/hybrid model: decode using next position: %d\n", __func__, (int)n_match); - pos = n_match; // use next position for decoding - } else { - LOG_INF("%s: non-recurrent model: removing and re-decode last position: %d\n", __func__, (int)n_match - 1); - if (!llama_memory_seq_rm(mem, 0, n_match - 1, n_match)) { - LOG_ERR("%s: failed to remove last position from KV cache\n", __func__); - return 1; - } - pos = n_match - 1; - } + int32_t pos = n_match; llama_batch batch = llama_batch_get_one(&last_token, 1); batch.pos = &pos; @@ -412,6 +400,9 @@ int main(int argc, char ** argv) { LOG_ERR("%s: failed to regenerate logits after loading state\n", __func__); return 1; } + + session_do_save = false; + LOG_INF("%s: replayed last token from session\n", __func__); } } @@ -701,40 +692,26 @@ int main(int argc, char ** argv) { } if (!embd.empty()) { - int n_eval = (int) embd.size(); - LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str()); - - GGML_ASSERT(n_eval <= params.n_batch); - if (llama_decode(ctx, llama_batch_get_one(embd.data(), n_eval))) { - LOG_ERR("%s : failed to eval\n", __func__); + const bool is_last_batch = (n_consumed >= (int) embd_inp.size()); + if (!common_prompt_batch_decode(ctx, embd, n_past, params.n_batch, path_session, session_do_save, is_last_batch)) { return 1; } - - n_past += n_eval; + session_tokens.insert(session_tokens.end(), embd.begin(), embd.begin()); + n_session_consumed = session_tokens.size(); + session_do_save = false; LOG_DBG("n_past = %d\n", n_past); + // Display total tokens alongside total time if (params.n_print > 0 && n_past % params.n_print == 0) { LOG_DBG("\n\033[31mTokens consumed so far = %d / %d \033[0m\n", n_past, n_ctx); } } - - if (!embd.empty() && !path_session.empty()) { - session_tokens.insert(session_tokens.end(), embd.begin(), embd.end()); - n_session_consumed = session_tokens.size(); - } } embd.clear(); if ((int) embd_inp.size() <= n_consumed && !is_interacting) { - // optionally save the session on first sample (for faster prompt loading next time) - if (session_do_save) { - session_do_save = false; - llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size()); - - LOG_DBG("saved session to %s\n", path_session.c_str()); - } const llama_token id = common_sampler_sample(smpl, ctx, -1);