From 2f148f6cb3fa00003f02d71d07128c539649b9c0 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Tue, 27 Jan 2026 16:01:43 +0100 Subject: [PATCH 1/9] llama : remove write/read of output ids/logits/embeddings This commit removes the write/read of output ids, logits and embeddings from the llama context state. Refs: https://github.com/ggml-org/llama.cpp/pull/18862#issuecomment-3756330941 --- src/llama-context.cpp | 122 ------------------------------------------ 1 file changed, 122 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index fc05989aa5..e6575da876 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -2482,64 +2482,6 @@ size_t llama_context::state_write_data(llama_io_write_i & io) { // TODO: add more model-specific info which should prevent loading the session file if not identical } - // write output ids - { - LLAMA_LOG_DEBUG("%s: - writing output ids\n", __func__); - - const auto n_outputs = this->n_outputs; - const auto & output_ids = this->output_ids; - - std::vector w_output_pos; - - w_output_pos.resize(n_outputs); - - // build a more compact representation of the output ids - for (size_t i = 0; i < n_batch(); ++i) { - // map an output id to a position in the batch - int64_t pos = output_ids[i]; - if (pos >= 0) { - GGML_ASSERT(pos < n_outputs); - w_output_pos[pos] = i; - } - } - - io.write(&n_outputs, sizeof(n_outputs)); - - if (n_outputs) { - io.write(w_output_pos.data(), n_outputs * sizeof(int32_t)); - } - } - - // [TAG_CONTEXT_STATE_LOGITS] - // write logits - { - LLAMA_LOG_DEBUG("%s: - writing logits\n", __func__); - - const uint64_t logits_size = std::min((uint64_t) this->logits.size, (uint64_t) n_outputs * model.vocab.n_tokens()); - - io.write(&logits_size, sizeof(logits_size)); - - if (logits_size) { - io.write(logits.data, logits_size * sizeof(float)); - } - } - - // write embeddings - { - LLAMA_LOG_DEBUG("%s: - writing embeddings\n", __func__); - - const uint64_t embd_size = std::min((uint64_t) this->embd.size, (uint64_t) n_outputs * model.hparams.n_embd); - - io.write(&embd_size, sizeof(embd_size)); - - if (embd_size) { - io.write(embd.data, embd_size * sizeof(float)); - } - } - - // TODO: handle sampling buffers and samplers state ? - // https://github.com/ggml-org/llama.cpp/pull/17004 - if (memory != nullptr) { LLAMA_LOG_DEBUG("%s: - writing memory module\n", __func__); memory->state_write(io); @@ -2565,70 +2507,6 @@ size_t llama_context::state_read_data(llama_io_read_i & io) { // TODO: add more info which needs to be identical but which is not verified otherwise } - // read output ids - { - LLAMA_LOG_DEBUG("%s: - reading output ids\n", __func__); - - auto n_outputs = this->n_outputs; - io.read_to(&n_outputs, sizeof(n_outputs)); - - if (n_outputs > output_reserve(n_outputs)) { - throw std::runtime_error("could not reserve outputs"); - } - - std::vector output_pos; - - if (n_outputs) { - output_pos.resize(n_outputs); - io.read_to(output_pos.data(), n_outputs * sizeof(int32_t)); - - for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) { - int32_t id = output_pos[i]; - if ((uint32_t) id >= n_batch()) { - throw std::runtime_error(format("invalid output id, %d does not fit in batch size of %u", id, n_batch())); - } - this->output_ids[id] = i; - } - - this->n_outputs = n_outputs; - } - } - - // read logits - { - LLAMA_LOG_DEBUG("%s: - reading logits\n", __func__); - - uint64_t logits_size; - io.read_to(&logits_size, sizeof(logits_size)); - - if (this->logits.size < logits_size) { - throw std::runtime_error("logits buffer too small"); - } - - if (logits_size) { - io.read_to(this->logits.data, logits_size * sizeof(float)); - } - } - - // read embeddings - { - LLAMA_LOG_DEBUG("%s: - reading embeddings\n", __func__); - - uint64_t embd_size; - io.read_to(&embd_size, sizeof(embd_size)); - - if (this->embd.size < embd_size) { - throw std::runtime_error("embeddings buffer too small"); - } - - if (embd_size) { - io.read_to(this->embd.data, embd_size * sizeof(float)); - } - } - - // TODO: handle sampling buffers and samplers state ? - // https://github.com/ggml-org/llama.cpp/pull/17004 - if (memory) { LLAMA_LOG_DEBUG("%s: - reading memory module\n", __func__); From 79f7d4351df70ba5a2e6ecee4089ff374203ac13 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Tue, 27 Jan 2026 16:08:18 +0100 Subject: [PATCH 2/9] completion : add replying of session state This commit updates the session handing in the completion tool to handle the that logits are no longer stored in the session file. Instead, we need to replay the last token to get the logits for sampling. --- tools/completion/completion.cpp | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tools/completion/completion.cpp b/tools/completion/completion.cpp index 977132756f..86590ee263 100644 --- a/tools/completion/completion.cpp +++ b/tools/completion/completion.cpp @@ -387,6 +387,32 @@ int main(int argc, char ** argv) { } session_do_save = !path_session.empty() && n_match < embd_inp.size() && !params.prompt_cache_ro; + + // Logits are not stored as part of the session state so we need to + // "replay" the last token to get logits for sampling. + if (!session_tokens.empty() && n_match > 0 && n_match == session_tokens.size()) { + llama_token last_token = session_tokens.back(); + int32_t pos; + + if (llama_model_is_recurrent(model) || llama_model_is_hybrid(model)) { + LOG_INF("%s: recurrent/hybrid model: decode using next position: %d\n", __func__, (int)n_match); + pos = n_match; // use next position for decoding + } else { + LOG_INF("%s: non-recurrent model: removing and re-decode last position: %d\n", __func__, (int)n_match - 1); + if (!llama_memory_seq_rm(mem, 0, n_match - 1, n_match)) { + LOG_ERR("%s: failed to remove last position from KV cache\n", __func__); + return 1; + } + pos = n_match - 1; + } + + llama_batch batch = llama_batch_get_one(&last_token, 1); + batch.pos = &pos; + if (llama_decode(ctx, batch)) { + LOG_ERR("%s: failed to regenerate logits after loading state\n", __func__); + return 1; + } + } } // number of tokens to keep when resetting context From aebc54600bbc5d91e329941afc4dcfa38ce563f2 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Tue, 3 Feb 2026 10:13:04 +0100 Subject: [PATCH 3/9] common : add common_prompt_batch_decode function This commit adds a new function which is responsible for decoding prompt and optionally handle the saving for session data. --- common/common.cpp | 53 +++++++++++++++++++++++++++++++++ common/common.h | 15 ++++++++++ tools/completion/completion.cpp | 43 +++++++------------------- 3 files changed, 78 insertions(+), 33 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 32487ddc61..383379cb5b 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1788,3 +1788,56 @@ float lr_opt::get_lr(float epoch) const { LOG_INF("epoch %.2g lr=%.2g\n", epoch, r); return r; } + +bool common_prompt_batch_decode( + struct llama_context * ctx, + const std::vector & tokens, + int & n_past, + int n_batch, + const std::filesystem::path & state_path, + bool save_state, + bool is_last_batch) { + const int n_eval = tokens.size(); + if (n_eval == 0) { + return true; + } + + if (save_state && is_last_batch && n_eval > 1) { + const int n_tokens_before_last = n_eval - 1; + + GGML_ASSERT(n_eval <= n_batch); + + // Decode all but the last token so we can save the memory state before decoding the last token. + // This is done so we can restore the session state later and replay the last token. + // Memory implementations in recurrent/hybrid models don't support removing tokens from their + // memory, so we can't just remove the last token from the memory and replay the last token which + // is the reason for this logic. + if (llama_decode(ctx, llama_batch_get_one(const_cast(tokens.data()), n_tokens_before_last))) { + LOG_ERR("%s : failed to eval\n", __func__); + return false; + } + n_past += n_tokens_before_last; + + llama_state_save_file(ctx, state_path.string().c_str(), tokens.data(), n_tokens_before_last); + LOG_INF("saved session before last token to %s, n_tokens = %d\n", state_path.string().c_str(), n_tokens_before_last); + + llama_token last_token = tokens.back(); + llama_batch batch = llama_batch_get_one(&last_token, 1); + int32_t pos = n_past; + batch.pos = &pos; + + if (llama_decode(ctx, batch)) { + LOG_ERR("%s : failed to eval last token\n", __func__); + return false; + } + n_past++; + } else { + if (llama_decode(ctx, llama_batch_get_one(const_cast(tokens.data()), n_eval))) { + LOG_ERR("%s : failed to eval\n", __func__); + return false; + } + n_past += n_eval; + } + + return true; +} diff --git a/common/common.h b/common/common.h index 804485fb19..c7de7d7ecb 100644 --- a/common/common.h +++ b/common/common.h @@ -5,6 +5,7 @@ #include "ggml-opt.h" #include "llama-cpp.h" +#include #include #include #include @@ -779,6 +780,20 @@ void common_batch_add( const std::vector & seq_ids, bool logits); +// decodes a single batch of tokens for a prompt and manages session tokens +// +// Note: We save state before the last token so that we can replay it to ensure +// compatibility with all memory types. Recurrent/hybrid models cannot remove +// tokens from memory, so this approach works across all model architectures. +bool common_prompt_batch_decode( + struct llama_context * ctx, + const std::vector & embd, + int & n_past, + int n_batch, + const std::filesystem::path & state_path, + bool save_state, + bool is_last_batch = true); + // // Vocab utils // diff --git a/tools/completion/completion.cpp b/tools/completion/completion.cpp index 86590ee263..7607bc1e01 100644 --- a/tools/completion/completion.cpp +++ b/tools/completion/completion.cpp @@ -392,19 +392,7 @@ int main(int argc, char ** argv) { // "replay" the last token to get logits for sampling. if (!session_tokens.empty() && n_match > 0 && n_match == session_tokens.size()) { llama_token last_token = session_tokens.back(); - int32_t pos; - - if (llama_model_is_recurrent(model) || llama_model_is_hybrid(model)) { - LOG_INF("%s: recurrent/hybrid model: decode using next position: %d\n", __func__, (int)n_match); - pos = n_match; // use next position for decoding - } else { - LOG_INF("%s: non-recurrent model: removing and re-decode last position: %d\n", __func__, (int)n_match - 1); - if (!llama_memory_seq_rm(mem, 0, n_match - 1, n_match)) { - LOG_ERR("%s: failed to remove last position from KV cache\n", __func__); - return 1; - } - pos = n_match - 1; - } + int32_t pos = n_match; llama_batch batch = llama_batch_get_one(&last_token, 1); batch.pos = &pos; @@ -412,6 +400,9 @@ int main(int argc, char ** argv) { LOG_ERR("%s: failed to regenerate logits after loading state\n", __func__); return 1; } + + session_do_save = false; + LOG_INF("%s: replayed last token from session\n", __func__); } } @@ -701,40 +692,26 @@ int main(int argc, char ** argv) { } if (!embd.empty()) { - int n_eval = (int) embd.size(); - LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str()); - - GGML_ASSERT(n_eval <= params.n_batch); - if (llama_decode(ctx, llama_batch_get_one(embd.data(), n_eval))) { - LOG_ERR("%s : failed to eval\n", __func__); + const bool is_last_batch = (n_consumed >= (int) embd_inp.size()); + if (!common_prompt_batch_decode(ctx, embd, n_past, params.n_batch, path_session, session_do_save, is_last_batch)) { return 1; } - - n_past += n_eval; + session_tokens.insert(session_tokens.end(), embd.begin(), embd.begin()); + n_session_consumed = session_tokens.size(); + session_do_save = false; LOG_DBG("n_past = %d\n", n_past); + // Display total tokens alongside total time if (params.n_print > 0 && n_past % params.n_print == 0) { LOG_DBG("\n\033[31mTokens consumed so far = %d / %d \033[0m\n", n_past, n_ctx); } } - - if (!embd.empty() && !path_session.empty()) { - session_tokens.insert(session_tokens.end(), embd.begin(), embd.end()); - n_session_consumed = session_tokens.size(); - } } embd.clear(); if ((int) embd_inp.size() <= n_consumed && !is_interacting) { - // optionally save the session on first sample (for faster prompt loading next time) - if (session_do_save) { - session_do_save = false; - llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size()); - - LOG_DBG("saved session to %s\n", path_session.c_str()); - } const llama_token id = common_sampler_sample(smpl, ctx, -1); From 295c6a063928cd2f7f6fe656d93a1d3061d976fe Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Thu, 5 Feb 2026 09:40:07 +0100 Subject: [PATCH 4/9] update save-state.cpp to use llama_state_load_file This commit updates the save-load-state example to utilize the new llama_state_load_file function for loading the model state from a file. And it also replays the last token after loading since this state is now stored before the last token is processed. I'm not sure if this is acceptable or not, as it does change the example to not directly use llama_state_get_data and llama_state_set_data for loading which might have been the point of the example. --- examples/save-load-state/save-load-state.cpp | 99 +++++++++----------- 1 file changed, 42 insertions(+), 57 deletions(-) diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index 39d4464663..8b111a2ca8 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -2,15 +2,30 @@ #include "common.h" #include "llama.h" +#include #include #include +static bool replay_last_token(llama_context * ctx, llama_token last_token, int & n_past) { + llama_batch batch = llama_batch_get_one(&last_token, 1); + int pos = n_past; + batch.pos = &pos; + if (llama_decode(ctx, batch)) { + fprintf(stderr, "%s: failed to replay last token after loading state\n", __func__); + return false; + } + ++n_past; + return true; +} + int main(int argc, char ** argv) { common_params params; params.prompt = "The quick brown fox"; params.sampling.seed = 1234; + std::filesystem::path state_file = "dump_state.bin"; + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) { return 1; } @@ -53,35 +68,16 @@ int main(int argc, char ** argv) { // tokenize prompt auto tokens = common_tokenize(ctx, params.prompt, true); - // prepare the batch - llama_batch batch = llama_batch_init(tokens.size(), 0, 1); - for (size_t i = 0; i < tokens.size(); i++) { - common_batch_add(batch, tokens[i], i, {0}, false); + const bool save_state = true; + if (!common_prompt_batch_decode(ctx, tokens, n_past, params.n_batch, state_file, save_state)) { + return 1; } - batch.logits[batch.n_tokens - 1] = true; // generate next token - - // evaluate prompt - llama_decode(ctx, batch); - n_past += batch.n_tokens; - - // save state (rng, logits, embedding and kv_cache) to file - { - std::vector state_mem(llama_state_get_size(ctx)); - const size_t written = llama_state_get_data(ctx, state_mem.data(), state_mem.size()); - - FILE *fp_write = fopen("dump_state.bin", "wb"); - fwrite(state_mem.data(), 1, written, fp_write); - fclose(fp_write); - - fprintf(stderr, "%s : serialized state into %zd out of a maximum of %zd bytes\n", __func__, written, state_mem.size()); - } - - // save state (last tokens) - const auto n_past_saved = n_past; // first run printf("\nfirst run: %s", params.prompt.c_str()); + llama_batch batch = llama_batch_init(1, 0, 1); + for (auto i = 0; i < params.n_predict; i++) { auto next_token = llama_sampler_sample(smpl, ctx, -1); auto next_token_str = common_token_to_piece(ctx, next_token); @@ -111,27 +107,22 @@ int main(int argc, char ** argv) { printf("\nsecond run: %s", params.prompt.c_str()); - // load state (rng, logits, embedding and kv_cache) from file - { - std::vector state_mem; + // load state from file + std::vector unused_sts(tokens.size()); // unused session tokens. + size_t n_token_count_out = 0; - FILE * fp_read = fopen("dump_state.bin", "rb"); - fseek(fp_read, 0, SEEK_END); - state_mem.resize(ftell(fp_read)); - fseek(fp_read, 0, SEEK_SET); - const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read); - fclose(fp_read); - - if (read != llama_state_set_data(ctx2, state_mem.data(), state_mem.size())) { - fprintf(stderr, "\n%s : failed to read state\n", __func__); - return 1; - } - - fprintf(stderr, "%s : deserialized state from %zd out of a maximum of %zd bytes\n", __func__, read, state_mem.size()); + if (!llama_state_load_file(ctx2, state_file.string().c_str(), unused_sts.data(), unused_sts.size(), &n_token_count_out)) { + fprintf(stderr, "\n%s : failed to load state\n", __func__); + return 1; } + fprintf(stderr, "%s : loaded state with %zu tokens\n", __func__, n_token_count_out); + // restore state (last tokens) - n_past = n_past_saved; + n_past = n_token_count_out; + if (!replay_last_token(ctx2, tokens.back(), n_past)) { + return 1; + } // second run for (auto i = 0; i < params.n_predict; i++) { @@ -169,26 +160,20 @@ int main(int argc, char ** argv) { printf("\nsingle seq run: %s", params.prompt.c_str()); // load state (rng, logits, embedding and kv_cache) from file - { - std::vector state_mem; + n_token_count_out = 0; - FILE * fp_read = fopen("dump_state.bin", "rb"); - fseek(fp_read, 0, SEEK_END); - state_mem.resize(ftell(fp_read)); - fseek(fp_read, 0, SEEK_SET); - const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read); - fclose(fp_read); - - if (read != llama_state_set_data(ctx3, state_mem.data(), state_mem.size())) { - fprintf(stderr, "\n%s : failed to read state\n", __func__); - return 1; - } - - fprintf(stderr, "%s : deserialized state from %zd out of a maximum of %zd bytes\n", __func__, read, state_mem.size()); + if (!llama_state_load_file(ctx3, state_file.string().c_str(), unused_sts.data(), unused_sts.size(), &n_token_count_out)) { + fprintf(stderr, "\n%s : failed to load state\n", __func__); + return 1; } + fprintf(stderr, "%s : loaded state with %zu tokens\n", __func__, n_token_count_out); + // restore state (last tokens) - n_past = n_past_saved; + n_past = n_token_count_out; + if (!replay_last_token(ctx3, tokens.back(), n_past)) { + return 1; + } // save seq 0 and load into seq 1 { From a70867c19e286f445f62fbaf3a14d432f11112e4 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Fri, 6 Feb 2026 13:15:20 +0100 Subject: [PATCH 5/9] examples : set n_seq_max = 2 for ctx3 This commit updates the save-load-state example to set the n_seq_max parameter to 2 when initializing the ctx3 context. The motivation for this change is that using 1 as n_parallel/n_seq_max the context only supports one sequence, but the test laster tries to use a second sequence which results in the following error: ```console main : loaded state with 4 tokens main : seq 0 copied, 225760 bytes main : kv cache cleared find_slot: seq_id=1 >= n_seq_max=1 Try using a bigger --parallel value state_read_meta: failed to find available cells in kv cache ``` This seems to only happen for recurrent/hybrid models. --- examples/save-load-state/save-load-state.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index 8b111a2ca8..c315933163 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -151,7 +151,9 @@ int main(int argc, char ** argv) { } // make new context - llama_context * ctx3 = llama_init_from_model(model, common_context_params_to_llama(params)); + auto params_ctx3 = common_context_params_to_llama(params); + params_ctx3.n_seq_max = 2; + llama_context * ctx3 = llama_init_from_model(model, params_ctx3); llama_sampler * smpl3 = llama_sampler_chain_init(sparams); From d9a23126bf1b081ae408695a1db74d66c3b301e1 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Wed, 11 Feb 2026 13:22:50 +0100 Subject: [PATCH 6/9] common : extract replay_last_token to common.h This commit extracts the replay_last_token function from save-load-state.cpp to common.h. The motivation for this is to allow reuse of the function but also to clarify the intent of code that replays the last token after loading the session state. --- common/common.cpp | 10 ++++++++++ common/common.h | 4 ++++ examples/save-load-state/save-load-state.cpp | 17 ++++------------- tools/completion/completion.cpp | 8 +------- 4 files changed, 19 insertions(+), 20 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 383379cb5b..55e5cb3372 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1789,6 +1789,16 @@ float lr_opt::get_lr(float epoch) const { return r; } +bool common_replay_last_token(struct llama_context * ctx, llama_token last_token, int32_t pos) { + llama_batch batch = llama_batch_get_one(&last_token, 1); + batch.pos = &pos; + if (llama_decode(ctx, batch)) { + LOG_ERR("%s: failed to replay last token\n", __func__); + return false; + } + return true; +} + bool common_prompt_batch_decode( struct llama_context * ctx, const std::vector & tokens, diff --git a/common/common.h b/common/common.h index c7de7d7ecb..b1124d92b6 100644 --- a/common/common.h +++ b/common/common.h @@ -794,6 +794,10 @@ bool common_prompt_batch_decode( bool save_state, bool is_last_batch = true); +// replays the last token after loading state to regenerate logits +// used after loading session state to ensure the sampling context has valid logits +bool common_replay_last_token(struct llama_context * ctx, llama_token last_token, int32_t pos); + // // Vocab utils // diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index c315933163..a949e1b643 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -6,17 +6,6 @@ #include #include -static bool replay_last_token(llama_context * ctx, llama_token last_token, int & n_past) { - llama_batch batch = llama_batch_get_one(&last_token, 1); - int pos = n_past; - batch.pos = &pos; - if (llama_decode(ctx, batch)) { - fprintf(stderr, "%s: failed to replay last token after loading state\n", __func__); - return false; - } - ++n_past; - return true; -} int main(int argc, char ** argv) { common_params params; @@ -120,9 +109,10 @@ int main(int argc, char ** argv) { // restore state (last tokens) n_past = n_token_count_out; - if (!replay_last_token(ctx2, tokens.back(), n_past)) { + if (!common_replay_last_token(ctx2, tokens.back(), n_past)) { return 1; } + ++n_past; // second run for (auto i = 0; i < params.n_predict; i++) { @@ -173,9 +163,10 @@ int main(int argc, char ** argv) { // restore state (last tokens) n_past = n_token_count_out; - if (!replay_last_token(ctx3, tokens.back(), n_past)) { + if (!common_replay_last_token(ctx3, tokens.back(), n_past)) { return 1; } + ++n_past; // save seq 0 and load into seq 1 { diff --git a/tools/completion/completion.cpp b/tools/completion/completion.cpp index 7607bc1e01..d1acbbc538 100644 --- a/tools/completion/completion.cpp +++ b/tools/completion/completion.cpp @@ -391,13 +391,7 @@ int main(int argc, char ** argv) { // Logits are not stored as part of the session state so we need to // "replay" the last token to get logits for sampling. if (!session_tokens.empty() && n_match > 0 && n_match == session_tokens.size()) { - llama_token last_token = session_tokens.back(); - int32_t pos = n_match; - - llama_batch batch = llama_batch_get_one(&last_token, 1); - batch.pos = &pos; - if (llama_decode(ctx, batch)) { - LOG_ERR("%s: failed to regenerate logits after loading state\n", __func__); + if (!common_replay_last_token(ctx, session_tokens.back(), n_match)) { return 1; } From 7902ae7380ef1845d575d655ec05adcddc0a8473 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Fri, 13 Feb 2026 12:11:53 +0100 Subject: [PATCH 7/9] commmon : use string_view instead of std::filesystem::path --- common/common.cpp | 6 +++--- common/common.h | 3 +-- examples/save-load-state/save-load-state.cpp | 7 +++---- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 55e5cb3372..b1382e0644 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1804,7 +1804,7 @@ bool common_prompt_batch_decode( const std::vector & tokens, int & n_past, int n_batch, - const std::filesystem::path & state_path, + const std::string_view & state_path, bool save_state, bool is_last_batch) { const int n_eval = tokens.size(); @@ -1828,8 +1828,8 @@ bool common_prompt_batch_decode( } n_past += n_tokens_before_last; - llama_state_save_file(ctx, state_path.string().c_str(), tokens.data(), n_tokens_before_last); - LOG_INF("saved session before last token to %s, n_tokens = %d\n", state_path.string().c_str(), n_tokens_before_last); + llama_state_save_file(ctx, state_path.data(), tokens.data(), n_tokens_before_last); + LOG_INF("saved session before last token to %s, n_tokens = %d\n", state_path.data(), n_tokens_before_last); llama_token last_token = tokens.back(); llama_batch batch = llama_batch_get_one(&last_token, 1); diff --git a/common/common.h b/common/common.h index b1124d92b6..5b3e5bdc5b 100644 --- a/common/common.h +++ b/common/common.h @@ -5,7 +5,6 @@ #include "ggml-opt.h" #include "llama-cpp.h" -#include #include #include #include @@ -790,7 +789,7 @@ bool common_prompt_batch_decode( const std::vector & embd, int & n_past, int n_batch, - const std::filesystem::path & state_path, + const std::string_view & state_path, bool save_state, bool is_last_batch = true); diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index a949e1b643..5e35dcd603 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -2,7 +2,6 @@ #include "common.h" #include "llama.h" -#include #include #include @@ -13,7 +12,7 @@ int main(int argc, char ** argv) { params.prompt = "The quick brown fox"; params.sampling.seed = 1234; - std::filesystem::path state_file = "dump_state.bin"; + const std::string_view state_file = "dump_state.bin"; if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) { return 1; @@ -100,7 +99,7 @@ int main(int argc, char ** argv) { std::vector unused_sts(tokens.size()); // unused session tokens. size_t n_token_count_out = 0; - if (!llama_state_load_file(ctx2, state_file.string().c_str(), unused_sts.data(), unused_sts.size(), &n_token_count_out)) { + if (!llama_state_load_file(ctx2, state_file.data(), unused_sts.data(), unused_sts.size(), &n_token_count_out)) { fprintf(stderr, "\n%s : failed to load state\n", __func__); return 1; } @@ -154,7 +153,7 @@ int main(int argc, char ** argv) { // load state (rng, logits, embedding and kv_cache) from file n_token_count_out = 0; - if (!llama_state_load_file(ctx3, state_file.string().c_str(), unused_sts.data(), unused_sts.size(), &n_token_count_out)) { + if (!llama_state_load_file(ctx3, state_file.data(), unused_sts.data(), unused_sts.size(), &n_token_count_out)) { fprintf(stderr, "\n%s : failed to load state\n", __func__); return 1; } From d9a6e49844a0eddee7c2c52729c5dc1a78076134 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Fri, 13 Feb 2026 12:43:50 +0100 Subject: [PATCH 8/9] remove is_last_batch parameter from common_prompt_batch_decode --- common/common.cpp | 5 ++--- common/common.h | 3 +-- tools/completion/completion.cpp | 3 ++- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index b1382e0644..615ce81c62 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1805,14 +1805,13 @@ bool common_prompt_batch_decode( int & n_past, int n_batch, const std::string_view & state_path, - bool save_state, - bool is_last_batch) { + bool save_state) { const int n_eval = tokens.size(); if (n_eval == 0) { return true; } - if (save_state && is_last_batch && n_eval > 1) { + if (save_state && n_eval > 1) { const int n_tokens_before_last = n_eval - 1; GGML_ASSERT(n_eval <= n_batch); diff --git a/common/common.h b/common/common.h index 5b3e5bdc5b..d7ced8a2ca 100644 --- a/common/common.h +++ b/common/common.h @@ -790,8 +790,7 @@ bool common_prompt_batch_decode( int & n_past, int n_batch, const std::string_view & state_path, - bool save_state, - bool is_last_batch = true); + bool save_state); // replays the last token after loading state to regenerate logits // used after loading session state to ensure the sampling context has valid logits diff --git a/tools/completion/completion.cpp b/tools/completion/completion.cpp index d1acbbc538..aed2c0e38f 100644 --- a/tools/completion/completion.cpp +++ b/tools/completion/completion.cpp @@ -687,7 +687,8 @@ int main(int argc, char ** argv) { if (!embd.empty()) { const bool is_last_batch = (n_consumed >= (int) embd_inp.size()); - if (!common_prompt_batch_decode(ctx, embd, n_past, params.n_batch, path_session, session_do_save, is_last_batch)) { + const bool save_now = session_do_save && is_last_batch; + if (!common_prompt_batch_decode(ctx, embd, n_past, params.n_batch, path_session, save_now)) { return 1; } session_tokens.insert(session_tokens.end(), embd.begin(), embd.begin()); From 43f301fa730fa70dccc016750ab2a8ed5fe503e0 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Fri, 13 Feb 2026 15:59:54 +0100 Subject: [PATCH 9/9] pass std::string_view by value --- common/common.cpp | 2 +- common/common.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 615ce81c62..e6286cb071 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1804,7 +1804,7 @@ bool common_prompt_batch_decode( const std::vector & tokens, int & n_past, int n_batch, - const std::string_view & state_path, + std::string_view state_path, bool save_state) { const int n_eval = tokens.size(); if (n_eval == 0) { diff --git a/common/common.h b/common/common.h index d7ced8a2ca..50a6b27490 100644 --- a/common/common.h +++ b/common/common.h @@ -789,7 +789,7 @@ bool common_prompt_batch_decode( const std::vector & embd, int & n_past, int n_batch, - const std::string_view & state_path, + std::string_view state_path, bool save_state); // replays the last token after loading state to regenerate logits