server : support multi-modal context checkpoints (#19849)

* Modify llama-memory-hybrid-iswa.cpp

* Modify llama-memory-recurrent.cpp

* Modify server-common.cpp

* Modify server-common.h

* Modify server-context.cpp

* Modify server-task.h

* Added comment to llama-memory-hybrid-iswa.cpp

* Remove comment from server-context.cpp

* Stylistic fix server-context.cpp

* Fix an issue when seqrm isn't called in server-context.cpp

* cont : alternative impl

* cont : cleanup

* cont : n_tokens -> int64_t

---------

Co-authored-by: timkhronos <timkhronos@gmail.com>
This commit is contained in:
Georgi Gerganov 2026-02-25 15:14:27 +02:00 committed by GitHub
parent c747294b2d
commit d7d826b3c1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 100 additions and 35 deletions

View File

@ -163,7 +163,7 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
const auto & cell = cells[tail_id];
// partial intersection is invalid if it includes the final pos
if (0 < p0 && p0 <= cell.pos && p1 > cell.pos) {
//printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false\n");
//printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false, p0 = %d, cell.pos = %d, p1 = %d\n", p0, cell.pos, p1);
return false;
}
// invalidate tails which will be cleared

View File

@ -231,19 +231,77 @@ server_tokens::server_tokens(mtmd::input_chunks & mtmd_chunks, bool has_mtmd) :
server_tokens::server_tokens(const llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {
}
llama_pos server_tokens::pos_next() const {
llama_pos server_tokens::pos_next(int64_t n_tokens) const {
if (!has_mtmd) {
return tokens.size();
if (n_tokens < 0) {
return tokens.size();
}
return n_tokens;
}
llama_pos res = tokens.size();
if (n_tokens < 0) {
llama_pos res = tokens.size();
for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ++it) {
const auto & chunk = it->second;
res += mtmd_input_chunk_get_n_pos(chunk.get()) - mtmd_input_chunk_get_n_tokens(chunk.get());
for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ++it) {
const auto & chunk = it->second;
res += mtmd_input_chunk_get_n_pos(chunk.get()) - mtmd_input_chunk_get_n_tokens(chunk.get());
}
return res;
}
return res;
int64_t idx = 0;
llama_pos pos = 0;
GGML_ASSERT(n_tokens <= (int64_t)tokens.size());
while (idx < n_tokens) {
const auto media_it = map_idx_to_media.find(idx);
if (media_it != map_idx_to_media.end()) {
const auto & chunk = media_it->second;
const llama_pos n_pos = mtmd_input_chunk_get_n_pos(chunk.get());
const size_t n_tok = mtmd_input_chunk_get_n_tokens(chunk.get());
pos += n_pos;
idx += n_tok;
} else {
pos++;
idx++;
}
}
return pos;
}
size_t server_tokens::size_up_to_pos(llama_pos max_pos) const {
if (!has_mtmd) {
return std::min((size_t)(max_pos + 1), tokens.size());
}
size_t idx = 0;
llama_pos pos = 0;
while (idx < tokens.size()) {
const auto media_it = map_idx_to_media.find(idx);
if (media_it != map_idx_to_media.end()) {
const auto & chunk = media_it->second;
const llama_pos n_pos = mtmd_input_chunk_get_n_pos(chunk.get());
const size_t n_tok = mtmd_input_chunk_get_n_tokens(chunk.get());
pos += n_pos;
idx += n_tok;
} else {
pos++;
idx++;
}
if (pos > max_pos) {
break;
}
}
return idx;
}
std::string server_tokens::str() const {

View File

@ -167,7 +167,12 @@ public:
// for debugging
std::string str() const;
llama_pos pos_next() const;
// the next position after n_tokens. if n_tokens < 0, return the next position after all tokens.
llama_pos pos_next(int64_t n_tokens = -1) const;
// number of tokens with position <= max_pos
size_t size_up_to_pos(llama_pos max_pos) const;
const mtmd::input_chunk_ptr & find_chunk(size_t idx) const;
void push_back(llama_token tok);

View File

@ -1442,7 +1442,7 @@ private:
res->id = slot.task->id;
res->id_slot = slot.id;
res->index = slot.task->index;
res->index = slot.task->index;
// keep copy of last generated text for debugging purposes
if (slots_debug) {
@ -2282,15 +2282,15 @@ private:
n_past = 0;
}
llama_pos pos_next = slot.prompt.tokens.pos_next(n_past);
// note: when n_swa == 0, the model does not use SWA, which is equivalent to a window of 1
const auto n_swa = std::max(1, llama_model_n_swa(model));
// the largest pos_min required for a checkpoint to be useful
const auto pos_min_thold = std::max(0, n_past - n_swa);
const auto pos_min_thold = std::max(0, pos_next - n_swa);
// note: disallow with mtmd contexts for now
// https://github.com/ggml-org/llama.cpp/issues/17043
if (!mctx && n_past > 0 && n_past < slot.prompt.n_tokens()) {
if (n_past > 0 && n_past < slot.prompt.n_tokens()) {
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
if (pos_min == -1) {
SLT_ERR(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min);
@ -2341,9 +2341,6 @@ private:
}
if (pos_min > pos_min_thold) {
// TODO: support can be added in the future when corresponding vision models get released
GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa);
// search for a context checkpoint
@ -2364,18 +2361,20 @@ private:
const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
if (n != checkpoint_size) {
SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024);
SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, (float) checkpoint_size / 1024 / 1024);
do_reset = true;
//printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint");
} else {
n_past = std::min(n_past, std::max(it->pos_min + 1, it->pos_max));
SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024);
pos_next = std::min(pos_next, std::max(it->pos_min + 1, it->pos_max));
n_past = slot.prompt.tokens.size_up_to_pos(pos_next);
SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, (float) checkpoint_size / 1024 / 1024);
}
}
if (do_reset) {
SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see %s)\n",
"https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
pos_next = 0;
n_past = 0;
}
}
@ -2386,7 +2385,7 @@ private:
for (auto it = slot.prompt.checkpoints.begin(); it != slot.prompt.checkpoints.end();) {
const auto & cur = *it;
if (cur.pos_min > pos_min_thold) {
SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_swa = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, n_swa, (float) cur.data.size() / 1024 / 1024);
SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_swa = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, cur.n_tokens, n_swa, (float) cur.data.size() / 1024 / 1024);
it = slot.prompt.checkpoints.erase(it);
} else {
++it;
@ -2402,7 +2401,7 @@ private:
SLT_WRN(slot, "n_past was set to %d\n", n_past);
}
slot.n_prompt_tokens_cache = n_past;
slot.n_prompt_tokens_cache = n_past;
slot.n_prompt_tokens_processed = 0;
slot.prompt.tokens.keep_first(n_past);
@ -2520,10 +2519,6 @@ private:
}
}
// SLT_INF(slot, "new slot.prompt.tokens: %s\n", slot.slot.prompt.tokens.str().c_str());
SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n", slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.task->n_tokens());
// entire prompt has been processed
if (slot.prompt.n_tokens() == slot.task->n_tokens()) {
slot.state = SLOT_STATE_DONE_PROMPT;
@ -2536,8 +2531,6 @@ private:
slot.n_decoded = 0;
slot.i_batch = batch.n_tokens - 1;
SLT_INF(slot, "prompt done, n_tokens = %d, batch.n_tokens = %d\n", slot.prompt.n_tokens(), batch.n_tokens);
slot.init_sampler();
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
@ -2549,13 +2542,15 @@ private:
// no need to create checkpoints that are too close together
do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || pos_max > slot.prompt.checkpoints.back().pos_max + 64);
// note: we create the checkpoint before calling llama_decode(), so the current batch is not
// yet processed and therefore it is not part of the checkpoint.
if (do_checkpoint) {
while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
// make room for the new checkpoint, if needed
const auto & cur = slot.prompt.checkpoints.front();
SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024);
slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin());
}
@ -2563,16 +2558,21 @@ private:
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
auto & cur = slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{
/*.pos_min = */ pos_min,
/*.pos_max = */ pos_max,
/*.data = */ std::vector<uint8_t>(checkpoint_size),
/*.pos_min = */ pos_min,
/*.pos_max = */ pos_max,
/*.n_tokens = */ slot.prompt.n_tokens() - batch.n_tokens,
/*.data = */ std::vector<uint8_t>(checkpoint_size),
});
llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
(int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
(int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024);
}
SLT_INF(slot, "prompt processing done, n_tokens = %d, batch.n_tokens = %d\n", slot.prompt.n_tokens(), batch.n_tokens);
} else {
SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n", slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.task->n_tokens());
}
}

View File

@ -557,6 +557,8 @@ struct server_prompt_checkpoint {
llama_pos pos_min;
llama_pos pos_max;
int64_t n_tokens;
std::vector<uint8_t> data;
size_t size() const {