From 8a4fe643107cbcc7751331b8bc3d4d982bb46572 Mon Sep 17 00:00:00 2001 From: Sascha Rogmann Date: Mon, 9 Feb 2026 22:29:12 +0100 Subject: [PATCH] server : speculative decoding using checkpoints --- common/arg.cpp | 10 +++ common/common.h | 1 + common/ngram-map.cpp | 4 +- tools/server/server-context.cpp | 122 ++++++++++++++++++++++++++++++-- 4 files changed, 128 insertions(+), 9 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 9c85696ebd..becc0f7c47 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -3447,6 +3447,16 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.speculative.ngram_min_hits = value; } ).set_examples({LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"--spec-ckpt-num-tries"}, "N", + string_format("number of tries for speculative decoding with recurrent memory (default: %d)", params.speculative.ckpt_num_tries), + [](common_params & params, int value) { + if (value < 0 || value > 10) { + throw std::invalid_argument("number of tries must be between 0 and 10 inclusive"); + } + params.speculative.ckpt_num_tries = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( {"-ctkd", "--cache-type-k-draft"}, "TYPE", string_format( diff --git a/common/common.h b/common/common.h index b284244530..158037d399 100644 --- a/common/common.h +++ b/common/common.h @@ -270,6 +270,7 @@ struct common_params_speculative { uint16_t ngram_size_n = 12; // ngram size for lookup uint16_t ngram_size_m = 48; // mgram size for speculative tokens uint16_t ngram_min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed + uint16_t ckpt_num_tries = 0; // number of tries in case of recurrent memory std::shared_ptr ngram_mod; diff --git a/common/ngram-map.cpp b/common/ngram-map.cpp index 2b876a6e99..b289abbd04 100644 --- a/common/ngram-map.cpp +++ b/common/ngram-map.cpp @@ -231,7 +231,7 @@ void common_ngram_map_draft(common_ngram_map & map, GGML_ABORT("%s: cur_len exceeds UINT32_MAX: %zu", __func__, cur_len); } - if (map.idx_last_check > cur_len) { + if (map.idx_last_check > cur_len) { // Should not happen because of common_ngram_map_begin(). GGML_ABORT("%s: map.idx_last_check > cur_len: %zu > %zu", __func__, map.idx_last_check, cur_len); } @@ -386,7 +386,7 @@ void common_ngram_map_draft(common_ngram_map & map, LOG_DBG("%s: key_idx = %zu, key_offset = %zu, key_num = %d, draft.size = %zu\n", __func__, curr_key.key_idx, key_offset, curr_key.key_num, draft.size()); - map.last_draft_created = false; + map.last_draft_created = true; map.last_draft_key_idx = key_offset; map.last_draft_value_idx = 0; // value 0 is used for simple mode return; diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index ceafcac179..8f318c4a4c 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -146,6 +146,13 @@ struct server_slot { llama_token sampled; // in speculative mode, this is the last accepted token llama_tokens drafted; + // use of checkpoints in speculative mode + uint16_t spec_n_denials = 0; // number of drafts not accepted at the current position + int spec_n_accepted = 0; // number of accepted tokens at current position + bool spec_has_ckpt = false; // true if a checkpoint for rollback after partial speculation has been created + size_t spec_ckpt_size_part; // size of partial checkpoint + + // stats size_t n_sent_text = 0; // number of sent text character @@ -742,7 +749,7 @@ private: const bool can_spec = common_speculative_is_compat(ctx); if (!can_spec) { - SRV_WRN("%s", "speculative decoding not supported by this context\n"); + SRV_WRN("%s", "speculative decoding not supported by this context without checkpoints\n"); } // initialize slots @@ -757,7 +764,7 @@ private: slot.prompt.tokens.has_mtmd = mctx != nullptr; // try speculative decoding - if (can_spec) { + if (can_spec || params_base.speculative.ckpt_num_tries > 0) { slot.spec = common_speculative_init(params_base.speculative, slot.ctx); if (slot.spec) { if (mctx) { @@ -2041,8 +2048,8 @@ private: // generate draft tokens in speculative decoding mode // TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK] // perform the speculative drafting for all sequences at the same time in a single batch - const int n_draft_max = slot.get_n_draft_max(); - if (n_draft_max > 0) { + const int n_draft_max = (slot.spec_n_accepted > 0) ? slot.spec_n_accepted : slot.get_n_draft_max(); + if (n_draft_max > 0 && slot.spec_n_denials < params_base.speculative.ckpt_num_tries) { if (mctx) { // we should never reach this, as speculative is automatically disabled if mmproj is loaded GGML_ABORT("not supported by multimodal"); @@ -2059,17 +2066,67 @@ private: draft.resize(n_draft_max); } + const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); + const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id); + bool do_checkpoint = !draft.empty() && slot.prompt.checkpoints.size() < (size_t) params_base.n_ctx_checkpoints; + SLT_DBG(slot, "draft.size = %zu, n_spec_denials = %d, #ckpts=%zu, do_checkpoint = %s, pos_min = %d, pos_max = %d, tokens=[..., %d, %d, %d]\n", + draft.size(), slot.spec_n_denials, + slot.prompt.checkpoints.size(), + do_checkpoint ? "yes" : "no", pos_min, pos_max, + cached_text_tokens[cached_text_tokens.size() - 3], + cached_text_tokens[cached_text_tokens.size() - 2], + cached_text_tokens[cached_text_tokens.size() - 1]); + + if (do_checkpoint) { + while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) { + // make room for the new checkpoint, if needed + const auto & cur = slot.prompt.checkpoints.front(); + + SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", + cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); + + slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin()); + } + + const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, 0); + + auto & cur = slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{ + /*.pos_min = */ pos_min, + /*.pos_max = */ pos_max, + /*.data = */ std::vector(checkpoint_size), + }); + + const size_t n = llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + + SLT_INF(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", + (int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); + + slot.spec_ckpt_size_part = n; + slot.spec_has_ckpt = true; + } + // add the sampled token to the batch slot.i_batch_dft.push_back(batch.n_tokens); + SLT_DBG(slot, "before common_batch_add: sampled=%d, pos_next=%d, tokens.size=%zu, tokens.last=%d\n", + slot.sampled, slot.prompt.tokens.pos_next(), slot.prompt.tokens.size(), slot.prompt.tokens[slot.prompt.tokens.size() -1]); common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true); slot.prompt.tokens.push_back(slot.sampled); if (slot.task->params.speculative.n_min > (int) draft.size()) { - SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.task->params.speculative.n_min); + SLT_INF(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.task->params.speculative.n_min); // fallback to normal decoding slot.i_batch = slot.i_batch_dft[0]; slot.drafted.clear(); slot.i_batch_dft.clear(); + + if (slot.spec_has_ckpt) { + slot.spec_n_accepted = 0; + slot.spec_n_denials = 0; + + // Delete Checkpoint + slot.prompt.checkpoints.pop_back(); + slot.spec_has_ckpt = false; + } } else { // keep track of total number of drafted tokens tested slot.n_draft_total += draft.size(); @@ -2086,6 +2143,9 @@ private: // no speculative decoding slot.i_batch = batch.n_tokens; + slot.spec_n_denials = 0; + slot.spec_n_accepted = 0; + common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true); slot.prompt.tokens.push_back(slot.sampled); @@ -2538,6 +2598,7 @@ private: // no need for empty or small checkpoints do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64); + SLT_INF(slot, "main/do_checkpoint = %s, pos_min = %d, pos_max = %d\n", do_checkpoint ? "yes" : "no", pos_min, pos_max); // no need to create checkpoints that are too close together do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || pos_max > slot.prompt.checkpoints.back().pos_max + 64); @@ -2797,12 +2858,49 @@ private: const int64_t t_current = ggml_time_us(); - slot.n_decoded += ids.size(); + if (ids.size() < n_draft + 1 && slot.spec_has_ckpt) { + // the main model rejected some tokens, so we need to rollback to the state before sampling the draft tokens + auto & ckpt = slot.prompt.checkpoints.back(); + SLT_INF(slot, "partial acceptance: %zu < %zu, restoring checkpoint (pos_min = %d, pos_max = %d)\n", + ids.size() - 1, n_draft, + ckpt.pos_min, ckpt.pos_max); + const size_t n = llama_state_seq_set_data_ext(ctx, + ckpt.data.data(), ckpt.size(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + if (n != slot.spec_ckpt_size_part) { + GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu", + __func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), slot.spec_ckpt_size_part, n); + } + SRV_INF("partial acceptance: %zu < %zu, restored checkpoint: got %zu bytes\n", + ids.size() -1 , n_draft, n); + + // rollback to the state before sampling the draft tokens + SLT_INF(slot, "partial acceptance: n_tokens=%d, n_draft=%zu, pos_max=%d\n", + slot.prompt.n_tokens(), n_draft, ckpt.pos_max); + + slot.prompt.tokens.keep_first(ckpt.pos_max + 1); + + // Delete Checkpoint + slot.prompt.checkpoints.pop_back(); + slot.spec_has_ckpt = false; + + // Inform the speculative implementation of the number of valid tokens. + // common_speculative_accept(slot.spec, ids.size() - 1); + + slot.spec_n_denials++; + slot.spec_n_accepted = (slot.spec_n_denials < params_base.speculative.ckpt_num_tries) ? (int) (ids.size() - 1) : 0; + + common_batch_clear(batch); + + continue; + } + + slot.n_decoded += ids.size(); slot.t_token_generation = std::max(1, t_current - slot.t_start_generation) / 1e3; // update how many tokens out of those tested were accepted slot.n_draft_accepted += ids.size() - 1; + slot.spec_n_accepted = 0; // inform the speculative decoding about the number of accepted tokens common_speculative_accept(slot.spec, ids.size() - 1); @@ -2814,7 +2912,17 @@ private: slot.prompt.tokens.insert({ids.begin(), ids.end() - 1}); slot.sampled = ids.back(); // last accepted token - llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.prompt.n_tokens(), -1); + slot.spec_n_denials = 0; + if (slot.spec_has_ckpt) { + // Delete Checkpoint + if (slot.prompt.checkpoints.empty()) { + GGML_ABORT("missing checkpoint to delete"); + } + slot.prompt.checkpoints.pop_back(); + slot.spec_has_ckpt = false; + } else { + llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.prompt.n_tokens(), -1); + } for (size_t i = 0; i < ids.size(); ++i) { completion_token_output result;