From 6051df2f2b502ba07a45bdd7be79831a0d9e8af8 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 18 Mar 2026 20:21:53 +0200 Subject: [PATCH] server : improve mtmd ctx checkpoints --- tools/server/server-context.cpp | 73 ++++++++++++++++++--------------- 1 file changed, 39 insertions(+), 34 deletions(-) diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 2e43bebd6f..8740975544 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -2459,8 +2459,39 @@ private: slot.n_prompt_tokens_cache = 0; } + // If using an alora, there may be uncached tokens that come + // before the invocation sequence. When this happens, the + // tokens before the invocation sequence need to be + // processed without the adapter in a separate batch, then + // the adapter needs to be enabled for the remaining tokens. + if (lora_all_alora(slot.lora) && slot.alora_invocation_start - 1 > slot.prompt.n_tokens()) { + SLT_DBG(slot, "processing pre-alora tokens without the adapter (n_tokens = %d, alora_invocation_start = %d)\n", slot.prompt.n_tokens(), slot.alora_invocation_start); + const auto & enabled_loras = lora_get_enabled_ids(slot.lora); + GGML_ASSERT(enabled_loras.size() == 1); + alora_scale = slot.lora[enabled_loras[0]].scale; + slot.lora[enabled_loras[0]].scale = 0.0f; + alora_disabled_id = enabled_loras[0]; + } + bool do_checkpoint = params_base.n_ctx_checkpoints > 0; + // make checkpoints only for completion tasks + do_checkpoint = do_checkpoint && slot.task->type == SERVER_TASK_TYPE_COMPLETION; + + // make a checkpoint of the parts of the memory that cannot be rolled back. + // checkpoints are created only if: + // - the model uses SWA and we are not using `swa_full` + // - the model architecture is marked as recurrent or hybrid + // + // TODO: try to make this conditional on the context or the memory module, instead of the model type + do_checkpoint = do_checkpoint && ( + llama_model_is_recurrent(model) || + llama_model_is_hybrid(model) || + (llama_model_n_swa(model) > 0 && !params_base.swa_full) + ); + + bool has_mtmd = false; + // check if we should process the image if (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) { // process the image @@ -2481,38 +2512,9 @@ private: slot.prompt.tokens.push_back(chunk.get()); // copy } - do_checkpoint = false; // do not checkpoint right after an image chunk + has_mtmd = true; } - // If using an alora, there may be uncached tokens that come - // before the invocation sequence. When this happens, the - // tokens before the invocation sequence need to be - // processed without the adapter in a separate batch, then - // the adapter needs to be enabled for the remaining tokens. - if (lora_all_alora(slot.lora) && slot.alora_invocation_start - 1 > slot.prompt.n_tokens()) { - SLT_DBG(slot, "processing pre-alora tokens without the adapter (n_tokens = %d, alora_invocation_start = %d)\n", slot.prompt.n_tokens(), slot.alora_invocation_start); - const auto & enabled_loras = lora_get_enabled_ids(slot.lora); - GGML_ASSERT(enabled_loras.size() == 1); - alora_scale = slot.lora[enabled_loras[0]].scale; - slot.lora[enabled_loras[0]].scale = 0.0f; - alora_disabled_id = enabled_loras[0]; - } - - // make checkpoints only for completion tasks - do_checkpoint = do_checkpoint && slot.task->type == SERVER_TASK_TYPE_COMPLETION; - - // make a checkpoint of the parts of the memory that cannot be rolled back. - // checkpoints are created only if: - // - the model uses SWA and we are not using `swa_full` - // - the model architecture is marked as recurrent or hybrid - // - // TODO: try to make this conditional on the context or the memory module, instead of the model type - do_checkpoint = do_checkpoint && ( - llama_model_is_recurrent(model) || - llama_model_is_hybrid(model) || - (llama_model_n_swa(model) > 0 && !params_base.swa_full) - ); - // add prompt tokens for processing in the current batch while (slot.prompt.n_tokens() < slot.task->n_tokens() && batch.n_tokens < n_batch) { // get next token to process @@ -2544,13 +2546,13 @@ private: // - 4 + n_ubatch // - 4 // ref: https://github.com/ggml-org/llama.cpp/pull/20288 - { + if (do_checkpoint) { static const int checkpoint_offsets[] = {4 + n_ubatch, 4}; bool should_break = false; for (int offset : checkpoint_offsets) { const int n_last = std::min(n_batch, offset); - if (do_checkpoint && slot.task->n_tokens() == slot.prompt.n_tokens() + n_last) { + if (slot.task->n_tokens() == slot.prompt.n_tokens() + n_last) { should_break = true; break; } @@ -2607,10 +2609,13 @@ private: const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id); // no need for empty or small checkpoints - do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64); + do_checkpoint = do_checkpoint && (pos_min >= 0 && slot.prompt.n_tokens() >= 64); + + // do not checkpoint after mtmd chunks + do_checkpoint = do_checkpoint && !has_mtmd; // no need to create checkpoints that are too close together - do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || pos_max > slot.prompt.checkpoints.back().pos_max + 64); + do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || slot.prompt.n_tokens() - n_tokens_cur > slot.prompt.checkpoints.back().n_tokens + 64); // note: we create the checkpoint before calling llama_decode(), so the current batch is not // yet processed and therefore it is not part of the checkpoint.