server : fix server_speculative_callback (slot.id)
This commit is contained in:
parent
91932ae05b
commit
b5b3ac3b55
|
|
@ -630,61 +630,75 @@ private:
|
|||
// callback for speculative decoding
|
||||
//
|
||||
struct server_speculative_callback : public common_speculative_callback {
|
||||
server_slot & slot;
|
||||
int slot_id; // store slot.id instead of server_slot & slot
|
||||
server_context_impl & ctx_impl;
|
||||
|
||||
server_speculative_callback(server_slot & slot, server_context_impl & ctx_impl)
|
||||
: slot(slot), ctx_impl(ctx_impl) {}
|
||||
server_speculative_callback(int slot_id, server_context_impl & ctx_impl)
|
||||
: slot_id(slot_id), ctx_impl(ctx_impl) {}
|
||||
|
||||
server_slot * get_slot() {
|
||||
server_slot * slot = ctx_impl.get_slot_by_id(slot_id);
|
||||
if (slot == nullptr) {
|
||||
GGML_ABORT("missing slot, slot.id=%d", slot_id);
|
||||
}
|
||||
return slot;
|
||||
}
|
||||
|
||||
void batch_add_token(const llama_token token, bool logits) override {
|
||||
slot.i_batch_dft.push_back(ctx_impl.batch.n_tokens);
|
||||
common_batch_add(ctx_impl.batch, token, slot.prompt.tokens.pos_next(), { slot.id }, logits);
|
||||
slot.prompt.tokens.push_back(token);
|
||||
server_slot * slot = get_slot();
|
||||
slot->i_batch_dft.push_back(ctx_impl.batch.n_tokens);
|
||||
common_batch_add(ctx_impl.batch, token, slot->prompt.tokens.pos_next(), { slot_id }, logits);
|
||||
slot->prompt.tokens.push_back(token);
|
||||
}
|
||||
|
||||
std::vector<llama_token> sampler_sample_and_accept_n(const llama_tokens & drafted) override {
|
||||
if (slot.i_batch_dft.size() != 1 + drafted.size()) {
|
||||
const server_slot * slot = get_slot();
|
||||
if (slot->i_batch_dft.size() != 1 + drafted.size()) {
|
||||
GGML_ABORT("%s: #i_batch_dft = %zu != 1 + #drafted=%zu",
|
||||
__func__, slot.i_batch_dft.size(), 1 + drafted.size());
|
||||
__func__, slot->i_batch_dft.size(), 1 + drafted.size());
|
||||
}
|
||||
const auto ids = common_sampler_sample_and_accept_n(slot.smpl.get(), ctx_impl.ctx, slot.i_batch_dft, drafted);
|
||||
const auto ids = common_sampler_sample_and_accept_n(slot->smpl.get(), ctx_impl.ctx, slot->i_batch_dft, drafted);
|
||||
|
||||
return ids;
|
||||
}
|
||||
|
||||
bool memory_seq_rm(llama_pos p0, llama_pos p1) override {
|
||||
return llama_memory_seq_rm(llama_get_memory(ctx_impl.ctx), slot.id, p0, p1);
|
||||
return llama_memory_seq_rm(llama_get_memory(ctx_impl.ctx), slot_id, p0, p1);
|
||||
}
|
||||
|
||||
size_t create_checkpoint() override {
|
||||
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_impl.ctx), slot.id);
|
||||
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_impl.ctx), slot.id);
|
||||
server_slot * slot = get_slot();
|
||||
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_impl.ctx), slot_id);
|
||||
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_impl.ctx), slot_id);
|
||||
const auto n_tokens_cur = 0; // TODO was ctx_impl.batch.n_tokens; The draft model doesn't change the prompt?
|
||||
const auto & cur_with_size = ctx_impl.get_checkpoint(slot, n_tokens_cur, pos_min, pos_max);
|
||||
const auto & cur_with_size = ctx_impl.get_checkpoint(*slot, n_tokens_cur, pos_min, pos_max);
|
||||
auto & cur = cur_with_size.checkpoint;
|
||||
|
||||
SLT_DBG(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
|
||||
(int) slot.prompt.checkpoints.size(), ctx_impl.params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
|
||||
SLT_DBG(*slot, "created context checkpoint %zu of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
|
||||
slot->prompt.checkpoints.size(), ctx_impl.params_base.n_ctx_checkpoints,
|
||||
cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
|
||||
return cur_with_size.size;
|
||||
}
|
||||
|
||||
size_t restore_checkpoint(size_t ckpt_size_part_expected) override {
|
||||
auto & ckpt = slot.prompt.checkpoints.back();
|
||||
server_slot * slot = get_slot();
|
||||
auto & ckpt = slot->prompt.checkpoints.back();
|
||||
|
||||
SLT_DBG(slot, "restoring checkpoint (pos_min = %d, pos_max = %d)\n", ckpt.pos_min, ckpt.pos_max);
|
||||
SLT_DBG(*slot, "restoring checkpoint (pos_min = %d, pos_max = %d)\n", ckpt.pos_min, ckpt.pos_max);
|
||||
const size_t n = llama_state_seq_set_data_ext(ctx_impl.ctx,
|
||||
ckpt.data.data(), ckpt.size(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
ckpt.data.data(), ckpt.size(), slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
if (n != ckpt_size_part_expected) {
|
||||
GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu",
|
||||
__func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), ckpt_size_part_expected, n);
|
||||
}
|
||||
|
||||
slot.prompt.tokens.keep_first(ckpt.pos_max + 1);
|
||||
slot->prompt.tokens.keep_first(ckpt.pos_max + 1);
|
||||
return n;
|
||||
}
|
||||
|
||||
void delete_checkpoint() override {
|
||||
slot.prompt.checkpoints.pop_back();
|
||||
server_slot * slot = get_slot();
|
||||
slot->prompt.checkpoints.pop_back();
|
||||
}
|
||||
|
||||
};
|
||||
|
|
@ -851,7 +865,7 @@ private:
|
|||
SRV_ERR("%s\n", "speculative decoding is not supported with multimodal");
|
||||
return false;
|
||||
}
|
||||
slot.spec_callback = std::make_unique<server_speculative_callback>(slot, *this);
|
||||
slot.spec_callback = std::make_unique<server_speculative_callback>(slot.id, *this);
|
||||
slot.spec_session = std::make_unique<common_speculative_session>(*slot.spec_callback,
|
||||
params_base.speculative, slot.ctx);
|
||||
SLT_INF(slot, "%s", "speculative decoding context initialized\n");
|
||||
|
|
|
|||
Loading…
Reference in New Issue