cont : remove server_prompt_checkpoint_with_size

This commit is contained in:
Georgi Gerganov 2026-04-03 16:35:23 +03:00
parent 8491e15405
commit e1141d1cd1
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
3 changed files with 12 additions and 21 deletions

View File

@ -13,6 +13,7 @@
#include <cstring>
#include <iomanip>
#include <map>
#include <cinttypes>
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
@ -240,9 +241,8 @@ struct common_speculative_state_draft : public common_speculative_state {
void begin(const llama_tokens & prompt) override {
if (use_checkpoint && ckpt.size() > 0) {
// delete checkpoint
LOG_DBG("%s: delete checkpoint, prompt.size=%zu, pos_min=%d, pos_max=%d, n_tokens=%zu, size=%.3f MiB\n",
__func__, prompt.size(),
ckpt.pos_min, ckpt.pos_max, ckpt.n_tokens, (float) ckpt.data.size() / 1024 / 1024);
LOG_DBG("%s: delete checkpoint, prompt.size=%zu, pos_min=%d, pos_max=%d, n_tokens=%" PRId64 ", size=%.3f MiB\n",
__func__, prompt.size(), ckpt.pos_min, ckpt.pos_max, ckpt.n_tokens, (float) ckpt.data.size() / 1024 / 1024);
ckpt.pos_min = 0;
ckpt.pos_max = 0;
ckpt.n_tokens = 0;
@ -1374,7 +1374,7 @@ struct common_speculative_session::impl {
return common_speculative_accept_response{std::move(ids), n_draft, false};
}
void rewind(const llama_pos p0) {
void rewind(llama_pos p0) {
spec_ckpt_n_denials = 0;
if (spec_has_ckpt) {
// Delete Checkpoint

View File

@ -65,7 +65,7 @@ struct common_speculative_callback {
virtual ~common_speculative_callback();
// Add a token to the draft sequence.
virtual void batch_add_token(const llama_token token, bool logits) = 0;
virtual void batch_add_token(llama_token token, bool logits) = 0;
// Sample and accept tokens from the main model.
virtual llama_tokens sampler_sample_and_accept_n(const llama_tokens & drafted) = 0;
@ -125,7 +125,7 @@ struct common_speculative_session {
common_speculative_accept_response sample_and_accept();
// rewind (because of a draft not fully accepted)
void rewind(const llama_pos p0);
void rewind(llama_pos p0);
// print statistics
void print_stats() const;

View File

@ -649,7 +649,7 @@ private:
return slot;
}
void batch_add_token(const llama_token token, bool logits) override {
void batch_add_token(llama_token token, bool logits) override {
server_slot * slot = get_slot();
slot->i_batch_dft.push_back(ctx_impl.batch.n_tokens);
common_batch_add(ctx_impl.batch, token, slot->prompt.tokens.pos_next(), { slot_id }, logits);
@ -676,8 +676,7 @@ private:
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_impl.ctx), slot_id);
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_impl.ctx), slot_id);
const auto n_tokens_cur = 0; // TODO was ctx_impl.batch.n_tokens; The draft model doesn't change the prompt?
const auto & cur_with_size = ctx_impl.get_checkpoint(*slot, n_tokens_cur, pos_min, pos_max);
auto & cur = cur_with_size.checkpoint;
const auto cur = ctx_impl.get_checkpoint(*slot, n_tokens_cur, pos_min, pos_max);
SLT_DBG(*slot, "created context checkpoint %zu of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
slot->prompt.checkpoints.size(), ctx_impl.params_base.n_ctx_checkpoints,
@ -689,7 +688,7 @@ private:
// save sampler (we may want to restore the RNG in the sampler after refusal of a draft)
slot->spec_saved_sampler = common_sampler_clone(slot->smpl.get());
return cur_with_size.size;
return cur.size();
}
size_t restore_checkpoint(size_t ckpt_size_part_expected) override {
@ -721,10 +720,8 @@ private:
server_slot * slot = get_slot();
slot->prompt.checkpoints.pop_back();
}
};
// load the model and initialize llama_context
// this may also be called to resume from sleeping state
bool load_model(const common_params & params) {
@ -1786,15 +1783,10 @@ private:
return true;
}
struct server_prompt_checkpoint_with_size {
server_prompt_checkpoint checkpoint;
size_t size;
};
// Creates a checkpoint.
//
// n_tokens_cur: the number of tokens added to the batch for the current slot
server_prompt_checkpoint_with_size get_checkpoint(server_slot & slot, const int64_t n_tokens_cur,
server_prompt_checkpoint get_checkpoint(server_slot & slot, const int64_t n_tokens_cur,
llama_pos pos_min, llama_pos pos_max) {
while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
// make room for the new checkpoint, if needed
@ -1820,7 +1812,7 @@ private:
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", checkpoint_size, n);
}
return server_prompt_checkpoint_with_size{ cur, checkpoint_size };
return cur;
}
void process_single_task(server_task && task) {
@ -2744,8 +2736,7 @@ private:
// note: we create the checkpoint before calling llama_decode(), so the current batch is not
// yet processed and therefore it is not part of the checkpoint.
if (do_checkpoint) {
auto cur_with_size = get_checkpoint(slot, n_tokens_cur, pos_min, pos_max);
auto & cur = cur_with_size.checkpoint;
const auto cur = get_checkpoint(slot, n_tokens_cur, pos_min, pos_max);
SLT_WRN(slot,
"created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64
", size = %.3f MiB)\n",