diff --git a/tools/server/README.md b/tools/server/README.md index bf274db79d..9deb241b07 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -495,6 +495,8 @@ By default, this value is set to `0`, meaning no tokens are kept. Use `-1` to re `n_cmpl`: Number of completions to generate from the current prompt. If input has multiple prompts, the output will have N prompts times `n_cmpl` entries. +`n_cache_reuse`: Min chunk size to attempt reusing from the cache via KV shifting. For more info, see `--cache-reuse` arg. Default: `0`, which is disabled. + `stream`: Allows receiving each predicted token in real-time instead of waiting for the completion to finish (uses a different response format). To enable this, set to `true`. `stop`: Specify a JSON array of stopping strings. diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 12a4e94e5d..d0039631d4 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -1880,8 +1880,18 @@ struct server_context_impl { n_past = std::min(n_past, slot.alora_invocation_start - 1); } + const auto n_cache_reuse = slot.task->params.n_cache_reuse; + + const bool can_cache_reuse = + llama_memory_can_shift(llama_get_memory(ctx)) && + !slot.prompt.tokens.has_mtmd; + + if (!can_cache_reuse && n_cache_reuse > 0) { + SLT_WRN(slot, "cache reuse is not supported - ignoring n_cache_reuse = %d\n", n_cache_reuse); + } + // reuse chunks from the cached prompt by shifting their KV cache in the new position - if (params_base.n_cache_reuse > 0) { + if (can_cache_reuse && n_cache_reuse > 0) { GGML_ASSERT(!slot.prompt.tokens.has_mtmd); size_t head_c = n_past; // cache @@ -1892,7 +1902,7 @@ struct server_context_impl { GGML_ABORT("not supported by multimodal"); } - SLT_DBG(slot, "trying to reuse chunks with size > %d, n_past = %d\n", params_base.n_cache_reuse, n_past); + SLT_DBG(slot, "trying to reuse chunks with size > %d, n_past = %d\n", n_cache_reuse, n_past); while (head_c < slot.prompt.tokens.size() && head_p < input_tokens.size()) { @@ -1901,11 +1911,10 @@ struct server_context_impl { while (head_c + n_match < slot.prompt.tokens.size() && head_p + n_match < input_tokens.size() && slot.prompt.tokens[head_c + n_match] == input_tokens[head_p + n_match]) { - n_match++; } - if (n_match >= (size_t) params_base.n_cache_reuse) { + if (n_match >= (size_t) n_cache_reuse) { SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match); //for (size_t i = head_p; i < head_p + n_match; i++) { // SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index c401f47a78..360826062b 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -155,11 +155,12 @@ task_params server_task::params_from_json_cmpl( // Sampling parameter defaults are loaded from the global server context (but individual requests can still them) task_params defaults; - defaults.sampling = params_base.sampling; - defaults.speculative = params_base.speculative; - defaults.n_keep = params_base.n_keep; - defaults.n_predict = params_base.n_predict; - defaults.antiprompt = params_base.antiprompt; + defaults.sampling = params_base.sampling; + defaults.speculative = params_base.speculative; + defaults.n_keep = params_base.n_keep; + defaults.n_predict = params_base.n_predict; + defaults.n_cache_reuse = params_base.n_cache_reuse; + defaults.antiprompt = params_base.antiprompt; // enabling this will output extra debug information in the HTTP responses from the server params.verbose = params_base.verbosity > 9; @@ -176,6 +177,7 @@ task_params server_task::params_from_json_cmpl( params.n_keep = json_value(data, "n_keep", defaults.n_keep); params.n_discard = json_value(data, "n_discard", defaults.n_discard); params.n_cmpl = json_value(data, "n_cmpl", json_value(data, "n", 1)); + params.n_cache_reuse = json_value(data, "n_cache_reuse", defaults.n_cache_reuse); //params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms); params.response_fields = json_value(data, "response_fields", std::vector()); diff --git a/tools/server/server-task.h b/tools/server/server-task.h index 4e4840fc83..da4e22a7cd 100644 --- a/tools/server/server-task.h +++ b/tools/server/server-task.h @@ -55,6 +55,8 @@ struct task_params { int32_t n_indent = 0; // minimum line indentation for the generated text in number of whitespace characters int32_t n_cmpl = 1; // number of completions to generate from this prompt + int32_t n_cache_reuse = 0; // min chunk size to attempt reusing from the cache via KV shifting (0 = disabled) + int64_t t_max_prompt_ms = -1; // TODO: implement int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit @@ -62,18 +64,19 @@ struct task_params { std::vector antiprompt; std::vector response_fields; - bool timings_per_token = false; + + bool timings_per_token = false; bool post_sampling_probs = false; struct common_params_sampling sampling; struct common_params_speculative speculative; // response formatting - bool verbose = false; - task_response_type res_type = TASK_RESPONSE_TYPE_NONE; - std::string oaicompat_model; - std::string oaicompat_cmpl_id; - common_chat_syntax oaicompat_chat_syntax; + bool verbose = false; + task_response_type res_type = TASK_RESPONSE_TYPE_NONE; + std::string oaicompat_model; + std::string oaicompat_cmpl_id; + common_chat_syntax oaicompat_chat_syntax; // Embeddings int32_t embd_normalize = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm)