server : restore sampler in spec checkpoint and clear mem

This commit is contained in:
Sascha Rogmann 2026-03-26 23:37:05 +01:00
parent 7d2814a9bd
commit d0a856895f
2 changed files with 23 additions and 3 deletions

View File

@ -1353,8 +1353,7 @@ struct common_speculative_session::impl {
spec_ckpt_n_denials++;
if (ids.size() > 1u + static_cast<std::size_t>(params_spec.n_min) && spec_ckpt_n_denials == 1) {
// we will do the batch again but with the shortened draft
//return common_speculative_accept_response(std::move(ids), n_draft, true);
LOG_DBG("%s: partial draft disabled\n", __func__);
return common_speculative_accept_response(std::move(ids), n_draft, true);
}
LOG_DBG("%s: don't accept partial draft, n_draft=%zu, ids.size=%zu\n", __func__, n_draft, ids.size());

View File

@ -58,7 +58,8 @@ struct server_slot {
mtmd_context * mctx = nullptr;
std::unique_ptr<common_speculative_callback> spec_callback;
std::unique_ptr<common_speculative_session> spec_session = nullptr;
std::unique_ptr<common_speculative_session> spec_session = nullptr;
struct common_sampler * spec_saved_sampler = nullptr;
// TODO: move members that belong to the task (such as `generated_text`, `has_new_line`) to task_results_state
// see https://github.com/ggml-org/llama.cpp/pull/18283#issuecomment-3710175837
@ -607,6 +608,10 @@ private:
slot.spec_session->reset();
slot.spec_session = nullptr;
}
if (slot.spec_saved_sampler != nullptr) {
common_sampler_free(slot.spec_saved_sampler);
slot.spec_saved_sampler = nullptr;
}
}
llama_batch_free(batch);
@ -677,6 +682,13 @@ private:
SLT_DBG(*slot, "created context checkpoint %zu of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
slot->prompt.checkpoints.size(), ctx_impl.params_base.n_ctx_checkpoints,
cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
if (slot->spec_saved_sampler != nullptr) {
common_sampler_free(slot->spec_saved_sampler);
}
// save sampler (we may want to restore the RNG in the sampler after refusal of a draft)
slot->spec_saved_sampler = common_sampler_clone(slot->smpl.get());
return cur_with_size.size;
}
@ -691,8 +703,17 @@ private:
GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu",
__func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), ckpt_size_part_expected, n);
}
// remove entries after ckpt.pos_max
llama_memory_seq_rm(llama_get_memory(ctx_impl.ctx), slot->id, ckpt.pos_max + 1, -1);
slot->prompt.tokens.keep_first(ckpt.pos_max + 1);
if (slot->spec_saved_sampler != nullptr) {
slot->smpl.reset(slot->spec_saved_sampler);
slot->spec_saved_sampler = nullptr;
}
return n;
}