diff --git a/common/speculative.cpp b/common/speculative.cpp index 4c6caf66fa..a863690a64 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -1353,8 +1353,7 @@ struct common_speculative_session::impl { spec_ckpt_n_denials++; if (ids.size() > 1u + static_cast(params_spec.n_min) && spec_ckpt_n_denials == 1) { // we will do the batch again but with the shortened draft - //return common_speculative_accept_response(std::move(ids), n_draft, true); - LOG_DBG("%s: partial draft disabled\n", __func__); + return common_speculative_accept_response(std::move(ids), n_draft, true); } LOG_DBG("%s: don't accept partial draft, n_draft=%zu, ids.size=%zu\n", __func__, n_draft, ids.size()); diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 5230f33034..2c624de7d1 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -58,7 +58,8 @@ struct server_slot { mtmd_context * mctx = nullptr; std::unique_ptr spec_callback; - std::unique_ptr spec_session = nullptr; + std::unique_ptr spec_session = nullptr; + struct common_sampler * spec_saved_sampler = nullptr; // TODO: move members that belong to the task (such as `generated_text`, `has_new_line`) to task_results_state // see https://github.com/ggml-org/llama.cpp/pull/18283#issuecomment-3710175837 @@ -607,6 +608,10 @@ private: slot.spec_session->reset(); slot.spec_session = nullptr; } + if (slot.spec_saved_sampler != nullptr) { + common_sampler_free(slot.spec_saved_sampler); + slot.spec_saved_sampler = nullptr; + } } llama_batch_free(batch); @@ -677,6 +682,13 @@ private: SLT_DBG(*slot, "created context checkpoint %zu of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", slot->prompt.checkpoints.size(), ctx_impl.params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); + + if (slot->spec_saved_sampler != nullptr) { + common_sampler_free(slot->spec_saved_sampler); + } + // save sampler (we may want to restore the RNG in the sampler after refusal of a draft) + slot->spec_saved_sampler = common_sampler_clone(slot->smpl.get()); + return cur_with_size.size; } @@ -691,8 +703,17 @@ private: GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu", __func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), ckpt_size_part_expected, n); } + // remove entries after ckpt.pos_max + llama_memory_seq_rm(llama_get_memory(ctx_impl.ctx), slot->id, ckpt.pos_max + 1, -1); slot->prompt.tokens.keep_first(ckpt.pos_max + 1); + + if (slot->spec_saved_sampler != nullptr) { + slot->smpl.reset(slot->spec_saved_sampler); + + slot->spec_saved_sampler = nullptr; + } + return n; }