server : fix server_speculative_callback (slot.id)

This commit is contained in:
Sascha Rogmann 2026-03-11 22:18:25 +01:00
parent 91932ae05b
commit b5b3ac3b55
1 changed files with 35 additions and 21 deletions

View File

@ -630,61 +630,75 @@ private:
// callback for speculative decoding
//
struct server_speculative_callback : public common_speculative_callback {
server_slot & slot;
int slot_id; // store slot.id instead of server_slot & slot
server_context_impl & ctx_impl;
server_speculative_callback(server_slot & slot, server_context_impl & ctx_impl)
: slot(slot), ctx_impl(ctx_impl) {}
server_speculative_callback(int slot_id, server_context_impl & ctx_impl)
: slot_id(slot_id), ctx_impl(ctx_impl) {}
server_slot * get_slot() {
server_slot * slot = ctx_impl.get_slot_by_id(slot_id);
if (slot == nullptr) {
GGML_ABORT("missing slot, slot.id=%d", slot_id);
}
return slot;
}
void batch_add_token(const llama_token token, bool logits) override {
slot.i_batch_dft.push_back(ctx_impl.batch.n_tokens);
common_batch_add(ctx_impl.batch, token, slot.prompt.tokens.pos_next(), { slot.id }, logits);
slot.prompt.tokens.push_back(token);
server_slot * slot = get_slot();
slot->i_batch_dft.push_back(ctx_impl.batch.n_tokens);
common_batch_add(ctx_impl.batch, token, slot->prompt.tokens.pos_next(), { slot_id }, logits);
slot->prompt.tokens.push_back(token);
}
std::vector<llama_token> sampler_sample_and_accept_n(const llama_tokens & drafted) override {
if (slot.i_batch_dft.size() != 1 + drafted.size()) {
const server_slot * slot = get_slot();
if (slot->i_batch_dft.size() != 1 + drafted.size()) {
GGML_ABORT("%s: #i_batch_dft = %zu != 1 + #drafted=%zu",
__func__, slot.i_batch_dft.size(), 1 + drafted.size());
__func__, slot->i_batch_dft.size(), 1 + drafted.size());
}
const auto ids = common_sampler_sample_and_accept_n(slot.smpl.get(), ctx_impl.ctx, slot.i_batch_dft, drafted);
const auto ids = common_sampler_sample_and_accept_n(slot->smpl.get(), ctx_impl.ctx, slot->i_batch_dft, drafted);
return ids;
}
bool memory_seq_rm(llama_pos p0, llama_pos p1) override {
return llama_memory_seq_rm(llama_get_memory(ctx_impl.ctx), slot.id, p0, p1);
return llama_memory_seq_rm(llama_get_memory(ctx_impl.ctx), slot_id, p0, p1);
}
size_t create_checkpoint() override {
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_impl.ctx), slot.id);
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_impl.ctx), slot.id);
server_slot * slot = get_slot();
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_impl.ctx), slot_id);
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_impl.ctx), slot_id);
const auto n_tokens_cur = 0; // TODO was ctx_impl.batch.n_tokens; The draft model doesn't change the prompt?
const auto & cur_with_size = ctx_impl.get_checkpoint(slot, n_tokens_cur, pos_min, pos_max);
const auto & cur_with_size = ctx_impl.get_checkpoint(*slot, n_tokens_cur, pos_min, pos_max);
auto & cur = cur_with_size.checkpoint;
SLT_DBG(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
(int) slot.prompt.checkpoints.size(), ctx_impl.params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
SLT_DBG(*slot, "created context checkpoint %zu of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
slot->prompt.checkpoints.size(), ctx_impl.params_base.n_ctx_checkpoints,
cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
return cur_with_size.size;
}
size_t restore_checkpoint(size_t ckpt_size_part_expected) override {
auto & ckpt = slot.prompt.checkpoints.back();
server_slot * slot = get_slot();
auto & ckpt = slot->prompt.checkpoints.back();
SLT_DBG(slot, "restoring checkpoint (pos_min = %d, pos_max = %d)\n", ckpt.pos_min, ckpt.pos_max);
SLT_DBG(*slot, "restoring checkpoint (pos_min = %d, pos_max = %d)\n", ckpt.pos_min, ckpt.pos_max);
const size_t n = llama_state_seq_set_data_ext(ctx_impl.ctx,
ckpt.data.data(), ckpt.size(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
ckpt.data.data(), ckpt.size(), slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
if (n != ckpt_size_part_expected) {
GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu",
__func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), ckpt_size_part_expected, n);
}
slot.prompt.tokens.keep_first(ckpt.pos_max + 1);
slot->prompt.tokens.keep_first(ckpt.pos_max + 1);
return n;
}
void delete_checkpoint() override {
slot.prompt.checkpoints.pop_back();
server_slot * slot = get_slot();
slot->prompt.checkpoints.pop_back();
}
};
@ -851,7 +865,7 @@ private:
SRV_ERR("%s\n", "speculative decoding is not supported with multimodal");
return false;
}
slot.spec_callback = std::make_unique<server_speculative_callback>(slot, *this);
slot.spec_callback = std::make_unique<server_speculative_callback>(slot.id, *this);
slot.spec_session = std::make_unique<common_speculative_session>(*slot.spec_callback,
params_base.speculative, slot.ctx);
SLT_INF(slot, "%s", "speculative decoding context initialized\n");