server : make cache_reuse configurable per request (#17858)
This commit is contained in:
parent
5814b4dce1
commit
2bc96931d2
|
|
@ -495,6 +495,8 @@ By default, this value is set to `0`, meaning no tokens are kept. Use `-1` to re
|
||||||
|
|
||||||
`n_cmpl`: Number of completions to generate from the current prompt. If input has multiple prompts, the output will have N prompts times `n_cmpl` entries.
|
`n_cmpl`: Number of completions to generate from the current prompt. If input has multiple prompts, the output will have N prompts times `n_cmpl` entries.
|
||||||
|
|
||||||
|
`n_cache_reuse`: Min chunk size to attempt reusing from the cache via KV shifting. For more info, see `--cache-reuse` arg. Default: `0`, which is disabled.
|
||||||
|
|
||||||
`stream`: Allows receiving each predicted token in real-time instead of waiting for the completion to finish (uses a different response format). To enable this, set to `true`.
|
`stream`: Allows receiving each predicted token in real-time instead of waiting for the completion to finish (uses a different response format). To enable this, set to `true`.
|
||||||
|
|
||||||
`stop`: Specify a JSON array of stopping strings.
|
`stop`: Specify a JSON array of stopping strings.
|
||||||
|
|
|
||||||
|
|
@ -1880,8 +1880,18 @@ struct server_context_impl {
|
||||||
n_past = std::min(n_past, slot.alora_invocation_start - 1);
|
n_past = std::min(n_past, slot.alora_invocation_start - 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const auto n_cache_reuse = slot.task->params.n_cache_reuse;
|
||||||
|
|
||||||
|
const bool can_cache_reuse =
|
||||||
|
llama_memory_can_shift(llama_get_memory(ctx)) &&
|
||||||
|
!slot.prompt.tokens.has_mtmd;
|
||||||
|
|
||||||
|
if (!can_cache_reuse && n_cache_reuse > 0) {
|
||||||
|
SLT_WRN(slot, "cache reuse is not supported - ignoring n_cache_reuse = %d\n", n_cache_reuse);
|
||||||
|
}
|
||||||
|
|
||||||
// reuse chunks from the cached prompt by shifting their KV cache in the new position
|
// reuse chunks from the cached prompt by shifting their KV cache in the new position
|
||||||
if (params_base.n_cache_reuse > 0) {
|
if (can_cache_reuse && n_cache_reuse > 0) {
|
||||||
GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
|
GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
|
||||||
|
|
||||||
size_t head_c = n_past; // cache
|
size_t head_c = n_past; // cache
|
||||||
|
|
@ -1892,7 +1902,7 @@ struct server_context_impl {
|
||||||
GGML_ABORT("not supported by multimodal");
|
GGML_ABORT("not supported by multimodal");
|
||||||
}
|
}
|
||||||
|
|
||||||
SLT_DBG(slot, "trying to reuse chunks with size > %d, n_past = %d\n", params_base.n_cache_reuse, n_past);
|
SLT_DBG(slot, "trying to reuse chunks with size > %d, n_past = %d\n", n_cache_reuse, n_past);
|
||||||
|
|
||||||
while (head_c < slot.prompt.tokens.size() &&
|
while (head_c < slot.prompt.tokens.size() &&
|
||||||
head_p < input_tokens.size()) {
|
head_p < input_tokens.size()) {
|
||||||
|
|
@ -1901,11 +1911,10 @@ struct server_context_impl {
|
||||||
while (head_c + n_match < slot.prompt.tokens.size() &&
|
while (head_c + n_match < slot.prompt.tokens.size() &&
|
||||||
head_p + n_match < input_tokens.size() &&
|
head_p + n_match < input_tokens.size() &&
|
||||||
slot.prompt.tokens[head_c + n_match] == input_tokens[head_p + n_match]) {
|
slot.prompt.tokens[head_c + n_match] == input_tokens[head_p + n_match]) {
|
||||||
|
|
||||||
n_match++;
|
n_match++;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (n_match >= (size_t) params_base.n_cache_reuse) {
|
if (n_match >= (size_t) n_cache_reuse) {
|
||||||
SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match);
|
SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match);
|
||||||
//for (size_t i = head_p; i < head_p + n_match; i++) {
|
//for (size_t i = head_p; i < head_p + n_match; i++) {
|
||||||
// SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
|
// SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
|
||||||
|
|
|
||||||
|
|
@ -159,6 +159,7 @@ task_params server_task::params_from_json_cmpl(
|
||||||
defaults.speculative = params_base.speculative;
|
defaults.speculative = params_base.speculative;
|
||||||
defaults.n_keep = params_base.n_keep;
|
defaults.n_keep = params_base.n_keep;
|
||||||
defaults.n_predict = params_base.n_predict;
|
defaults.n_predict = params_base.n_predict;
|
||||||
|
defaults.n_cache_reuse = params_base.n_cache_reuse;
|
||||||
defaults.antiprompt = params_base.antiprompt;
|
defaults.antiprompt = params_base.antiprompt;
|
||||||
|
|
||||||
// enabling this will output extra debug information in the HTTP responses from the server
|
// enabling this will output extra debug information in the HTTP responses from the server
|
||||||
|
|
@ -176,6 +177,7 @@ task_params server_task::params_from_json_cmpl(
|
||||||
params.n_keep = json_value(data, "n_keep", defaults.n_keep);
|
params.n_keep = json_value(data, "n_keep", defaults.n_keep);
|
||||||
params.n_discard = json_value(data, "n_discard", defaults.n_discard);
|
params.n_discard = json_value(data, "n_discard", defaults.n_discard);
|
||||||
params.n_cmpl = json_value(data, "n_cmpl", json_value(data, "n", 1));
|
params.n_cmpl = json_value(data, "n_cmpl", json_value(data, "n", 1));
|
||||||
|
params.n_cache_reuse = json_value(data, "n_cache_reuse", defaults.n_cache_reuse);
|
||||||
//params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement
|
//params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement
|
||||||
params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms);
|
params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms);
|
||||||
params.response_fields = json_value(data, "response_fields", std::vector<std::string>());
|
params.response_fields = json_value(data, "response_fields", std::vector<std::string>());
|
||||||
|
|
|
||||||
|
|
@ -55,6 +55,8 @@ struct task_params {
|
||||||
int32_t n_indent = 0; // minimum line indentation for the generated text in number of whitespace characters
|
int32_t n_indent = 0; // minimum line indentation for the generated text in number of whitespace characters
|
||||||
int32_t n_cmpl = 1; // number of completions to generate from this prompt
|
int32_t n_cmpl = 1; // number of completions to generate from this prompt
|
||||||
|
|
||||||
|
int32_t n_cache_reuse = 0; // min chunk size to attempt reusing from the cache via KV shifting (0 = disabled)
|
||||||
|
|
||||||
int64_t t_max_prompt_ms = -1; // TODO: implement
|
int64_t t_max_prompt_ms = -1; // TODO: implement
|
||||||
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
|
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
|
||||||
|
|
||||||
|
|
@ -62,6 +64,7 @@ struct task_params {
|
||||||
|
|
||||||
std::vector<std::string> antiprompt;
|
std::vector<std::string> antiprompt;
|
||||||
std::vector<std::string> response_fields;
|
std::vector<std::string> response_fields;
|
||||||
|
|
||||||
bool timings_per_token = false;
|
bool timings_per_token = false;
|
||||||
bool post_sampling_probs = false;
|
bool post_sampling_probs = false;
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue