server : host-memory prompt caching (#16391)

* minor : code style

* server : fix prompt similarity calculation

* server : initial host-memory prompt caching

* cont

* server : refactor

* cont

* cont : make the server task of the slot const

* cont : minor [no ci]

* server : cache prompts and checkpoints only for completion tasks

* server : improve prompt caching logic

* cont : fix check for number of cached prompts [no ci]

* server : improve caching logic, add -cram CLI arg

* server : print prompt mismatch info

* cont : better naming [no ci]

* server : improve prompt cache loading logic

* server : add option to debug the slot contents (#16482)

* server : add option to debug the slot contents

* Update tools/server/server.cpp

---------

Co-authored-by: Xuan-Son Nguyen <son@huggingface.co>

* server : add option to disable prompt cache

---------

Co-authored-by: Xuan-Son Nguyen <son@huggingface.co>
This commit is contained in:
Georgi Gerganov 2025-10-09 18:54:51 +03:00 committed by GitHub
parent 8328fd4bae
commit d00cbea63c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 813 additions and 471 deletions

View File

@ -1935,6 +1935,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.n_ctx_checkpoints = value; params.n_ctx_checkpoints = value;
} }
).set_env("LLAMA_ARG_CTX_CHECKPOINTS").set_examples({LLAMA_EXAMPLE_SERVER})); ).set_env("LLAMA_ARG_CTX_CHECKPOINTS").set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--cache-ram", "-cram"}, "N",
string_format("set the maximum cache size in MiB (default: %d, -1 - no limit, 0 - disable)\n"
"[(more info)](https://github.com/ggml-org/llama.cpp/pull/16391)", params.cache_ram_mib),
[](common_params & params, int value) {
params.cache_ram_mib = value;
}
).set_env("LLAMA_ARG_CACHE_RAM").set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg( add_opt(common_arg(
{"--kv-unified", "-kvu"}, {"--kv-unified", "-kvu"},
string_format("use single unified KV buffer for the KV cache of all sequences (default: %s)\n" string_format("use single unified KV buffer for the KV cache of all sequences (default: %s)\n"

View File

@ -33,8 +33,8 @@ struct common_chat_msg_content_part {
struct common_chat_msg { struct common_chat_msg {
std::string role; std::string role;
std::string content; std::string content;
std::vector<common_chat_msg_content_part> content_parts = {}; std::vector<common_chat_msg_content_part> content_parts;
std::vector<common_chat_tool_call> tool_calls = {}; std::vector<common_chat_tool_call> tool_calls;
std::string reasoning_content; std::string reasoning_content;
std::string tool_name; std::string tool_name;
std::string tool_call_id; std::string tool_call_id;
@ -44,7 +44,7 @@ struct common_chat_msg {
bool empty() const { bool empty() const {
return content.empty() && content_parts.empty() && tool_calls.empty() && reasoning_content.empty() && tool_name.empty() && tool_call_id.empty(); return content.empty() && content_parts.empty() && tool_calls.empty() && reasoning_content.empty() && tool_name.empty() && tool_call_id.empty();
} }
void ensure_tool_call_ids_set(std::vector<std::string> & ids_cache, const std::function<std::string()> & gen_tool_call_id) { void set_tool_call_ids(std::vector<std::string> & ids_cache, const std::function<std::string()> & gen_tool_call_id) {
for (auto i = 0u; i < tool_calls.size(); i++) { for (auto i = 0u; i < tool_calls.size(); i++) {
if (ids_cache.size() <= i) { if (ids_cache.size() <= i) {
auto id = tool_calls[i].id; auto id = tool_calls[i].id;

View File

@ -378,7 +378,7 @@ struct common_params {
bool simple_io = false; // improves compatibility with subprocesses and limited consoles bool simple_io = false; // improves compatibility with subprocesses and limited consoles
bool cont_batching = true; // insert new sequences for decoding on-the-fly bool cont_batching = true; // insert new sequences for decoding on-the-fly
bool no_perf = false; // disable performance metrics bool no_perf = false; // disable performance metrics
bool ctx_shift = false; // context shift on infinite text generation bool ctx_shift = false; // context shift on infinite text generation
bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055) bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
bool kv_unified = false; // enable unified KV cache bool kv_unified = false; // enable unified KV cache
@ -425,7 +425,8 @@ struct common_params {
int32_t timeout_write = timeout_read; // http write timeout in seconds int32_t timeout_write = timeout_read; // http write timeout in seconds
int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool) int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting
int32_t n_ctx_checkpoints = 3; // max number of context checkpoints per slot int32_t n_ctx_checkpoints = 8; // max number of context checkpoints per slot
int32_t cache_ram_mib = 8192; // 0 = no limit, 1 = 1 MiB, etc.
std::string hostname = "127.0.0.1"; std::string hostname = "127.0.0.1";
std::string public_path = ""; // NOLINT std::string public_path = ""; // NOLINT

View File

@ -123,11 +123,8 @@ llama_kv_cache::llama_kv_cache(
throw std::runtime_error("failed to create ggml context for kv cache"); throw std::runtime_error("failed to create ggml context for kv cache");
} }
ggml_tensor * k; ggml_tensor * k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream);
ggml_tensor * v; ggml_tensor * v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream);
k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream);
v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream);
ggml_format_name(k, "cache_k_l%d", il); ggml_format_name(k, "cache_k_l%d", il);
ggml_format_name(v, "cache_v_l%d", il); ggml_format_name(v, "cache_v_l%d", il);

File diff suppressed because it is too large Load Diff

View File

@ -66,8 +66,7 @@ def test_server_slots():
assert len(res.body) == server.n_slots assert len(res.body) == server.n_slots
assert server.n_ctx is not None and server.n_slots is not None assert server.n_ctx is not None and server.n_slots is not None
assert res.body[0]["n_ctx"] == server.n_ctx / server.n_slots assert res.body[0]["n_ctx"] == server.n_ctx / server.n_slots
assert "params" in res.body[0] assert "params" not in res.body[0]
assert res.body[0]["params"]["seed"] == server.seed
def test_load_split_model(): def test_load_split_model():

View File

@ -19,8 +19,8 @@ def create_server():
(None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, None), (None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, None),
(None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, 'chatml'), (None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, 'chatml'),
(None, "Book", "What is the best book", 8, "^ blue", 23, 8, "length", True, "This is not a chat template, it is"), (None, "Book", "What is the best book", 8, "^ blue", 23, 8, "length", True, "This is not a chat template, it is"),
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", False, None), ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 128, "length", False, None),
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", True, None), ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 128, "length", True, None),
(None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", False, None), (None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", False, None),
(None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", True, None), (None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", True, None),
] ]
@ -54,7 +54,7 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte
"system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason", "system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason",
[ [
("Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"), ("Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"),
("You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length"), ("You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 128, "length"),
] ]
) )
def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason): def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason):

View File

@ -16,7 +16,7 @@ def create_server():
@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated,return_tokens", [ @pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated,return_tokens", [
("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False, False), ("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False, False),
("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False, True), ("Write a joke about AI from a very long prompt which will not be truncated", 64, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False, True),
]) ])
def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool, return_tokens: bool): def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool, return_tokens: bool):
global server global server
@ -41,7 +41,7 @@ def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int,
@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [ @pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [
("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False), ("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False),
("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False), ("Write a joke about AI from a very long prompt which will not be truncated", 64, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False),
]) ])
def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool): def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool):
global server global server

View File

@ -4,6 +4,12 @@ from utils import *
server = ServerPreset.tinyllama2() server = ServerPreset.tinyllama2()
SHORT_TEXT = """
Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
""".strip()
LONG_TEXT = """ LONG_TEXT = """
Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
@ -21,19 +27,18 @@ def create_server():
def test_ctx_shift_enabled(): def test_ctx_shift_enabled():
# the prompt is 301 tokens # the prompt is 226 tokens
# the slot context is 512/2 = 256 tokens # the slot context is 512/2 = 256 tokens
# the prompt is truncated to keep the last (301 - 256/2) = 173 tokens
# 96 tokens are generated thanks to shifting the context when it gets full # 96 tokens are generated thanks to shifting the context when it gets full
global server global server
server.enable_ctx_shift = True server.enable_ctx_shift = True
server.start() server.start()
res = server.make_request("POST", "/completion", data={ res = server.make_request("POST", "/completion", data={
"n_predict": 96, "n_predict": 96,
"prompt": LONG_TEXT, "prompt": SHORT_TEXT,
}) })
assert res.status_code == 200 assert res.status_code == 200
assert res.body["timings"]["prompt_n"] == 173 assert res.body["timings"]["prompt_n"] == 226
assert res.body["timings"]["predicted_n"] == 96 assert res.body["timings"]["predicted_n"] == 96
assert res.body["truncated"] is True assert res.body["truncated"] is True

View File

@ -31,10 +31,10 @@
using json = nlohmann::ordered_json; using json = nlohmann::ordered_json;
#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) #define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) #define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) #define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) #define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) #define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) #define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
@ -1102,6 +1102,7 @@ public:
~server_tokens() = default; ~server_tokens() = default;
// Prevent copying // Prevent copying
// TODO: server_tokens should be copyable - remove this:
server_tokens(const server_tokens&) = delete; server_tokens(const server_tokens&) = delete;
server_tokens& operator=(const server_tokens&) = delete; server_tokens& operator=(const server_tokens&) = delete;
@ -1119,7 +1120,7 @@ public:
} }
} }
server_tokens(llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {} server_tokens(const llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {}
// for debugging // for debugging
std::string str() const { std::string str() const {
@ -1144,9 +1145,8 @@ public:
auto it = map_pos_to_media.find(pos); auto it = map_pos_to_media.find(pos);
if (it != map_pos_to_media.end()) { if (it != map_pos_to_media.end()) {
return it->second; return it->second;
} else {
throw std::runtime_error("Chunk not found");
} }
throw std::runtime_error("Chunk not found");
} }
void push_back(llama_token tok) { void push_back(llama_token tok) {
@ -1170,7 +1170,7 @@ public:
map_pos_to_media[start_pos] = std::move(new_chunk); map_pos_to_media[start_pos] = std::move(new_chunk);
} else if (type == MTMD_INPUT_CHUNK_TYPE_TEXT) { } else if (type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
size_t n_tokens; size_t n_tokens;
auto text_tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens); const auto * text_tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens);
for (size_t i = 0; i < n_tokens; ++i) { for (size_t i = 0; i < n_tokens; ++i) {
push_back(text_tokens[i]); push_back(text_tokens[i]);
} }
@ -1190,7 +1190,7 @@ public:
// We could also just check, but this will prevent silently dropping MTMD data. // We could also just check, but this will prevent silently dropping MTMD data.
GGML_ASSERT(has_mtmd); GGML_ASSERT(has_mtmd);
for (auto it = tokens.map_pos_to_media.begin(); it != tokens.map_pos_to_media.end(); ) { for (auto it = tokens.map_pos_to_media.begin(); it != tokens.map_pos_to_media.end(); ) {
auto chunk = tokens.map_pos_to_media[it->first].get(); auto * chunk = tokens.map_pos_to_media[it->first].get();
mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk)); mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk));
map_pos_to_media[start_pos+it->first] = std::move(new_chunk); map_pos_to_media[start_pos+it->first] = std::move(new_chunk);
} }
@ -1271,33 +1271,52 @@ public:
} }
size_t get_common_prefix(const server_tokens & b) const { size_t get_common_prefix(const server_tokens & b) const {
size_t max_idx = std::min(tokens.size(), b.tokens.size()); const size_t max_idx = std::min(tokens.size(), b.tokens.size());
for (size_t i = 0; i < max_idx; ++i) {
auto & ai = tokens[i];
auto & bi = b.tokens[i];
if (ai == LLAMA_TOKEN_NULL && bi == LLAMA_TOKEN_NULL) { if (!has_mtmd) {
GGML_ASSERT(has_mtmd); for (size_t i = 0; i < max_idx; ++i) {
const auto & a_chunk = find_chunk(i); if (tokens[i] == b.tokens[i]) {
const auto & b_chunk = b.find_chunk(i);
GGML_ASSERT(a_chunk && b_chunk);
std::string ai_id = mtmd_input_chunk_get_id(a_chunk.get());
std::string bi_id = mtmd_input_chunk_get_id(b_chunk.get());
size_t a_pos = mtmd_input_chunk_get_n_pos(a_chunk.get());
size_t b_pos = mtmd_input_chunk_get_n_pos(b_chunk.get());
if (ai_id == bi_id && a_pos == b_pos) {
GGML_ASSERT(a_pos > 0 && "Invalid media chunk"); // should never happen
i += a_pos - 1; // will be +1 by the for loop
continue; continue;
} else {
return i;
} }
} else if (ai == bi) {
continue;
} else {
return i; return i;
} }
return max_idx;
} }
for (size_t i = 0; i < max_idx; ++i) {
const llama_token ai = tokens[i];
const llama_token bi = b.tokens[i];
if (ai == LLAMA_TOKEN_NULL && bi == LLAMA_TOKEN_NULL) {
const auto & a_chunk = find_chunk(i);
const auto & b_chunk = b.find_chunk(i);
GGML_ASSERT(a_chunk && b_chunk);
const std::string id_ai = mtmd_input_chunk_get_id(a_chunk.get());
const std::string id_bi = mtmd_input_chunk_get_id(b_chunk.get());
const size_t pos_a = mtmd_input_chunk_get_n_pos(a_chunk.get());
const size_t pos_b = mtmd_input_chunk_get_n_pos(b_chunk.get());
if (id_ai == id_bi && pos_a == pos_b) {
GGML_ASSERT(pos_a > 0 && "Invalid media chunk"); // should never happen
i += pos_a - 1; // will be +1 by the for loop
continue;
}
return i;
}
if (ai == bi) {
continue;
}
return i;
}
return max_idx; // all tokens are equal return max_idx; // all tokens are equal
} }
@ -1308,7 +1327,7 @@ public:
const int32_t n_vocab = llama_vocab_n_tokens(vocab); const int32_t n_vocab = llama_vocab_n_tokens(vocab);
for (size_t i = 0; i < tokens.size(); ++i) { for (size_t i = 0; i < tokens.size(); ++i) {
auto & t = tokens[i]; const auto & t = tokens[i];
if (t == LLAMA_TOKEN_NULL) { if (t == LLAMA_TOKEN_NULL) {
try { try {
const auto & chunk = find_chunk(i); const auto & chunk = find_chunk(i);
@ -1330,8 +1349,8 @@ public:
mtmd_context * mctx, mtmd_context * mctx,
llama_pos n_past, llama_pos n_past,
int32_t seq_id, int32_t seq_id,
llama_pos & n_pos_out) { llama_pos & n_pos_out) const {
auto & chunk = find_chunk(n_past); const auto & chunk = find_chunk(n_past);
const char * name = mtmd_input_chunk_get_type(chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE const char * name = mtmd_input_chunk_get_type(chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE
? "image" : "audio"; ? "image" : "audio";
SRV_INF("processing %s...\n", name); SRV_INF("processing %s...\n", name);