server : restore sampler in spec checkpoint and clear mem
This commit is contained in:
parent
7d2814a9bd
commit
d0a856895f
|
|
@ -1353,8 +1353,7 @@ struct common_speculative_session::impl {
|
||||||
spec_ckpt_n_denials++;
|
spec_ckpt_n_denials++;
|
||||||
if (ids.size() > 1u + static_cast<std::size_t>(params_spec.n_min) && spec_ckpt_n_denials == 1) {
|
if (ids.size() > 1u + static_cast<std::size_t>(params_spec.n_min) && spec_ckpt_n_denials == 1) {
|
||||||
// we will do the batch again but with the shortened draft
|
// we will do the batch again but with the shortened draft
|
||||||
//return common_speculative_accept_response(std::move(ids), n_draft, true);
|
return common_speculative_accept_response(std::move(ids), n_draft, true);
|
||||||
LOG_DBG("%s: partial draft disabled\n", __func__);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
LOG_DBG("%s: don't accept partial draft, n_draft=%zu, ids.size=%zu\n", __func__, n_draft, ids.size());
|
LOG_DBG("%s: don't accept partial draft, n_draft=%zu, ids.size=%zu\n", __func__, n_draft, ids.size());
|
||||||
|
|
|
||||||
|
|
@ -58,7 +58,8 @@ struct server_slot {
|
||||||
mtmd_context * mctx = nullptr;
|
mtmd_context * mctx = nullptr;
|
||||||
|
|
||||||
std::unique_ptr<common_speculative_callback> spec_callback;
|
std::unique_ptr<common_speculative_callback> spec_callback;
|
||||||
std::unique_ptr<common_speculative_session> spec_session = nullptr;
|
std::unique_ptr<common_speculative_session> spec_session = nullptr;
|
||||||
|
struct common_sampler * spec_saved_sampler = nullptr;
|
||||||
|
|
||||||
// TODO: move members that belong to the task (such as `generated_text`, `has_new_line`) to task_results_state
|
// TODO: move members that belong to the task (such as `generated_text`, `has_new_line`) to task_results_state
|
||||||
// see https://github.com/ggml-org/llama.cpp/pull/18283#issuecomment-3710175837
|
// see https://github.com/ggml-org/llama.cpp/pull/18283#issuecomment-3710175837
|
||||||
|
|
@ -607,6 +608,10 @@ private:
|
||||||
slot.spec_session->reset();
|
slot.spec_session->reset();
|
||||||
slot.spec_session = nullptr;
|
slot.spec_session = nullptr;
|
||||||
}
|
}
|
||||||
|
if (slot.spec_saved_sampler != nullptr) {
|
||||||
|
common_sampler_free(slot.spec_saved_sampler);
|
||||||
|
slot.spec_saved_sampler = nullptr;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_batch_free(batch);
|
llama_batch_free(batch);
|
||||||
|
|
@ -677,6 +682,13 @@ private:
|
||||||
SLT_DBG(*slot, "created context checkpoint %zu of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
|
SLT_DBG(*slot, "created context checkpoint %zu of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
|
||||||
slot->prompt.checkpoints.size(), ctx_impl.params_base.n_ctx_checkpoints,
|
slot->prompt.checkpoints.size(), ctx_impl.params_base.n_ctx_checkpoints,
|
||||||
cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
|
cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
|
||||||
|
|
||||||
|
if (slot->spec_saved_sampler != nullptr) {
|
||||||
|
common_sampler_free(slot->spec_saved_sampler);
|
||||||
|
}
|
||||||
|
// save sampler (we may want to restore the RNG in the sampler after refusal of a draft)
|
||||||
|
slot->spec_saved_sampler = common_sampler_clone(slot->smpl.get());
|
||||||
|
|
||||||
return cur_with_size.size;
|
return cur_with_size.size;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -691,8 +703,17 @@ private:
|
||||||
GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu",
|
GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu",
|
||||||
__func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), ckpt_size_part_expected, n);
|
__func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), ckpt_size_part_expected, n);
|
||||||
}
|
}
|
||||||
|
// remove entries after ckpt.pos_max
|
||||||
|
llama_memory_seq_rm(llama_get_memory(ctx_impl.ctx), slot->id, ckpt.pos_max + 1, -1);
|
||||||
|
|
||||||
slot->prompt.tokens.keep_first(ckpt.pos_max + 1);
|
slot->prompt.tokens.keep_first(ckpt.pos_max + 1);
|
||||||
|
|
||||||
|
if (slot->spec_saved_sampler != nullptr) {
|
||||||
|
slot->smpl.reset(slot->spec_saved_sampler);
|
||||||
|
|
||||||
|
slot->spec_saved_sampler = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
return n;
|
return n;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue