diff --git a/common/common.cpp b/common/common.cpp index 3aa396127c..bd94d951c5 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1863,3 +1863,56 @@ float lr_opt::get_lr(float epoch) const { LOG_INF("epoch %.2g lr=%.2g\n", epoch, r); return r; } + +bool common_prompt_batch_decode( + struct llama_context * ctx, + const std::vector & tokens, + int & n_past, + int n_batch, + const std::filesystem::path & state_path, + bool save_state, + bool is_last_batch) { + const int n_eval = tokens.size(); + if (n_eval == 0) { + return true; + } + + if (save_state && is_last_batch && n_eval > 1) { + const int n_tokens_before_last = n_eval - 1; + + GGML_ASSERT(n_eval <= n_batch); + + // Decode all but the last token so we can save the memory state before decoding the last token. + // This is done so we can restore the session state later and replay the last token. + // Memory implementations in recurrent/hybrid models don't support removing tokens from their + // memory, so we can't just remove the last token from the memory and replay the last token which + // is the reason for this logic. + if (llama_decode(ctx, llama_batch_get_one(const_cast(tokens.data()), n_tokens_before_last))) { + LOG_ERR("%s : failed to eval\n", __func__); + return false; + } + n_past += n_tokens_before_last; + + llama_state_save_file(ctx, state_path.string().c_str(), tokens.data(), n_tokens_before_last); + LOG_INF("saved session before last token to %s, n_tokens = %d\n", state_path.string().c_str(), n_tokens_before_last); + + llama_token last_token = tokens.back(); + llama_batch batch = llama_batch_get_one(&last_token, 1); + int32_t pos = n_past; + batch.pos = &pos; + + if (llama_decode(ctx, batch)) { + LOG_ERR("%s : failed to eval last token\n", __func__); + return false; + } + n_past++; + } else { + if (llama_decode(ctx, llama_batch_get_one(const_cast(tokens.data()), n_eval))) { + LOG_ERR("%s : failed to eval\n", __func__); + return false; + } + n_past += n_eval; + } + + return true; +} diff --git a/common/common.h b/common/common.h index 398ebb0960..dad97b5a96 100644 --- a/common/common.h +++ b/common/common.h @@ -5,6 +5,7 @@ #include "ggml-opt.h" #include "llama-cpp.h" +#include #include #include #include @@ -780,6 +781,20 @@ void common_batch_add( const std::vector & seq_ids, bool logits); +// decodes a single batch of tokens for a prompt and manages session tokens +// +// Note: We save state before the last token so that we can replay it to ensure +// compatibility with all memory types. Recurrent/hybrid models cannot remove +// tokens from memory, so this approach works across all model architectures. +bool common_prompt_batch_decode( + struct llama_context * ctx, + const std::vector & embd, + int & n_past, + int n_batch, + const std::filesystem::path & state_path, + bool save_state, + bool is_last_batch = true); + // // Token utils // diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index 39d4464663..c315933163 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -2,15 +2,30 @@ #include "common.h" #include "llama.h" +#include #include #include +static bool replay_last_token(llama_context * ctx, llama_token last_token, int & n_past) { + llama_batch batch = llama_batch_get_one(&last_token, 1); + int pos = n_past; + batch.pos = &pos; + if (llama_decode(ctx, batch)) { + fprintf(stderr, "%s: failed to replay last token after loading state\n", __func__); + return false; + } + ++n_past; + return true; +} + int main(int argc, char ** argv) { common_params params; params.prompt = "The quick brown fox"; params.sampling.seed = 1234; + std::filesystem::path state_file = "dump_state.bin"; + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) { return 1; } @@ -53,35 +68,16 @@ int main(int argc, char ** argv) { // tokenize prompt auto tokens = common_tokenize(ctx, params.prompt, true); - // prepare the batch - llama_batch batch = llama_batch_init(tokens.size(), 0, 1); - for (size_t i = 0; i < tokens.size(); i++) { - common_batch_add(batch, tokens[i], i, {0}, false); + const bool save_state = true; + if (!common_prompt_batch_decode(ctx, tokens, n_past, params.n_batch, state_file, save_state)) { + return 1; } - batch.logits[batch.n_tokens - 1] = true; // generate next token - - // evaluate prompt - llama_decode(ctx, batch); - n_past += batch.n_tokens; - - // save state (rng, logits, embedding and kv_cache) to file - { - std::vector state_mem(llama_state_get_size(ctx)); - const size_t written = llama_state_get_data(ctx, state_mem.data(), state_mem.size()); - - FILE *fp_write = fopen("dump_state.bin", "wb"); - fwrite(state_mem.data(), 1, written, fp_write); - fclose(fp_write); - - fprintf(stderr, "%s : serialized state into %zd out of a maximum of %zd bytes\n", __func__, written, state_mem.size()); - } - - // save state (last tokens) - const auto n_past_saved = n_past; // first run printf("\nfirst run: %s", params.prompt.c_str()); + llama_batch batch = llama_batch_init(1, 0, 1); + for (auto i = 0; i < params.n_predict; i++) { auto next_token = llama_sampler_sample(smpl, ctx, -1); auto next_token_str = common_token_to_piece(ctx, next_token); @@ -111,27 +107,22 @@ int main(int argc, char ** argv) { printf("\nsecond run: %s", params.prompt.c_str()); - // load state (rng, logits, embedding and kv_cache) from file - { - std::vector state_mem; + // load state from file + std::vector unused_sts(tokens.size()); // unused session tokens. + size_t n_token_count_out = 0; - FILE * fp_read = fopen("dump_state.bin", "rb"); - fseek(fp_read, 0, SEEK_END); - state_mem.resize(ftell(fp_read)); - fseek(fp_read, 0, SEEK_SET); - const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read); - fclose(fp_read); - - if (read != llama_state_set_data(ctx2, state_mem.data(), state_mem.size())) { - fprintf(stderr, "\n%s : failed to read state\n", __func__); - return 1; - } - - fprintf(stderr, "%s : deserialized state from %zd out of a maximum of %zd bytes\n", __func__, read, state_mem.size()); + if (!llama_state_load_file(ctx2, state_file.string().c_str(), unused_sts.data(), unused_sts.size(), &n_token_count_out)) { + fprintf(stderr, "\n%s : failed to load state\n", __func__); + return 1; } + fprintf(stderr, "%s : loaded state with %zu tokens\n", __func__, n_token_count_out); + // restore state (last tokens) - n_past = n_past_saved; + n_past = n_token_count_out; + if (!replay_last_token(ctx2, tokens.back(), n_past)) { + return 1; + } // second run for (auto i = 0; i < params.n_predict; i++) { @@ -160,7 +151,9 @@ int main(int argc, char ** argv) { } // make new context - llama_context * ctx3 = llama_init_from_model(model, common_context_params_to_llama(params)); + auto params_ctx3 = common_context_params_to_llama(params); + params_ctx3.n_seq_max = 2; + llama_context * ctx3 = llama_init_from_model(model, params_ctx3); llama_sampler * smpl3 = llama_sampler_chain_init(sparams); @@ -169,26 +162,20 @@ int main(int argc, char ** argv) { printf("\nsingle seq run: %s", params.prompt.c_str()); // load state (rng, logits, embedding and kv_cache) from file - { - std::vector state_mem; + n_token_count_out = 0; - FILE * fp_read = fopen("dump_state.bin", "rb"); - fseek(fp_read, 0, SEEK_END); - state_mem.resize(ftell(fp_read)); - fseek(fp_read, 0, SEEK_SET); - const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read); - fclose(fp_read); - - if (read != llama_state_set_data(ctx3, state_mem.data(), state_mem.size())) { - fprintf(stderr, "\n%s : failed to read state\n", __func__); - return 1; - } - - fprintf(stderr, "%s : deserialized state from %zd out of a maximum of %zd bytes\n", __func__, read, state_mem.size()); + if (!llama_state_load_file(ctx3, state_file.string().c_str(), unused_sts.data(), unused_sts.size(), &n_token_count_out)) { + fprintf(stderr, "\n%s : failed to load state\n", __func__); + return 1; } + fprintf(stderr, "%s : loaded state with %zu tokens\n", __func__, n_token_count_out); + // restore state (last tokens) - n_past = n_past_saved; + n_past = n_token_count_out; + if (!replay_last_token(ctx3, tokens.back(), n_past)) { + return 1; + } // save seq 0 and load into seq 1 { diff --git a/src/llama-context.cpp b/src/llama-context.cpp index a6df893a31..a106dcf51c 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -2500,64 +2500,6 @@ size_t llama_context::state_write_data(llama_io_write_i & io) { // TODO: add more model-specific info which should prevent loading the session file if not identical } - // write output ids - { - LLAMA_LOG_DEBUG("%s: - writing output ids\n", __func__); - - const auto n_outputs = this->n_outputs; - const auto & output_ids = this->output_ids; - - std::vector w_output_pos; - - w_output_pos.resize(n_outputs); - - // build a more compact representation of the output ids - for (size_t i = 0; i < n_batch(); ++i) { - // map an output id to a position in the batch - int64_t pos = output_ids[i]; - if (pos >= 0) { - GGML_ASSERT(pos < n_outputs); - w_output_pos[pos] = i; - } - } - - io.write(&n_outputs, sizeof(n_outputs)); - - if (n_outputs) { - io.write(w_output_pos.data(), n_outputs * sizeof(int32_t)); - } - } - - // [TAG_CONTEXT_STATE_LOGITS] - // write logits - { - LLAMA_LOG_DEBUG("%s: - writing logits\n", __func__); - - const uint64_t logits_size = std::min((uint64_t) this->logits_size, (uint64_t) n_outputs * model.vocab.n_tokens()); - - io.write(&logits_size, sizeof(logits_size)); - - if (logits_size) { - io.write(logits, logits_size * sizeof(float)); - } - } - - // write embeddings - { - LLAMA_LOG_DEBUG("%s: - writing embeddings\n", __func__); - - const uint64_t embd_size = std::min((uint64_t) this->embd_size, (uint64_t) n_outputs * model.hparams.n_embd); - - io.write(&embd_size, sizeof(embd_size)); - - if (embd_size) { - io.write(embd, embd_size * sizeof(float)); - } - } - - // TODO: handle sampling buffers and samplers state ? - // https://github.com/ggml-org/llama.cpp/pull/17004 - if (memory != nullptr) { LLAMA_LOG_DEBUG("%s: - writing memory module\n", __func__); memory->state_write(io); @@ -2583,70 +2525,6 @@ size_t llama_context::state_read_data(llama_io_read_i & io) { // TODO: add more info which needs to be identical but which is not verified otherwise } - // read output ids - { - LLAMA_LOG_DEBUG("%s: - reading output ids\n", __func__); - - auto n_outputs = this->n_outputs; - io.read_to(&n_outputs, sizeof(n_outputs)); - - if (n_outputs > output_reserve(n_outputs)) { - throw std::runtime_error("could not reserve outputs"); - } - - std::vector output_pos; - - if (n_outputs) { - output_pos.resize(n_outputs); - io.read_to(output_pos.data(), n_outputs * sizeof(int32_t)); - - for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) { - int32_t id = output_pos[i]; - if ((uint32_t) id >= n_batch()) { - throw std::runtime_error(format("invalid output id, %d does not fit in batch size of %u", id, n_batch())); - } - this->output_ids[id] = i; - } - - this->n_outputs = n_outputs; - } - } - - // read logits - { - LLAMA_LOG_DEBUG("%s: - reading logits\n", __func__); - - uint64_t logits_size; - io.read_to(&logits_size, sizeof(logits_size)); - - if (this->logits_size < logits_size) { - throw std::runtime_error("logits buffer too small"); - } - - if (logits_size) { - io.read_to(this->logits, logits_size * sizeof(float)); - } - } - - // read embeddings - { - LLAMA_LOG_DEBUG("%s: - reading embeddings\n", __func__); - - uint64_t embd_size; - io.read_to(&embd_size, sizeof(embd_size)); - - if (this->embd_size < embd_size) { - throw std::runtime_error("embeddings buffer too small"); - } - - if (embd_size) { - io.read_to(this->embd, embd_size * sizeof(float)); - } - } - - // TODO: handle sampling buffers and samplers state ? - // https://github.com/ggml-org/llama.cpp/pull/17004 - if (memory) { LLAMA_LOG_DEBUG("%s: - reading memory module\n", __func__); diff --git a/tools/completion/completion.cpp b/tools/completion/completion.cpp index 977132756f..7607bc1e01 100644 --- a/tools/completion/completion.cpp +++ b/tools/completion/completion.cpp @@ -387,6 +387,23 @@ int main(int argc, char ** argv) { } session_do_save = !path_session.empty() && n_match < embd_inp.size() && !params.prompt_cache_ro; + + // Logits are not stored as part of the session state so we need to + // "replay" the last token to get logits for sampling. + if (!session_tokens.empty() && n_match > 0 && n_match == session_tokens.size()) { + llama_token last_token = session_tokens.back(); + int32_t pos = n_match; + + llama_batch batch = llama_batch_get_one(&last_token, 1); + batch.pos = &pos; + if (llama_decode(ctx, batch)) { + LOG_ERR("%s: failed to regenerate logits after loading state\n", __func__); + return 1; + } + + session_do_save = false; + LOG_INF("%s: replayed last token from session\n", __func__); + } } // number of tokens to keep when resetting context @@ -675,40 +692,26 @@ int main(int argc, char ** argv) { } if (!embd.empty()) { - int n_eval = (int) embd.size(); - LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str()); - - GGML_ASSERT(n_eval <= params.n_batch); - if (llama_decode(ctx, llama_batch_get_one(embd.data(), n_eval))) { - LOG_ERR("%s : failed to eval\n", __func__); + const bool is_last_batch = (n_consumed >= (int) embd_inp.size()); + if (!common_prompt_batch_decode(ctx, embd, n_past, params.n_batch, path_session, session_do_save, is_last_batch)) { return 1; } - - n_past += n_eval; + session_tokens.insert(session_tokens.end(), embd.begin(), embd.begin()); + n_session_consumed = session_tokens.size(); + session_do_save = false; LOG_DBG("n_past = %d\n", n_past); + // Display total tokens alongside total time if (params.n_print > 0 && n_past % params.n_print == 0) { LOG_DBG("\n\033[31mTokens consumed so far = %d / %d \033[0m\n", n_past, n_ctx); } } - - if (!embd.empty() && !path_session.empty()) { - session_tokens.insert(session_tokens.end(), embd.begin(), embd.end()); - n_session_consumed = session_tokens.size(); - } } embd.clear(); if ((int) embd_inp.size() <= n_consumed && !is_interacting) { - // optionally save the session on first sample (for faster prompt loading next time) - if (session_do_save) { - session_do_save = false; - llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size()); - - LOG_DBG("saved session to %s\n", path_session.c_str()); - } const llama_token id = common_sampler_sample(smpl, ctx, -1);