server : fix spec checkpoints, logging

This commit is contained in:
Sascha Rogmann 2026-03-01 23:40:35 +01:00
parent bd2f7f2d7f
commit af3b630e0b
3 changed files with 54 additions and 46 deletions

View File

@ -1136,13 +1136,15 @@ struct common_speculative_session::impl {
clear_draft();
return draft;
}
if (params_spec.use_checkpoints
&& spec_ckpt_n_denials > 0) {
if (params_spec.use_checkpoints && spec_ckpt_n_denials > 1) {
// We shouldn't get two denials.
LOG_WRN("%s: #tokens=%zu, spec_ckpt_n_denials=%d, id_last=%d, #draft=%zu\n", __func__,
cached_text_tokens.size(), spec_ckpt_n_denials, id_last, draft.size());
clear_draft();
return draft;
}
if (spec_ckpt_n_denials > 0) {
if (spec_ckpt_n_denials == 1) {
// there is a previous speculation which wasn't accepted in full length
if (draft.empty()) {
LOG_WRN("%s: draft of length 0 after denied checkpoint\n", __func__);
@ -1150,7 +1152,8 @@ struct common_speculative_session::impl {
return draft;
}
// we use the shortened draft of previous speculation
LOG_INF("%s: resuse shortened draft, size=%zu\n", __func__, draft.size());
LOG_DBG("%s: reuse shortened draft, #tokens=%zu, id_last=%d, size=%zu\n", __func__,
cached_text_tokens.size(), id_last, draft.size());
} else {
// call the speculative implementation to create a draft
draft = common_speculative_draft(spec, params_spec, cached_text_tokens, id_last);
@ -1167,24 +1170,16 @@ struct common_speculative_session::impl {
}
bool do_checkpoint = !draft.empty() && params_spec.use_checkpoints;
if (do_checkpoint && cached_text_tokens.size() > 5) {
LOG_DBG("draft.size = %zu, n_spec_denials = %d, do_checkpoint = %s, tokens=[..., %d, %d, %d]\n",
if (do_checkpoint && cached_text_tokens.size() > 5 && draft.size() >= 3) {
LOG_DBG("%s: #tokens=%zu, draft.size=%zu, n_spec_denials=%d, do_checkpoint=%s, id_last=%d, tokens=[..., %d, %d, %d], draft=[%d, %d, %d, ...]\n",
__func__,
cached_text_tokens.size(),
draft.size(), spec_ckpt_n_denials,
do_checkpoint ? "yes" : "no",
do_checkpoint ? "yes" : "no", id_last,
cached_text_tokens[cached_text_tokens.size() - 3],
cached_text_tokens[cached_text_tokens.size() - 2],
cached_text_tokens[cached_text_tokens.size() - 1]);
}
if (do_checkpoint) {
const size_t n = callback.create_checkpoint();
if (n == 0) {
LOG_WRN("checkpoint creation failed");
clear_draft();
return draft;
}
spec_ckpt_size_part = n;
spec_has_ckpt = true;
cached_text_tokens[cached_text_tokens.size() - 1],
draft[0], draft[1], draft[2]);
}
if (params_spec.n_min > (int) draft.size()) {
@ -1193,6 +1188,17 @@ struct common_speculative_session::impl {
return draft;
}
if (do_checkpoint) {
const size_t n = callback.create_checkpoint();
if (n == 0) {
LOG_WRN("%s: checkpoint creation failed (#tokens=%zu)\n", __func__, cached_text_tokens.size());
clear_draft();
return draft;
}
spec_ckpt_size_part = n;
spec_has_ckpt = true;
}
// add last sampled token to the batch
callback.batch_add_token(id_last, true);
@ -1219,27 +1225,31 @@ struct common_speculative_session::impl {
if (spec_has_ckpt) {
// we need to rollback to the state before sampling the draft tokens
const size_t n = callback.restore_checkpoint(spec_ckpt_size_part);
LOG_INF("partial acceptance: %zu < %zu, restored checkpoint: got %zu bytes\n",
ids.size() -1 , n_draft, n);
LOG_DBG("%s: partial acceptance: %zu < %zu, restored checkpoint: got %zu bytes\n",
__func__,
ids.size() - 1, n_draft, n);
// rollback to the state before sampling the draft tokens
// Delete Checkpoint
// delete Checkpoint
callback.delete_checkpoint();
spec_has_ckpt = false;
if (n_draft > 0 && spec_ckpt_n_denials == 0) {
spec_ckpt_n_denials++;
if (ids.size() > 1u + static_cast<std::size_t>(params_spec.n_min) && spec_ckpt_n_denials == 1) {
// we will do the batch again but with the shortened draft
spec_ckpt_n_denials++;
return common_speculative_accept_response(std::move(ids), n_draft, true);
}
callback.batch_clear();
LOG_DBG("%s: don't accept partial draft, n_draft=%zu, ids.size=%zu\n", __func__, n_draft, ids.size());
draft.clear();
// use the sampled token only
ids.resize(1);
// drafted tokens in prompt have been deleted in restore_checkpoint(...).
return common_speculative_accept_response{std::move(ids), 0, false};
}
}
const size_t draft_size_accepted = draft.size();
LOG_DBG("%s: draft.size=%zu\n", __func__, draft_size_accepted);
LOG_DBG("%s: draft.size=%zu, ids.size=%zu\n", __func__, draft_size_accepted, ids.size());
common_speculative_accept(spec, draft_size_accepted);
draft.clear();

View File

@ -61,9 +61,6 @@ struct common_speculative_callback {
// Add a token to the draft sequence.
virtual void batch_add_token(const llama_token token, bool logits) = 0;
// Clears the batch context.
virtual void batch_clear() = 0;
// Sample and accept tokens from the main model.
virtual llama_tokens sampler_sample_and_accept_n(const llama_tokens & drafted) = 0;

View File

@ -1,3 +1,4 @@
#include "server-context.h"
#include "server-common.h"
#include "server-http.h"
@ -57,7 +58,7 @@ struct server_slot {
mtmd_context * mctx = nullptr;
std::unique_ptr<common_speculative_callback> spec_callback;
common_speculative_session * spec_session = nullptr;
std::unique_ptr<common_speculative_session> spec_session = nullptr;
// TODO: move members that belong to the task (such as `generated_text`, `has_new_line`) to task_results_state
// see https://github.com/ggml-org/llama.cpp/pull/18283#issuecomment-3710175837
@ -641,10 +642,6 @@ private:
slot.prompt.tokens.push_back(token);
}
void batch_clear() override {
common_batch_clear(ctx_impl.batch);
}
std::vector<llama_token> sampler_sample_and_accept_n(const llama_tokens & drafted) override {
if (slot.i_batch_dft.size() != 1 + drafted.size()) {
GGML_ABORT("%s: #i_batch_dft = %zu != 1 + #drafted=%zu",
@ -666,7 +663,7 @@ private:
const auto & cur_with_size = ctx_impl.get_checkpoint(slot, n_tokens_cur, pos_min, pos_max);
auto & cur = cur_with_size.checkpoint;
SLT_INF(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
SLT_DBG(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
(int) slot.prompt.checkpoints.size(), ctx_impl.params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
return cur_with_size.size;
}
@ -674,7 +671,7 @@ private:
size_t restore_checkpoint(size_t ckpt_size_part_expected) override {
auto & ckpt = slot.prompt.checkpoints.back();
SLT_INF(slot, "restoring checkpoint (pos_min = %d, pos_max = %d)\n", ckpt.pos_min, ckpt.pos_max);
SLT_DBG(slot, "restoring checkpoint (pos_min = %d, pos_max = %d)\n", ckpt.pos_min, ckpt.pos_max);
const size_t n = llama_state_seq_set_data_ext(ctx_impl.ctx,
ckpt.data.data(), ckpt.size(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
if (n != ckpt_size_part_expected) {
@ -855,7 +852,7 @@ private:
return false;
}
slot.spec_callback = std::make_unique<server_speculative_callback>(slot, *this);
slot.spec_session = new common_speculative_session(*slot.spec_callback,
slot.spec_session = std::make_unique<common_speculative_session>(*slot.spec_callback,
params_base.speculative, slot.ctx);
SLT_INF(slot, "%s", "speculative decoding context initialized\n");
}
@ -2180,12 +2177,16 @@ private:
// generate draft tokens in speculative decoding mode
// TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK]
// perform the speculative drafting for all sequences at the same time in a single batch
llama_tokens draft;
const int n_draft_max_slot = slot.get_n_draft_max();
const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens();
llama_tokens draft = slot.spec_session->compute_draft(cached_text_tokens, slot.sampled, n_draft_max_slot);
if (draft.size() > 0) {
SLT_DBG(slot, "compute_draft: #tokens=%d\n", (int) draft.size());
if (n_draft_max_slot > 0) {
const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens();
// compute draft and add draft to internal batch
draft = slot.spec_session->compute_draft(cached_text_tokens, slot.sampled, n_draft_max_slot);
if (draft.size() > 0) {
SLT_DBG(slot, "compute_draft: #cached_text_tokens=%zu, #tokens=%zu, #i_batch_dft=%zu\n",
cached_text_tokens.size(), draft.size(), slot.i_batch_dft.size());
}
}
if (draft.empty()) {
@ -2940,7 +2941,7 @@ private:
slot.i_batch_dft.clear();
const size_t n_draft = accept_response.draft_size_initial;
if (accept_response.skip_acceptance) {
SLT_INF(slot, "partial acceptance: n_tokens=%zu, n_draft=%zu\n", accept_response.tokens.size(), n_draft);
SLT_DBG(slot, "partial acceptance: n_tokens=%zu, n_draft=%zu\n", accept_response.tokens.size(), n_draft);
continue;
}
const auto ids = accept_response.tokens;