server : fix server_speculative_callback (slot.id)
This commit is contained in:
parent
91932ae05b
commit
b5b3ac3b55
|
|
@ -630,61 +630,75 @@ private:
|
||||||
// callback for speculative decoding
|
// callback for speculative decoding
|
||||||
//
|
//
|
||||||
struct server_speculative_callback : public common_speculative_callback {
|
struct server_speculative_callback : public common_speculative_callback {
|
||||||
server_slot & slot;
|
int slot_id; // store slot.id instead of server_slot & slot
|
||||||
server_context_impl & ctx_impl;
|
server_context_impl & ctx_impl;
|
||||||
|
|
||||||
server_speculative_callback(server_slot & slot, server_context_impl & ctx_impl)
|
server_speculative_callback(int slot_id, server_context_impl & ctx_impl)
|
||||||
: slot(slot), ctx_impl(ctx_impl) {}
|
: slot_id(slot_id), ctx_impl(ctx_impl) {}
|
||||||
|
|
||||||
|
server_slot * get_slot() {
|
||||||
|
server_slot * slot = ctx_impl.get_slot_by_id(slot_id);
|
||||||
|
if (slot == nullptr) {
|
||||||
|
GGML_ABORT("missing slot, slot.id=%d", slot_id);
|
||||||
|
}
|
||||||
|
return slot;
|
||||||
|
}
|
||||||
|
|
||||||
void batch_add_token(const llama_token token, bool logits) override {
|
void batch_add_token(const llama_token token, bool logits) override {
|
||||||
slot.i_batch_dft.push_back(ctx_impl.batch.n_tokens);
|
server_slot * slot = get_slot();
|
||||||
common_batch_add(ctx_impl.batch, token, slot.prompt.tokens.pos_next(), { slot.id }, logits);
|
slot->i_batch_dft.push_back(ctx_impl.batch.n_tokens);
|
||||||
slot.prompt.tokens.push_back(token);
|
common_batch_add(ctx_impl.batch, token, slot->prompt.tokens.pos_next(), { slot_id }, logits);
|
||||||
|
slot->prompt.tokens.push_back(token);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<llama_token> sampler_sample_and_accept_n(const llama_tokens & drafted) override {
|
std::vector<llama_token> sampler_sample_and_accept_n(const llama_tokens & drafted) override {
|
||||||
if (slot.i_batch_dft.size() != 1 + drafted.size()) {
|
const server_slot * slot = get_slot();
|
||||||
|
if (slot->i_batch_dft.size() != 1 + drafted.size()) {
|
||||||
GGML_ABORT("%s: #i_batch_dft = %zu != 1 + #drafted=%zu",
|
GGML_ABORT("%s: #i_batch_dft = %zu != 1 + #drafted=%zu",
|
||||||
__func__, slot.i_batch_dft.size(), 1 + drafted.size());
|
__func__, slot->i_batch_dft.size(), 1 + drafted.size());
|
||||||
}
|
}
|
||||||
const auto ids = common_sampler_sample_and_accept_n(slot.smpl.get(), ctx_impl.ctx, slot.i_batch_dft, drafted);
|
const auto ids = common_sampler_sample_and_accept_n(slot->smpl.get(), ctx_impl.ctx, slot->i_batch_dft, drafted);
|
||||||
|
|
||||||
return ids;
|
return ids;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool memory_seq_rm(llama_pos p0, llama_pos p1) override {
|
bool memory_seq_rm(llama_pos p0, llama_pos p1) override {
|
||||||
return llama_memory_seq_rm(llama_get_memory(ctx_impl.ctx), slot.id, p0, p1);
|
return llama_memory_seq_rm(llama_get_memory(ctx_impl.ctx), slot_id, p0, p1);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t create_checkpoint() override {
|
size_t create_checkpoint() override {
|
||||||
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_impl.ctx), slot.id);
|
server_slot * slot = get_slot();
|
||||||
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_impl.ctx), slot.id);
|
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_impl.ctx), slot_id);
|
||||||
|
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_impl.ctx), slot_id);
|
||||||
const auto n_tokens_cur = 0; // TODO was ctx_impl.batch.n_tokens; The draft model doesn't change the prompt?
|
const auto n_tokens_cur = 0; // TODO was ctx_impl.batch.n_tokens; The draft model doesn't change the prompt?
|
||||||
const auto & cur_with_size = ctx_impl.get_checkpoint(slot, n_tokens_cur, pos_min, pos_max);
|
const auto & cur_with_size = ctx_impl.get_checkpoint(*slot, n_tokens_cur, pos_min, pos_max);
|
||||||
auto & cur = cur_with_size.checkpoint;
|
auto & cur = cur_with_size.checkpoint;
|
||||||
|
|
||||||
SLT_DBG(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
|
SLT_DBG(*slot, "created context checkpoint %zu of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
|
||||||
(int) slot.prompt.checkpoints.size(), ctx_impl.params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
|
slot->prompt.checkpoints.size(), ctx_impl.params_base.n_ctx_checkpoints,
|
||||||
|
cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
|
||||||
return cur_with_size.size;
|
return cur_with_size.size;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t restore_checkpoint(size_t ckpt_size_part_expected) override {
|
size_t restore_checkpoint(size_t ckpt_size_part_expected) override {
|
||||||
auto & ckpt = slot.prompt.checkpoints.back();
|
server_slot * slot = get_slot();
|
||||||
|
auto & ckpt = slot->prompt.checkpoints.back();
|
||||||
|
|
||||||
SLT_DBG(slot, "restoring checkpoint (pos_min = %d, pos_max = %d)\n", ckpt.pos_min, ckpt.pos_max);
|
SLT_DBG(*slot, "restoring checkpoint (pos_min = %d, pos_max = %d)\n", ckpt.pos_min, ckpt.pos_max);
|
||||||
const size_t n = llama_state_seq_set_data_ext(ctx_impl.ctx,
|
const size_t n = llama_state_seq_set_data_ext(ctx_impl.ctx,
|
||||||
ckpt.data.data(), ckpt.size(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
ckpt.data.data(), ckpt.size(), slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||||
if (n != ckpt_size_part_expected) {
|
if (n != ckpt_size_part_expected) {
|
||||||
GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu",
|
GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu",
|
||||||
__func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), ckpt_size_part_expected, n);
|
__func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), ckpt_size_part_expected, n);
|
||||||
}
|
}
|
||||||
|
|
||||||
slot.prompt.tokens.keep_first(ckpt.pos_max + 1);
|
slot->prompt.tokens.keep_first(ckpt.pos_max + 1);
|
||||||
return n;
|
return n;
|
||||||
}
|
}
|
||||||
|
|
||||||
void delete_checkpoint() override {
|
void delete_checkpoint() override {
|
||||||
slot.prompt.checkpoints.pop_back();
|
server_slot * slot = get_slot();
|
||||||
|
slot->prompt.checkpoints.pop_back();
|
||||||
}
|
}
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
@ -851,7 +865,7 @@ private:
|
||||||
SRV_ERR("%s\n", "speculative decoding is not supported with multimodal");
|
SRV_ERR("%s\n", "speculative decoding is not supported with multimodal");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
slot.spec_callback = std::make_unique<server_speculative_callback>(slot, *this);
|
slot.spec_callback = std::make_unique<server_speculative_callback>(slot.id, *this);
|
||||||
slot.spec_session = std::make_unique<common_speculative_session>(*slot.spec_callback,
|
slot.spec_session = std::make_unique<common_speculative_session>(*slot.spec_callback,
|
||||||
params_base.speculative, slot.ctx);
|
params_base.speculative, slot.ctx);
|
||||||
SLT_INF(slot, "%s", "speculative decoding context initialized\n");
|
SLT_INF(slot, "%s", "speculative decoding context initialized\n");
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue