diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 3541d910d8..b86e7e608e 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -2530,9 +2530,24 @@ private: slot.n_prompt_tokens_processed++; // process the last few tokens of the prompt separately in order to allow for a checkpoint to be created. - const int n_last = std::min(n_batch, 512); - if (do_checkpoint && slot.task->n_tokens() == slot.prompt.n_tokens() + n_last) { - break; + // create checkpoints that many tokens before the end of the prompt: + // - 4 + n_ubatch + // - 4 + // ref: https://github.com/ggml-org/llama.cpp/pull/20288 + { + static const int checkpoint_offsets[] = {4 + n_ubatch, 4}; + + bool should_break = false; + for (int offset : checkpoint_offsets) { + const int n_last = std::min(n_batch, offset); + if (do_checkpoint && slot.task->n_tokens() == slot.prompt.n_tokens() + n_last) { + should_break = true; + break; + } + } + if (should_break) { + break; + } } } @@ -2554,18 +2569,27 @@ private: slot.init_sampler(); SLT_INF(slot, "prompt processing done, n_tokens = %d, batch.n_tokens = %d\n", slot.prompt.n_tokens(), batch.n_tokens); } else { - // only do non-end checkpoints if the "checkpoint every n tokens" option is set - do_checkpoint = do_checkpoint && params_base.checkpoint_every_nt > 0; - if (do_checkpoint) { - llama_pos last_checkpoint = 0; - if (!slot.prompt.checkpoints.empty()) { - last_checkpoint = slot.prompt.checkpoints.back().n_tokens; - } - do_checkpoint = do_checkpoint && slot.prompt.n_tokens() - batch.n_tokens - last_checkpoint >= params_base.checkpoint_every_nt; + if (slot.task->n_tokens() < slot.prompt.n_tokens() + n_ubatch) { + // near the end of the prompt + do_checkpoint = do_checkpoint && true; + } else { + // only do non-end checkpoints if the "checkpoint every n tokens" option is set + do_checkpoint = do_checkpoint && params_base.checkpoint_every_nt > 0; + if (do_checkpoint) { - SLT_INF(slot, "%d tokens since last checkpoint at %d, creating new checkpoint during processing at position %d\n", params_base.checkpoint_every_nt, last_checkpoint, slot.prompt.n_tokens()); + llama_pos last_checkpoint = 0; + if (!slot.prompt.checkpoints.empty()) { + last_checkpoint = slot.prompt.checkpoints.back().n_tokens; + } + + do_checkpoint = do_checkpoint && slot.prompt.n_tokens() - batch.n_tokens - last_checkpoint >= params_base.checkpoint_every_nt; + + if (do_checkpoint) { + SLT_INF(slot, "%d tokens since last checkpoint at %d, creating new checkpoint during processing at position %d\n", params_base.checkpoint_every_nt, last_checkpoint, slot.prompt.n_tokens()); + } } } + SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n", slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.task->n_tokens()); }