server : fix draft check with checkpoints
This commit is contained in:
parent
07747a0836
commit
c9fc6af71d
|
|
@ -191,6 +191,11 @@ struct server_slot {
|
||||||
n_draft_total = 0;
|
n_draft_total = 0;
|
||||||
n_draft_accepted = 0;
|
n_draft_accepted = 0;
|
||||||
|
|
||||||
|
spec_n_denials = 0;
|
||||||
|
spec_n_accepted = 0;
|
||||||
|
spec_has_ckpt = false;
|
||||||
|
spec_ckpt_size_part = 0;
|
||||||
|
|
||||||
task_prev = std::move(task);
|
task_prev = std::move(task);
|
||||||
task.reset();
|
task.reset();
|
||||||
|
|
||||||
|
|
@ -2049,7 +2054,8 @@ private:
|
||||||
// TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK]
|
// TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK]
|
||||||
// perform the speculative drafting for all sequences at the same time in a single batch
|
// perform the speculative drafting for all sequences at the same time in a single batch
|
||||||
const int n_draft_max = (slot.spec_n_accepted > 0) ? slot.spec_n_accepted : slot.get_n_draft_max();
|
const int n_draft_max = (slot.spec_n_accepted > 0) ? slot.spec_n_accepted : slot.get_n_draft_max();
|
||||||
if (n_draft_max > 0 && slot.spec_n_denials < params_base.speculative.ckpt_num_tries) {
|
if (n_draft_max > 0 && (params_base.speculative.ckpt_num_tries == 0
|
||||||
|
|| slot.spec_n_denials < params_base.speculative.ckpt_num_tries)) {
|
||||||
if (mctx) {
|
if (mctx) {
|
||||||
// we should never reach this, as speculative is automatically disabled if mmproj is loaded
|
// we should never reach this, as speculative is automatically disabled if mmproj is loaded
|
||||||
GGML_ABORT("not supported by multimodal");
|
GGML_ABORT("not supported by multimodal");
|
||||||
|
|
@ -2068,16 +2074,19 @@ private:
|
||||||
|
|
||||||
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
|
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
|
||||||
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id);
|
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id);
|
||||||
bool do_checkpoint = !draft.empty() && slot.prompt.checkpoints.size() < (size_t) params_base.n_ctx_checkpoints;
|
bool do_checkpoint = !draft.empty() && params_base.speculative.ckpt_num_tries > 0
|
||||||
SLT_DBG(slot, "draft.size = %zu, n_spec_denials = %d, #ckpts=%zu, do_checkpoint = %s, pos_min = %d, pos_max = %d, tokens=[..., %d, %d, %d]\n",
|
&& slot.prompt.checkpoints.size() < (size_t) params_base.n_ctx_checkpoints;
|
||||||
draft.size(), slot.spec_n_denials,
|
if (do_checkpoint && cached_text_tokens.size() > 5) {
|
||||||
slot.prompt.checkpoints.size(),
|
SLT_DBG(slot, "draft.size = %zu, n_spec_denials = %d, #ckpts=%zu, do_checkpoint = %s, pos_min = %d, pos_max = %d, tokens=[..., %d, %d, %d]\n",
|
||||||
do_checkpoint ? "yes" : "no", pos_min, pos_max,
|
draft.size(), slot.spec_n_denials,
|
||||||
cached_text_tokens[cached_text_tokens.size() - 3],
|
slot.prompt.checkpoints.size(),
|
||||||
cached_text_tokens[cached_text_tokens.size() - 2],
|
do_checkpoint ? "yes" : "no", pos_min, pos_max,
|
||||||
cached_text_tokens[cached_text_tokens.size() - 1]);
|
cached_text_tokens[cached_text_tokens.size() - 3],
|
||||||
|
cached_text_tokens[cached_text_tokens.size() - 2],
|
||||||
|
cached_text_tokens[cached_text_tokens.size() - 1]);
|
||||||
|
}
|
||||||
|
|
||||||
if (do_checkpoint) {
|
if (do_checkpoint) {
|
||||||
while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
|
while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
|
||||||
// make room for the new checkpoint, if needed
|
// make room for the new checkpoint, if needed
|
||||||
const auto & cur = slot.prompt.checkpoints.front();
|
const auto & cur = slot.prompt.checkpoints.front();
|
||||||
|
|
@ -2859,7 +2868,7 @@ private:
|
||||||
const int64_t t_current = ggml_time_us();
|
const int64_t t_current = ggml_time_us();
|
||||||
|
|
||||||
|
|
||||||
if (ids.size() < n_draft + 1 && slot.spec_has_ckpt) {
|
if (slot.spec_has_ckpt && ids.size() < n_draft + 1) {
|
||||||
// the main model rejected some tokens, so we need to rollback to the state before sampling the draft tokens
|
// the main model rejected some tokens, so we need to rollback to the state before sampling the draft tokens
|
||||||
auto & ckpt = slot.prompt.checkpoints.back();
|
auto & ckpt = slot.prompt.checkpoints.back();
|
||||||
SLT_INF(slot, "partial acceptance: %zu < %zu, restoring checkpoint (pos_min = %d, pos_max = %d)\n",
|
SLT_INF(slot, "partial acceptance: %zu < %zu, restoring checkpoint (pos_min = %d, pos_max = %d)\n",
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue