diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index b48373df82..3fa9253c3c 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -3092,6 +3092,66 @@ llama_context * server_context::get_llama_context() const { return impl->ctx; } +void server_context::auto_save_slots() { + const auto & params = impl->params_base; + if (params.slot_save_path.empty()) { + return; + } + + for (auto & slot : impl->slots) { + if (slot.prompt.tokens.size() == 0) { + continue; + } + + const std::string model_stem = std::filesystem::path(params.model.path).stem().string(); + const std::string filepath = params.slot_save_path + "/" + model_stem; + + const llama_tokens & tokens = slot.prompt.tokens.get_text_tokens(); + const size_t token_count = slot.prompt.tokens.size(); + const size_t nwrite = llama_state_seq_save_file(impl->ctx, filepath.c_str(), slot.id, tokens.data(), token_count); + + slot_checkpoints_save(filepath, slot.prompt.checkpoints); + + SRV_INF("auto-saved slot %d (%zu tokens, %.1f MiB) to %s\n", + slot.id, token_count, (float) nwrite / (1024.0f * 1024.0f), filepath.c_str()); + } +} + +void server_context::auto_restore_slots() { + const auto & params = impl->params_base; + if (params.slot_save_path.empty()) { + return; + } + + const std::string model_stem = std::filesystem::path(params.model.path).stem().string(); + const std::string filepath = params.slot_save_path + "/" + model_stem; + + if (!std::filesystem::exists(filepath)) { + return; + } + + for (auto & slot : impl->slots) { + llama_tokens tokens; + tokens.resize(slot.n_ctx); + size_t token_count = 0; + const size_t nread = llama_state_seq_load_file(impl->ctx, filepath.c_str(), slot.id, tokens.data(), tokens.size(), &token_count); + + if (nread == 0) { + SRV_WRN("auto-restore failed for slot %d from %s\n", slot.id, filepath.c_str()); + continue; + } + + tokens.resize(token_count); + slot.prompt.tokens.clear(); + slot.prompt.tokens.insert(tokens); + + slot_checkpoints_load(filepath, slot.prompt.checkpoints); + + SRV_INF("auto-restored slot %d (%zu tokens, %.1f MiB) from %s\n", + slot.id, token_count, (float) nread / (1024.0f * 1024.0f), filepath.c_str()); + } +} + server_response_reader server_context::get_response_reader() { return impl->get_response_reader(); } diff --git a/tools/server/server-context.h b/tools/server/server-context.h index 75f3d2de56..e63220a6ba 100644 --- a/tools/server/server-context.h +++ b/tools/server/server-context.h @@ -64,6 +64,11 @@ struct server_context { // terminate main loop (will unblock start_loop) void terminate(); + // auto-save/restore slot state for seamless model hot-swapping in router mode + // requires --slot-save-path to be set + void auto_save_slots(); + void auto_restore_slots(); + // get the underlaying llama_context, can return nullptr if sleeping // not thread-safe, should only be used from the main thread llama_context * get_llama_context() const; diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 0bd6fda17d..adbdc77bb8 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -273,6 +273,9 @@ int main(int argc, char ** argv) { LOG_INF("%s: model loaded\n", __func__); + // in router mode, restore previously saved slot state for this model + ctx_server.auto_restore_slots(); + shutdown_handler = [&](int) { // this will unblock start_loop() ctx_server.terminate(); @@ -318,6 +321,9 @@ int main(int argc, char ** argv) { // this call blocks the main thread until queue_tasks.terminate() is called ctx_server.start_loop(); + // in router mode, save slot state before exit so it can be restored on reload + ctx_server.auto_save_slots(); + clean_up(); if (ctx_http.thread.joinable()) { ctx_http.thread.join();