server : fix draft check with checkpoints
This commit is contained in:
parent
8a4fe64310
commit
d03ebf3293
|
|
@ -191,6 +191,11 @@ struct server_slot {
|
|||
n_draft_total = 0;
|
||||
n_draft_accepted = 0;
|
||||
|
||||
spec_n_denials = 0;
|
||||
spec_n_accepted = 0;
|
||||
spec_has_ckpt = false;
|
||||
spec_ckpt_size_part = 0;
|
||||
|
||||
task_prev = std::move(task);
|
||||
task.reset();
|
||||
|
||||
|
|
@ -2049,7 +2054,8 @@ private:
|
|||
// TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK]
|
||||
// perform the speculative drafting for all sequences at the same time in a single batch
|
||||
const int n_draft_max = (slot.spec_n_accepted > 0) ? slot.spec_n_accepted : slot.get_n_draft_max();
|
||||
if (n_draft_max > 0 && slot.spec_n_denials < params_base.speculative.ckpt_num_tries) {
|
||||
if (n_draft_max > 0 && (params_base.speculative.ckpt_num_tries == 0
|
||||
|| slot.spec_n_denials < params_base.speculative.ckpt_num_tries)) {
|
||||
if (mctx) {
|
||||
// we should never reach this, as speculative is automatically disabled if mmproj is loaded
|
||||
GGML_ABORT("not supported by multimodal");
|
||||
|
|
@ -2068,16 +2074,19 @@ private:
|
|||
|
||||
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
|
||||
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id);
|
||||
bool do_checkpoint = !draft.empty() && slot.prompt.checkpoints.size() < (size_t) params_base.n_ctx_checkpoints;
|
||||
SLT_DBG(slot, "draft.size = %zu, n_spec_denials = %d, #ckpts=%zu, do_checkpoint = %s, pos_min = %d, pos_max = %d, tokens=[..., %d, %d, %d]\n",
|
||||
draft.size(), slot.spec_n_denials,
|
||||
slot.prompt.checkpoints.size(),
|
||||
do_checkpoint ? "yes" : "no", pos_min, pos_max,
|
||||
cached_text_tokens[cached_text_tokens.size() - 3],
|
||||
cached_text_tokens[cached_text_tokens.size() - 2],
|
||||
cached_text_tokens[cached_text_tokens.size() - 1]);
|
||||
bool do_checkpoint = !draft.empty() && params_base.speculative.ckpt_num_tries > 0
|
||||
&& slot.prompt.checkpoints.size() < (size_t) params_base.n_ctx_checkpoints;
|
||||
if (do_checkpoint && cached_text_tokens.size() > 5) {
|
||||
SLT_DBG(slot, "draft.size = %zu, n_spec_denials = %d, #ckpts=%zu, do_checkpoint = %s, pos_min = %d, pos_max = %d, tokens=[..., %d, %d, %d]\n",
|
||||
draft.size(), slot.spec_n_denials,
|
||||
slot.prompt.checkpoints.size(),
|
||||
do_checkpoint ? "yes" : "no", pos_min, pos_max,
|
||||
cached_text_tokens[cached_text_tokens.size() - 3],
|
||||
cached_text_tokens[cached_text_tokens.size() - 2],
|
||||
cached_text_tokens[cached_text_tokens.size() - 1]);
|
||||
}
|
||||
|
||||
if (do_checkpoint) {
|
||||
if (do_checkpoint) {
|
||||
while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
|
||||
// make room for the new checkpoint, if needed
|
||||
const auto & cur = slot.prompt.checkpoints.front();
|
||||
|
|
@ -2859,7 +2868,7 @@ private:
|
|||
const int64_t t_current = ggml_time_us();
|
||||
|
||||
|
||||
if (ids.size() < n_draft + 1 && slot.spec_has_ckpt) {
|
||||
if (slot.spec_has_ckpt && ids.size() < n_draft + 1) {
|
||||
// the main model rejected some tokens, so we need to rollback to the state before sampling the draft tokens
|
||||
auto & ckpt = slot.prompt.checkpoints.back();
|
||||
SLT_INF(slot, "partial acceptance: %zu < %zu, restoring checkpoint (pos_min = %d, pos_max = %d)\n",
|
||||
|
|
|
|||
Loading…
Reference in New Issue