Compare commits

...

5 Commits

Author SHA1 Message Date
Sascha Rogmann 6bdff4e303
Merge c591189213 into 338085c69e 2026-02-12 14:49:07 -06:00
Sascha Rogmann c591189213 server : log levels 2026-02-10 23:51:27 +01:00
Sascha Rogmann cc6e40460a server : rename spec vars 2026-02-10 22:57:47 +01:00
Sascha Rogmann d03ebf3293 server : fix draft check with checkpoints 2026-02-10 22:32:59 +01:00
Sascha Rogmann 8a4fe64310 server : speculative decoding using checkpoints 2026-02-09 23:21:58 +01:00
4 changed files with 136 additions and 8 deletions

View File

@ -3447,6 +3447,16 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.speculative.ngram_min_hits = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--spec-ckpt-num-tries"}, "N",
string_format("number of tries for speculative decoding with recurrent memory (default: %d)", params.speculative.ckpt_num_tries),
[](common_params & params, int value) {
if (value < 0 || value > 10) {
throw std::invalid_argument("number of tries must be between 0 and 10 inclusive");
}
params.speculative.ckpt_num_tries = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"-ctkd", "--cache-type-k-draft"}, "TYPE",
string_format(

View File

@ -270,6 +270,7 @@ struct common_params_speculative {
uint16_t ngram_size_n = 12; // ngram size for lookup
uint16_t ngram_size_m = 48; // mgram size for speculative tokens
uint16_t ngram_min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed
uint16_t ckpt_num_tries = 0; // number of tries in case of recurrent memory
std::shared_ptr<common_ngram_mod> ngram_mod;

View File

@ -231,7 +231,7 @@ void common_ngram_map_draft(common_ngram_map & map,
GGML_ABORT("%s: cur_len exceeds UINT32_MAX: %zu", __func__, cur_len);
}
if (map.idx_last_check > cur_len) {
if (map.idx_last_check > cur_len) {
// Should not happen because of common_ngram_map_begin().
GGML_ABORT("%s: map.idx_last_check > cur_len: %zu > %zu", __func__, map.idx_last_check, cur_len);
}
@ -386,7 +386,7 @@ void common_ngram_map_draft(common_ngram_map & map,
LOG_DBG("%s: key_idx = %zu, key_offset = %zu, key_num = %d, draft.size = %zu\n", __func__,
curr_key.key_idx, key_offset, curr_key.key_num, draft.size());
map.last_draft_created = false;
map.last_draft_created = true;
map.last_draft_key_idx = key_offset;
map.last_draft_value_idx = 0; // value 0 is used for simple mode
return;

View File

@ -146,6 +146,13 @@ struct server_slot {
llama_token sampled; // in speculative mode, this is the last accepted token
llama_tokens drafted;
// use of checkpoints in speculative mode
bool spec_has_ckpt = false; // true if a checkpoint for rollback after partial speculation has been created
uint16_t spec_ckpt_n_denials = 0; // number of drafts not accepted at the current position
int spec_ckpt_n_accepted = 0; // number of accepted tokens at current position
size_t spec_ckpt_size_part = 0; // size of partial checkpoint
// stats
size_t n_sent_text = 0; // number of sent text character
@ -184,6 +191,11 @@ struct server_slot {
n_draft_total = 0;
n_draft_accepted = 0;
spec_ckpt_n_denials = 0;
spec_ckpt_n_accepted = 0;
spec_has_ckpt = false;
spec_ckpt_size_part = 0;
task_prev = std::move(task);
task.reset();
@ -742,7 +754,7 @@ private:
const bool can_spec = common_speculative_is_compat(ctx);
if (!can_spec) {
SRV_WRN("%s", "speculative decoding not supported by this context\n");
SRV_WRN("%s", "speculative decoding not supported by this context without checkpoints\n");
}
// initialize slots
@ -757,7 +769,7 @@ private:
slot.prompt.tokens.has_mtmd = mctx != nullptr;
// try speculative decoding
if (can_spec) {
if (can_spec || params_base.speculative.ckpt_num_tries > 0) {
slot.spec = common_speculative_init(params_base.speculative, slot.ctx);
if (slot.spec) {
if (mctx) {
@ -2041,8 +2053,9 @@ private:
// generate draft tokens in speculative decoding mode
// TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK]
// perform the speculative drafting for all sequences at the same time in a single batch
const int n_draft_max = slot.get_n_draft_max();
if (n_draft_max > 0) {
const int n_draft_max = (slot.spec_ckpt_n_accepted > 0) ? slot.spec_ckpt_n_accepted : slot.get_n_draft_max();
if (n_draft_max > 0 && (params_base.speculative.ckpt_num_tries == 0
|| slot.spec_ckpt_n_denials < params_base.speculative.ckpt_num_tries)) {
if (mctx) {
// we should never reach this, as speculative is automatically disabled if mmproj is loaded
GGML_ABORT("not supported by multimodal");
@ -2059,8 +2072,52 @@ private:
draft.resize(n_draft_max);
}
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id);
bool do_checkpoint = !draft.empty() && params_base.speculative.ckpt_num_tries > 0
&& slot.prompt.checkpoints.size() < (size_t) params_base.n_ctx_checkpoints;
if (do_checkpoint && cached_text_tokens.size() > 5) {
SLT_DBG(slot, "draft.size = %zu, n_spec_denials = %d, #ckpts=%zu, do_checkpoint = %s, pos_min = %d, pos_max = %d, tokens=[..., %d, %d, %d]\n",
draft.size(), slot.spec_ckpt_n_denials,
slot.prompt.checkpoints.size(),
do_checkpoint ? "yes" : "no", pos_min, pos_max,
cached_text_tokens[cached_text_tokens.size() - 3],
cached_text_tokens[cached_text_tokens.size() - 2],
cached_text_tokens[cached_text_tokens.size() - 1]);
}
if (do_checkpoint) {
while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
// make room for the new checkpoint, if needed
const auto & cur = slot.prompt.checkpoints.front();
SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin());
}
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, 0);
auto & cur = slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{
/*.pos_min = */ pos_min,
/*.pos_max = */ pos_max,
/*.data = */ std::vector<uint8_t>(checkpoint_size),
});
const size_t n = llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
SLT_INF(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
(int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
slot.spec_ckpt_size_part = n;
slot.spec_has_ckpt = true;
}
// add the sampled token to the batch
slot.i_batch_dft.push_back(batch.n_tokens);
SLT_DBG(slot, "before common_batch_add: sampled=%d, pos_next=%d, tokens.size=%zu, tokens.last=%d\n",
slot.sampled, slot.prompt.tokens.pos_next(), slot.prompt.tokens.size(), slot.prompt.tokens[slot.prompt.tokens.size() -1]);
common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true);
slot.prompt.tokens.push_back(slot.sampled);
@ -2070,6 +2127,15 @@ private:
slot.i_batch = slot.i_batch_dft[0];
slot.drafted.clear();
slot.i_batch_dft.clear();
if (slot.spec_has_ckpt) {
slot.spec_ckpt_n_accepted = 0;
slot.spec_ckpt_n_denials = 0;
// Delete Checkpoint
slot.prompt.checkpoints.pop_back();
slot.spec_has_ckpt = false;
}
} else {
// keep track of total number of drafted tokens tested
slot.n_draft_total += draft.size();
@ -2086,6 +2152,9 @@ private:
// no speculative decoding
slot.i_batch = batch.n_tokens;
slot.spec_ckpt_n_denials = 0;
slot.spec_ckpt_n_accepted = 0;
common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true);
slot.prompt.tokens.push_back(slot.sampled);
@ -2538,6 +2607,7 @@ private:
// no need for empty or small checkpoints
do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64);
SLT_DBG(slot, "main/do_checkpoint = %s, pos_min = %d, pos_max = %d\n", do_checkpoint ? "yes" : "no", pos_min, pos_max);
// no need to create checkpoints that are too close together
do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || pos_max > slot.prompt.checkpoints.back().pos_max + 64);
@ -2797,12 +2867,49 @@ private:
const int64_t t_current = ggml_time_us();
slot.n_decoded += ids.size();
if (slot.spec_has_ckpt && ids.size() < n_draft + 1) {
// the main model rejected some tokens, so we need to rollback to the state before sampling the draft tokens
auto & ckpt = slot.prompt.checkpoints.back();
SLT_INF(slot, "partial acceptance: %zu < %zu, restoring checkpoint (pos_min = %d, pos_max = %d)\n",
ids.size() - 1, n_draft,
ckpt.pos_min, ckpt.pos_max);
const size_t n = llama_state_seq_set_data_ext(ctx,
ckpt.data.data(), ckpt.size(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
if (n != slot.spec_ckpt_size_part) {
GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu",
__func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), slot.spec_ckpt_size_part, n);
}
SRV_INF("partial acceptance: %zu < %zu, restored checkpoint: got %zu bytes\n",
ids.size() -1 , n_draft, n);
// rollback to the state before sampling the draft tokens
SLT_INF(slot, "partial acceptance: n_tokens=%d, n_draft=%zu, pos_max=%d\n",
slot.prompt.n_tokens(), n_draft, ckpt.pos_max);
slot.prompt.tokens.keep_first(ckpt.pos_max + 1);
// Delete Checkpoint
slot.prompt.checkpoints.pop_back();
slot.spec_has_ckpt = false;
// Inform the speculative implementation of the number of valid tokens.
// common_speculative_accept(slot.spec, ids.size() - 1);
slot.spec_ckpt_n_denials++;
slot.spec_ckpt_n_accepted = (slot.spec_ckpt_n_denials < params_base.speculative.ckpt_num_tries) ? (int) (ids.size() - 1) : 0;
common_batch_clear(batch);
continue;
}
slot.n_decoded += ids.size();
slot.t_token_generation = std::max<int64_t>(1, t_current - slot.t_start_generation) / 1e3;
// update how many tokens out of those tested were accepted
slot.n_draft_accepted += ids.size() - 1;
slot.spec_ckpt_n_accepted = 0;
// inform the speculative decoding about the number of accepted tokens
common_speculative_accept(slot.spec, ids.size() - 1);
@ -2814,7 +2921,17 @@ private:
slot.prompt.tokens.insert({ids.begin(), ids.end() - 1});
slot.sampled = ids.back(); // last accepted token
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.prompt.n_tokens(), -1);
slot.spec_ckpt_n_denials = 0;
if (slot.spec_has_ckpt) {
// Delete Checkpoint
if (slot.prompt.checkpoints.empty()) {
GGML_ABORT("missing checkpoint to delete");
}
slot.prompt.checkpoints.pop_back();
slot.spec_has_ckpt = false;
} else {
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.prompt.n_tokens(), -1);
}
for (size_t i = 0; i < ids.size(); ++i) {
completion_token_output result;