From 61693b5bda8f67181a0e6aa7f633bb956fdacb26 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Thu, 5 Feb 2026 09:40:07 +0100 Subject: [PATCH] update save-state.cpp to use llama_state_load_file This commit updates the save-load-state example to utilize the new llama_state_load_file function for loading the model state from a file. And it also replays the last token after loading since this state is now stored before the last token is processed. I'm not sure if this is acceptable or not, as it does change the example to not directly use llama_state_get_data and llama_state_set_data for loading which might have been the point of the example. --- examples/save-load-state/save-load-state.cpp | 99 +++++++++----------- 1 file changed, 42 insertions(+), 57 deletions(-) diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index 39d4464663..8b111a2ca8 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -2,15 +2,30 @@ #include "common.h" #include "llama.h" +#include #include #include +static bool replay_last_token(llama_context * ctx, llama_token last_token, int & n_past) { + llama_batch batch = llama_batch_get_one(&last_token, 1); + int pos = n_past; + batch.pos = &pos; + if (llama_decode(ctx, batch)) { + fprintf(stderr, "%s: failed to replay last token after loading state\n", __func__); + return false; + } + ++n_past; + return true; +} + int main(int argc, char ** argv) { common_params params; params.prompt = "The quick brown fox"; params.sampling.seed = 1234; + std::filesystem::path state_file = "dump_state.bin"; + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) { return 1; } @@ -53,35 +68,16 @@ int main(int argc, char ** argv) { // tokenize prompt auto tokens = common_tokenize(ctx, params.prompt, true); - // prepare the batch - llama_batch batch = llama_batch_init(tokens.size(), 0, 1); - for (size_t i = 0; i < tokens.size(); i++) { - common_batch_add(batch, tokens[i], i, {0}, false); + const bool save_state = true; + if (!common_prompt_batch_decode(ctx, tokens, n_past, params.n_batch, state_file, save_state)) { + return 1; } - batch.logits[batch.n_tokens - 1] = true; // generate next token - - // evaluate prompt - llama_decode(ctx, batch); - n_past += batch.n_tokens; - - // save state (rng, logits, embedding and kv_cache) to file - { - std::vector state_mem(llama_state_get_size(ctx)); - const size_t written = llama_state_get_data(ctx, state_mem.data(), state_mem.size()); - - FILE *fp_write = fopen("dump_state.bin", "wb"); - fwrite(state_mem.data(), 1, written, fp_write); - fclose(fp_write); - - fprintf(stderr, "%s : serialized state into %zd out of a maximum of %zd bytes\n", __func__, written, state_mem.size()); - } - - // save state (last tokens) - const auto n_past_saved = n_past; // first run printf("\nfirst run: %s", params.prompt.c_str()); + llama_batch batch = llama_batch_init(1, 0, 1); + for (auto i = 0; i < params.n_predict; i++) { auto next_token = llama_sampler_sample(smpl, ctx, -1); auto next_token_str = common_token_to_piece(ctx, next_token); @@ -111,27 +107,22 @@ int main(int argc, char ** argv) { printf("\nsecond run: %s", params.prompt.c_str()); - // load state (rng, logits, embedding and kv_cache) from file - { - std::vector state_mem; + // load state from file + std::vector unused_sts(tokens.size()); // unused session tokens. + size_t n_token_count_out = 0; - FILE * fp_read = fopen("dump_state.bin", "rb"); - fseek(fp_read, 0, SEEK_END); - state_mem.resize(ftell(fp_read)); - fseek(fp_read, 0, SEEK_SET); - const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read); - fclose(fp_read); - - if (read != llama_state_set_data(ctx2, state_mem.data(), state_mem.size())) { - fprintf(stderr, "\n%s : failed to read state\n", __func__); - return 1; - } - - fprintf(stderr, "%s : deserialized state from %zd out of a maximum of %zd bytes\n", __func__, read, state_mem.size()); + if (!llama_state_load_file(ctx2, state_file.string().c_str(), unused_sts.data(), unused_sts.size(), &n_token_count_out)) { + fprintf(stderr, "\n%s : failed to load state\n", __func__); + return 1; } + fprintf(stderr, "%s : loaded state with %zu tokens\n", __func__, n_token_count_out); + // restore state (last tokens) - n_past = n_past_saved; + n_past = n_token_count_out; + if (!replay_last_token(ctx2, tokens.back(), n_past)) { + return 1; + } // second run for (auto i = 0; i < params.n_predict; i++) { @@ -169,26 +160,20 @@ int main(int argc, char ** argv) { printf("\nsingle seq run: %s", params.prompt.c_str()); // load state (rng, logits, embedding and kv_cache) from file - { - std::vector state_mem; + n_token_count_out = 0; - FILE * fp_read = fopen("dump_state.bin", "rb"); - fseek(fp_read, 0, SEEK_END); - state_mem.resize(ftell(fp_read)); - fseek(fp_read, 0, SEEK_SET); - const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read); - fclose(fp_read); - - if (read != llama_state_set_data(ctx3, state_mem.data(), state_mem.size())) { - fprintf(stderr, "\n%s : failed to read state\n", __func__); - return 1; - } - - fprintf(stderr, "%s : deserialized state from %zd out of a maximum of %zd bytes\n", __func__, read, state_mem.size()); + if (!llama_state_load_file(ctx3, state_file.string().c_str(), unused_sts.data(), unused_sts.size(), &n_token_count_out)) { + fprintf(stderr, "\n%s : failed to load state\n", __func__); + return 1; } + fprintf(stderr, "%s : loaded state with %zu tokens\n", __func__, n_token_count_out); + // restore state (last tokens) - n_past = n_past_saved; + n_past = n_token_count_out; + if (!replay_last_token(ctx3, tokens.back(), n_past)) { + return 1; + } // save seq 0 and load into seq 1 {