diff --git a/common/arg.cpp b/common/arg.cpp index cd73d96420..0d8561dbb3 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1279,13 +1279,20 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } ).set_env("LLAMA_ARG_SWA_FULL")); add_opt(common_arg( - {"--ctx-checkpoints", "--swa-checkpoints"}, "N", + {"-ctxcp", "--ctx-checkpoints", "--swa-checkpoints"}, "N", string_format("max number of context checkpoints to create per slot (default: %d)" "[(more info)](https://github.com/ggml-org/llama.cpp/pull/15293)", params.n_ctx_checkpoints), [](common_params & params, int value) { params.n_ctx_checkpoints = value; } ).set_env("LLAMA_ARG_CTX_CHECKPOINTS").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI})); + add_opt(common_arg( + {"-cpent", "--checkpoint-every-n-tokens"}, "N", + string_format("create a checkpoint every n tokens during prefill (processing), -1 to disable (default: %d)", params.checkpoint_every_nt), + [](common_params & params, int value) { + params.checkpoint_every_nt = value; + } + ).set_env("LLAMA_ARG_CHECKPOINT_EVERY_NT").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI})); add_opt(common_arg( {"-cram", "--cache-ram"}, "N", string_format("set the maximum cache size in MiB (default: %d, -1 - no limit, 0 - disable)" diff --git a/common/common.h b/common/common.h index 3c09cdf040..3e1b23f5d4 100644 --- a/common/common.h +++ b/common/common.h @@ -516,14 +516,15 @@ struct common_params { std::string cls_sep = "\t"; // separator of classification sequences // server params - int32_t port = 8080; // server listens on this network port - int32_t timeout_read = 600; // http read timeout in seconds - int32_t timeout_write = timeout_read; // http write timeout in seconds - int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool) - int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting - bool cache_prompt = true; // whether to enable prompt caching - int32_t n_ctx_checkpoints = 8; // max number of context checkpoints per slot - int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc. + int32_t port = 8080; // server listens on this network port + int32_t timeout_read = 600; // http read timeout in seconds + int32_t timeout_write = timeout_read; // http write timeout in seconds + int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool) + int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting + bool cache_prompt = true; // whether to enable prompt caching + int32_t n_ctx_checkpoints = 32; // max number of context checkpoints per slot + int32_t checkpoint_every_nt = 8192; // make a checkpoint every n tokens during prefill + int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc. std::string hostname = "127.0.0.1"; std::string public_path = ""; // NOLINT diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index aafed49502..9dbd6d798a 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -12,6 +12,7 @@ #include "mtmd.h" #include "mtmd-helper.h" +#include #include #include #include @@ -2348,8 +2349,10 @@ private: const auto it = std::find_if( slot.prompt.checkpoints.rbegin(), slot.prompt.checkpoints.rend(), - [&](const auto & cur) { + [&, func_name = __func__](const auto & cur) { // guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS] + LOG_INF("slot %12.*s: id %2d | task %d | Checking checkpoint with [%d, %d] against %d...\n", 12, + func_name, (slot).id, ((slot).task ? (slot).task->id : -1), cur.pos_min, cur.pos_max, pos_min_thold); return cur.pos_min < pos_min_thold; } ); @@ -2533,48 +2536,66 @@ private: slot.i_batch = batch.n_tokens - 1; slot.init_sampler(); - - const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); - const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id); - - // no need for empty or small checkpoints - do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64); - - // no need to create checkpoints that are too close together - do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || pos_max > slot.prompt.checkpoints.back().pos_max + 64); - - // note: we create the checkpoint before calling llama_decode(), so the current batch is not - // yet processed and therefore it is not part of the checkpoint. - if (do_checkpoint) { - while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) { - // make room for the new checkpoint, if needed - const auto & cur = slot.prompt.checkpoints.front(); - - SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", - cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024); - - slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin()); - } - - const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - - auto & cur = slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{ - /*.pos_min = */ pos_min, - /*.pos_max = */ pos_max, - /*.n_tokens = */ slot.prompt.n_tokens() - batch.n_tokens, - /*.data = */ std::vector(checkpoint_size), - }); - - llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - - SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", - (int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024); - } - SLT_INF(slot, "prompt processing done, n_tokens = %d, batch.n_tokens = %d\n", slot.prompt.n_tokens(), batch.n_tokens); } else { + // only do non-end checkpoints if the "checkpoint every n tokens" option is set + do_checkpoint = do_checkpoint && params_base.checkpoint_every_nt > 0; + if (do_checkpoint) { + llama_pos last_checkpoint = 0; + if (!slot.prompt.checkpoints.empty()) { + last_checkpoint = slot.prompt.checkpoints.back().n_tokens; + } + do_checkpoint = do_checkpoint && slot.prompt.n_tokens() - batch.n_tokens - last_checkpoint >= params_base.checkpoint_every_nt; + if (do_checkpoint) { + SLT_INF(slot, "%d tokens since last checkpoint at %d, creating new checkpoint during processing at position %d\n", params_base.checkpoint_every_nt, last_checkpoint, slot.prompt.n_tokens()); + } + } SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n", slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.task->n_tokens()); } + + const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); + const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id); + + // no need for empty or small checkpoints + do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64); + + // no need to create checkpoints that are too close together + do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || pos_max > slot.prompt.checkpoints.back().pos_max + 64); + + // note: we create the checkpoint before calling llama_decode(), so the current batch is not + // yet processed and therefore it is not part of the checkpoint. + if (do_checkpoint) { + while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) { + // make room for the new checkpoint, if needed + const auto & cur = slot.prompt.checkpoints.front(); + + SLT_WRN(slot, + "erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 + ", size = %.3f MiB)\n", + cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024); + + slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin()); + } + + const size_t checkpoint_size = + llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + + auto & cur = slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{ + /*.pos_min = */ pos_min, + /*.pos_max = */ pos_max, + /*.n_tokens = */ slot.prompt.n_tokens() - batch.n_tokens, + /*.data = */ std::vector(checkpoint_size), + }); + + llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, + LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + + SLT_WRN(slot, + "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 + ", size = %.3f MiB)\n", + (int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, + cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024); + } } if (!slot_batched) {