diff --git a/common/arg.cpp b/common/arg.cpp index f2675f842a..4b96c312f3 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2877,10 +2877,18 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.n_threads_http = value; } ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_THREADS_HTTP")); + add_opt(common_arg( + {"--cache-prompt"}, + {"--no-cache-prompt"}, + string_format("whether to enable prompt caching (default: %s)", params.cache_prompt ? "enabled" : "disabled"), + [](common_params & params, bool value) { + params.cache_prompt = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CACHE_PROMPT")); add_opt(common_arg( {"--cache-reuse"}, "N", string_format( - "min chunk size to attempt reusing from the cache via KV shifting (default: %d)\n" + "min chunk size to attempt reusing from the cache via KV shifting, requires prompt caching to be enabled (default: %d)\n" "[(card)](https://ggml.ai/f0.png)", params.n_cache_reuse ), [](common_params & params, int value) { diff --git a/common/common.h b/common/common.h index b3ac04c4ae..e60087dea3 100644 --- a/common/common.h +++ b/common/common.h @@ -476,6 +476,7 @@ struct common_params { int32_t timeout_write = timeout_read; // http write timeout in seconds int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool) int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting + bool cache_prompt = true; // whether to enable prompt caching int32_t n_ctx_checkpoints = 8; // max number of context checkpoints per slot int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc. diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index ed4f6546ea..aa4590e4ec 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -160,6 +160,7 @@ task_params server_task::params_from_json_cmpl( defaults.n_keep = params_base.n_keep; defaults.n_predict = params_base.n_predict; defaults.n_cache_reuse = params_base.n_cache_reuse; + defaults.cache_prompt = params_base.cache_prompt; defaults.antiprompt = params_base.antiprompt; // enabling this will output extra debug information in the HTTP responses from the server @@ -169,7 +170,7 @@ task_params server_task::params_from_json_cmpl( params.stream = json_value(data, "stream", false); auto stream_opt = json_value(data, "stream_options", json::object()); params.include_usage = json_value(stream_opt, "include_usage", false); - params.cache_prompt = json_value(data, "cache_prompt", true); + params.cache_prompt = json_value(data, "cache_prompt", defaults.cache_prompt); params.return_tokens = json_value(data, "return_tokens", false); params.return_progress = json_value(data, "return_progress", false); params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict));