From b5b3ac3b5527453f4e7e71be2f65bcce269c1183 Mon Sep 17 00:00:00 2001 From: Sascha Rogmann Date: Wed, 11 Mar 2026 22:18:25 +0100 Subject: [PATCH] server : fix server_speculative_callback (slot.id) --- tools/server/server-context.cpp | 56 ++++++++++++++++++++------------- 1 file changed, 35 insertions(+), 21 deletions(-) diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index db9ac5c9b6..5230f33034 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -630,61 +630,75 @@ private: // callback for speculative decoding // struct server_speculative_callback : public common_speculative_callback { - server_slot & slot; + int slot_id; // store slot.id instead of server_slot & slot server_context_impl & ctx_impl; - server_speculative_callback(server_slot & slot, server_context_impl & ctx_impl) - : slot(slot), ctx_impl(ctx_impl) {} + server_speculative_callback(int slot_id, server_context_impl & ctx_impl) + : slot_id(slot_id), ctx_impl(ctx_impl) {} + + server_slot * get_slot() { + server_slot * slot = ctx_impl.get_slot_by_id(slot_id); + if (slot == nullptr) { + GGML_ABORT("missing slot, slot.id=%d", slot_id); + } + return slot; + } void batch_add_token(const llama_token token, bool logits) override { - slot.i_batch_dft.push_back(ctx_impl.batch.n_tokens); - common_batch_add(ctx_impl.batch, token, slot.prompt.tokens.pos_next(), { slot.id }, logits); - slot.prompt.tokens.push_back(token); + server_slot * slot = get_slot(); + slot->i_batch_dft.push_back(ctx_impl.batch.n_tokens); + common_batch_add(ctx_impl.batch, token, slot->prompt.tokens.pos_next(), { slot_id }, logits); + slot->prompt.tokens.push_back(token); } std::vector sampler_sample_and_accept_n(const llama_tokens & drafted) override { - if (slot.i_batch_dft.size() != 1 + drafted.size()) { + const server_slot * slot = get_slot(); + if (slot->i_batch_dft.size() != 1 + drafted.size()) { GGML_ABORT("%s: #i_batch_dft = %zu != 1 + #drafted=%zu", - __func__, slot.i_batch_dft.size(), 1 + drafted.size()); + __func__, slot->i_batch_dft.size(), 1 + drafted.size()); } - const auto ids = common_sampler_sample_and_accept_n(slot.smpl.get(), ctx_impl.ctx, slot.i_batch_dft, drafted); + const auto ids = common_sampler_sample_and_accept_n(slot->smpl.get(), ctx_impl.ctx, slot->i_batch_dft, drafted); return ids; } bool memory_seq_rm(llama_pos p0, llama_pos p1) override { - return llama_memory_seq_rm(llama_get_memory(ctx_impl.ctx), slot.id, p0, p1); + return llama_memory_seq_rm(llama_get_memory(ctx_impl.ctx), slot_id, p0, p1); } size_t create_checkpoint() override { - const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_impl.ctx), slot.id); - const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_impl.ctx), slot.id); + server_slot * slot = get_slot(); + const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_impl.ctx), slot_id); + const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_impl.ctx), slot_id); const auto n_tokens_cur = 0; // TODO was ctx_impl.batch.n_tokens; The draft model doesn't change the prompt? - const auto & cur_with_size = ctx_impl.get_checkpoint(slot, n_tokens_cur, pos_min, pos_max); + const auto & cur_with_size = ctx_impl.get_checkpoint(*slot, n_tokens_cur, pos_min, pos_max); auto & cur = cur_with_size.checkpoint; - SLT_DBG(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", - (int) slot.prompt.checkpoints.size(), ctx_impl.params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); + SLT_DBG(*slot, "created context checkpoint %zu of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", + slot->prompt.checkpoints.size(), ctx_impl.params_base.n_ctx_checkpoints, + cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); return cur_with_size.size; } size_t restore_checkpoint(size_t ckpt_size_part_expected) override { - auto & ckpt = slot.prompt.checkpoints.back(); + server_slot * slot = get_slot(); + auto & ckpt = slot->prompt.checkpoints.back(); - SLT_DBG(slot, "restoring checkpoint (pos_min = %d, pos_max = %d)\n", ckpt.pos_min, ckpt.pos_max); + SLT_DBG(*slot, "restoring checkpoint (pos_min = %d, pos_max = %d)\n", ckpt.pos_min, ckpt.pos_max); const size_t n = llama_state_seq_set_data_ext(ctx_impl.ctx, - ckpt.data.data(), ckpt.size(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + ckpt.data.data(), ckpt.size(), slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); if (n != ckpt_size_part_expected) { GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu", __func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), ckpt_size_part_expected, n); } - slot.prompt.tokens.keep_first(ckpt.pos_max + 1); + slot->prompt.tokens.keep_first(ckpt.pos_max + 1); return n; } void delete_checkpoint() override { - slot.prompt.checkpoints.pop_back(); + server_slot * slot = get_slot(); + slot->prompt.checkpoints.pop_back(); } }; @@ -851,7 +865,7 @@ private: SRV_ERR("%s\n", "speculative decoding is not supported with multimodal"); return false; } - slot.spec_callback = std::make_unique(slot, *this); + slot.spec_callback = std::make_unique(slot.id, *this); slot.spec_session = std::make_unique(*slot.spec_callback, params_base.speculative, slot.ctx); SLT_INF(slot, "%s", "speculative decoding context initialized\n");