server : fix server_speculative_callback (slot.id)

This commit is contained in:
Sascha Rogmann 2026-03-11 22:18:25 +01:00
parent 91932ae05b
commit b5b3ac3b55
1 changed files with 35 additions and 21 deletions

View File

@ -630,61 +630,75 @@ private:
// callback for speculative decoding // callback for speculative decoding
// //
struct server_speculative_callback : public common_speculative_callback { struct server_speculative_callback : public common_speculative_callback {
server_slot & slot; int slot_id; // store slot.id instead of server_slot & slot
server_context_impl & ctx_impl; server_context_impl & ctx_impl;
server_speculative_callback(server_slot & slot, server_context_impl & ctx_impl) server_speculative_callback(int slot_id, server_context_impl & ctx_impl)
: slot(slot), ctx_impl(ctx_impl) {} : slot_id(slot_id), ctx_impl(ctx_impl) {}
server_slot * get_slot() {
server_slot * slot = ctx_impl.get_slot_by_id(slot_id);
if (slot == nullptr) {
GGML_ABORT("missing slot, slot.id=%d", slot_id);
}
return slot;
}
void batch_add_token(const llama_token token, bool logits) override { void batch_add_token(const llama_token token, bool logits) override {
slot.i_batch_dft.push_back(ctx_impl.batch.n_tokens); server_slot * slot = get_slot();
common_batch_add(ctx_impl.batch, token, slot.prompt.tokens.pos_next(), { slot.id }, logits); slot->i_batch_dft.push_back(ctx_impl.batch.n_tokens);
slot.prompt.tokens.push_back(token); common_batch_add(ctx_impl.batch, token, slot->prompt.tokens.pos_next(), { slot_id }, logits);
slot->prompt.tokens.push_back(token);
} }
std::vector<llama_token> sampler_sample_and_accept_n(const llama_tokens & drafted) override { std::vector<llama_token> sampler_sample_and_accept_n(const llama_tokens & drafted) override {
if (slot.i_batch_dft.size() != 1 + drafted.size()) { const server_slot * slot = get_slot();
if (slot->i_batch_dft.size() != 1 + drafted.size()) {
GGML_ABORT("%s: #i_batch_dft = %zu != 1 + #drafted=%zu", GGML_ABORT("%s: #i_batch_dft = %zu != 1 + #drafted=%zu",
__func__, slot.i_batch_dft.size(), 1 + drafted.size()); __func__, slot->i_batch_dft.size(), 1 + drafted.size());
} }
const auto ids = common_sampler_sample_and_accept_n(slot.smpl.get(), ctx_impl.ctx, slot.i_batch_dft, drafted); const auto ids = common_sampler_sample_and_accept_n(slot->smpl.get(), ctx_impl.ctx, slot->i_batch_dft, drafted);
return ids; return ids;
} }
bool memory_seq_rm(llama_pos p0, llama_pos p1) override { bool memory_seq_rm(llama_pos p0, llama_pos p1) override {
return llama_memory_seq_rm(llama_get_memory(ctx_impl.ctx), slot.id, p0, p1); return llama_memory_seq_rm(llama_get_memory(ctx_impl.ctx), slot_id, p0, p1);
} }
size_t create_checkpoint() override { size_t create_checkpoint() override {
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_impl.ctx), slot.id); server_slot * slot = get_slot();
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_impl.ctx), slot.id); const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_impl.ctx), slot_id);
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_impl.ctx), slot_id);
const auto n_tokens_cur = 0; // TODO was ctx_impl.batch.n_tokens; The draft model doesn't change the prompt? const auto n_tokens_cur = 0; // TODO was ctx_impl.batch.n_tokens; The draft model doesn't change the prompt?
const auto & cur_with_size = ctx_impl.get_checkpoint(slot, n_tokens_cur, pos_min, pos_max); const auto & cur_with_size = ctx_impl.get_checkpoint(*slot, n_tokens_cur, pos_min, pos_max);
auto & cur = cur_with_size.checkpoint; auto & cur = cur_with_size.checkpoint;
SLT_DBG(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", SLT_DBG(*slot, "created context checkpoint %zu of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
(int) slot.prompt.checkpoints.size(), ctx_impl.params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); slot->prompt.checkpoints.size(), ctx_impl.params_base.n_ctx_checkpoints,
cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
return cur_with_size.size; return cur_with_size.size;
} }
size_t restore_checkpoint(size_t ckpt_size_part_expected) override { size_t restore_checkpoint(size_t ckpt_size_part_expected) override {
auto & ckpt = slot.prompt.checkpoints.back(); server_slot * slot = get_slot();
auto & ckpt = slot->prompt.checkpoints.back();
SLT_DBG(slot, "restoring checkpoint (pos_min = %d, pos_max = %d)\n", ckpt.pos_min, ckpt.pos_max); SLT_DBG(*slot, "restoring checkpoint (pos_min = %d, pos_max = %d)\n", ckpt.pos_min, ckpt.pos_max);
const size_t n = llama_state_seq_set_data_ext(ctx_impl.ctx, const size_t n = llama_state_seq_set_data_ext(ctx_impl.ctx,
ckpt.data.data(), ckpt.size(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); ckpt.data.data(), ckpt.size(), slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
if (n != ckpt_size_part_expected) { if (n != ckpt_size_part_expected) {
GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu", GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu",
__func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), ckpt_size_part_expected, n); __func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), ckpt_size_part_expected, n);
} }
slot.prompt.tokens.keep_first(ckpt.pos_max + 1); slot->prompt.tokens.keep_first(ckpt.pos_max + 1);
return n; return n;
} }
void delete_checkpoint() override { void delete_checkpoint() override {
slot.prompt.checkpoints.pop_back(); server_slot * slot = get_slot();
slot->prompt.checkpoints.pop_back();
} }
}; };
@ -851,7 +865,7 @@ private:
SRV_ERR("%s\n", "speculative decoding is not supported with multimodal"); SRV_ERR("%s\n", "speculative decoding is not supported with multimodal");
return false; return false;
} }
slot.spec_callback = std::make_unique<server_speculative_callback>(slot, *this); slot.spec_callback = std::make_unique<server_speculative_callback>(slot.id, *this);
slot.spec_session = std::make_unique<common_speculative_session>(*slot.spec_callback, slot.spec_session = std::make_unique<common_speculative_session>(*slot.spec_callback,
params_base.speculative, slot.ctx); params_base.speculative, slot.ctx);
SLT_INF(slot, "%s", "speculative decoding context initialized\n"); SLT_INF(slot, "%s", "speculative decoding context initialized\n");