diff --git a/common/common.cpp b/common/common.cpp index 55e5cb3372..b1382e0644 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1804,7 +1804,7 @@ bool common_prompt_batch_decode( const std::vector & tokens, int & n_past, int n_batch, - const std::filesystem::path & state_path, + const std::string_view & state_path, bool save_state, bool is_last_batch) { const int n_eval = tokens.size(); @@ -1828,8 +1828,8 @@ bool common_prompt_batch_decode( } n_past += n_tokens_before_last; - llama_state_save_file(ctx, state_path.string().c_str(), tokens.data(), n_tokens_before_last); - LOG_INF("saved session before last token to %s, n_tokens = %d\n", state_path.string().c_str(), n_tokens_before_last); + llama_state_save_file(ctx, state_path.data(), tokens.data(), n_tokens_before_last); + LOG_INF("saved session before last token to %s, n_tokens = %d\n", state_path.data(), n_tokens_before_last); llama_token last_token = tokens.back(); llama_batch batch = llama_batch_get_one(&last_token, 1); diff --git a/common/common.h b/common/common.h index b1124d92b6..5b3e5bdc5b 100644 --- a/common/common.h +++ b/common/common.h @@ -5,7 +5,6 @@ #include "ggml-opt.h" #include "llama-cpp.h" -#include #include #include #include @@ -790,7 +789,7 @@ bool common_prompt_batch_decode( const std::vector & embd, int & n_past, int n_batch, - const std::filesystem::path & state_path, + const std::string_view & state_path, bool save_state, bool is_last_batch = true); diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index a949e1b643..5e35dcd603 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -2,7 +2,6 @@ #include "common.h" #include "llama.h" -#include #include #include @@ -13,7 +12,7 @@ int main(int argc, char ** argv) { params.prompt = "The quick brown fox"; params.sampling.seed = 1234; - std::filesystem::path state_file = "dump_state.bin"; + const std::string_view state_file = "dump_state.bin"; if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) { return 1; @@ -100,7 +99,7 @@ int main(int argc, char ** argv) { std::vector unused_sts(tokens.size()); // unused session tokens. size_t n_token_count_out = 0; - if (!llama_state_load_file(ctx2, state_file.string().c_str(), unused_sts.data(), unused_sts.size(), &n_token_count_out)) { + if (!llama_state_load_file(ctx2, state_file.data(), unused_sts.data(), unused_sts.size(), &n_token_count_out)) { fprintf(stderr, "\n%s : failed to load state\n", __func__); return 1; } @@ -154,7 +153,7 @@ int main(int argc, char ** argv) { // load state (rng, logits, embedding and kv_cache) from file n_token_count_out = 0; - if (!llama_state_load_file(ctx3, state_file.string().c_str(), unused_sts.data(), unused_sts.size(), &n_token_count_out)) { + if (!llama_state_load_file(ctx3, state_file.data(), unused_sts.data(), unused_sts.size(), &n_token_count_out)) { fprintf(stderr, "\n%s : failed to load state\n", __func__); return 1; }