diff --git a/common/speculative.cpp b/common/speculative.cpp index 581c18419d..537aa6cf1f 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -1136,13 +1136,15 @@ struct common_speculative_session::impl { clear_draft(); return draft; } - if (params_spec.use_checkpoints - && spec_ckpt_n_denials > 0) { + if (params_spec.use_checkpoints && spec_ckpt_n_denials > 1) { + // We shouldn't get two denials. + LOG_WRN("%s: #tokens=%zu, spec_ckpt_n_denials=%d, id_last=%d, #draft=%zu\n", __func__, + cached_text_tokens.size(), spec_ckpt_n_denials, id_last, draft.size()); clear_draft(); return draft; } - if (spec_ckpt_n_denials > 0) { + if (spec_ckpt_n_denials == 1) { // there is a previous speculation which wasn't accepted in full length if (draft.empty()) { LOG_WRN("%s: draft of length 0 after denied checkpoint\n", __func__); @@ -1150,7 +1152,8 @@ struct common_speculative_session::impl { return draft; } // we use the shortened draft of previous speculation - LOG_INF("%s: resuse shortened draft, size=%zu\n", __func__, draft.size()); + LOG_DBG("%s: reuse shortened draft, #tokens=%zu, id_last=%d, size=%zu\n", __func__, + cached_text_tokens.size(), id_last, draft.size()); } else { // call the speculative implementation to create a draft draft = common_speculative_draft(spec, params_spec, cached_text_tokens, id_last); @@ -1167,24 +1170,16 @@ struct common_speculative_session::impl { } bool do_checkpoint = !draft.empty() && params_spec.use_checkpoints; - if (do_checkpoint && cached_text_tokens.size() > 5) { - LOG_DBG("draft.size = %zu, n_spec_denials = %d, do_checkpoint = %s, tokens=[..., %d, %d, %d]\n", + if (do_checkpoint && cached_text_tokens.size() > 5 && draft.size() >= 3) { + LOG_DBG("%s: #tokens=%zu, draft.size=%zu, n_spec_denials=%d, do_checkpoint=%s, id_last=%d, tokens=[..., %d, %d, %d], draft=[%d, %d, %d, ...]\n", + __func__, + cached_text_tokens.size(), draft.size(), spec_ckpt_n_denials, - do_checkpoint ? "yes" : "no", + do_checkpoint ? "yes" : "no", id_last, cached_text_tokens[cached_text_tokens.size() - 3], cached_text_tokens[cached_text_tokens.size() - 2], - cached_text_tokens[cached_text_tokens.size() - 1]); - } - - if (do_checkpoint) { - const size_t n = callback.create_checkpoint(); - if (n == 0) { - LOG_WRN("checkpoint creation failed"); - clear_draft(); - return draft; - } - spec_ckpt_size_part = n; - spec_has_ckpt = true; + cached_text_tokens[cached_text_tokens.size() - 1], + draft[0], draft[1], draft[2]); } if (params_spec.n_min > (int) draft.size()) { @@ -1193,6 +1188,17 @@ struct common_speculative_session::impl { return draft; } + if (do_checkpoint) { + const size_t n = callback.create_checkpoint(); + if (n == 0) { + LOG_WRN("%s: checkpoint creation failed (#tokens=%zu)\n", __func__, cached_text_tokens.size()); + clear_draft(); + return draft; + } + spec_ckpt_size_part = n; + spec_has_ckpt = true; + } + // add last sampled token to the batch callback.batch_add_token(id_last, true); @@ -1219,27 +1225,31 @@ struct common_speculative_session::impl { if (spec_has_ckpt) { // we need to rollback to the state before sampling the draft tokens const size_t n = callback.restore_checkpoint(spec_ckpt_size_part); - LOG_INF("partial acceptance: %zu < %zu, restored checkpoint: got %zu bytes\n", - ids.size() -1 , n_draft, n); + LOG_DBG("%s: partial acceptance: %zu < %zu, restored checkpoint: got %zu bytes\n", + __func__, + ids.size() - 1, n_draft, n); - // rollback to the state before sampling the draft tokens - - // Delete Checkpoint + // delete Checkpoint callback.delete_checkpoint(); spec_has_ckpt = false; - if (n_draft > 0 && spec_ckpt_n_denials == 0) { + spec_ckpt_n_denials++; + if (ids.size() > 1u + static_cast(params_spec.n_min) && spec_ckpt_n_denials == 1) { // we will do the batch again but with the shortened draft - spec_ckpt_n_denials++; - return common_speculative_accept_response(std::move(ids), n_draft, true); } - callback.batch_clear(); + LOG_DBG("%s: don't accept partial draft, n_draft=%zu, ids.size=%zu\n", __func__, n_draft, ids.size()); + draft.clear(); + + // use the sampled token only + ids.resize(1); + // drafted tokens in prompt have been deleted in restore_checkpoint(...). + return common_speculative_accept_response{std::move(ids), 0, false}; } } const size_t draft_size_accepted = draft.size(); - LOG_DBG("%s: draft.size=%zu\n", __func__, draft_size_accepted); + LOG_DBG("%s: draft.size=%zu, ids.size=%zu\n", __func__, draft_size_accepted, ids.size()); common_speculative_accept(spec, draft_size_accepted); draft.clear(); diff --git a/common/speculative.h b/common/speculative.h index b141938eba..e9595b4bb9 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -61,9 +61,6 @@ struct common_speculative_callback { // Add a token to the draft sequence. virtual void batch_add_token(const llama_token token, bool logits) = 0; - // Clears the batch context. - virtual void batch_clear() = 0; - // Sample and accept tokens from the main model. virtual llama_tokens sampler_sample_and_accept_n(const llama_tokens & drafted) = 0; diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 8343975b19..e30111d4f6 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -1,3 +1,4 @@ + #include "server-context.h" #include "server-common.h" #include "server-http.h" @@ -57,7 +58,7 @@ struct server_slot { mtmd_context * mctx = nullptr; std::unique_ptr spec_callback; - common_speculative_session * spec_session = nullptr; + std::unique_ptr spec_session = nullptr; // TODO: move members that belong to the task (such as `generated_text`, `has_new_line`) to task_results_state // see https://github.com/ggml-org/llama.cpp/pull/18283#issuecomment-3710175837 @@ -641,10 +642,6 @@ private: slot.prompt.tokens.push_back(token); } - void batch_clear() override { - common_batch_clear(ctx_impl.batch); - } - std::vector sampler_sample_and_accept_n(const llama_tokens & drafted) override { if (slot.i_batch_dft.size() != 1 + drafted.size()) { GGML_ABORT("%s: #i_batch_dft = %zu != 1 + #drafted=%zu", @@ -666,7 +663,7 @@ private: const auto & cur_with_size = ctx_impl.get_checkpoint(slot, n_tokens_cur, pos_min, pos_max); auto & cur = cur_with_size.checkpoint; - SLT_INF(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", + SLT_DBG(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", (int) slot.prompt.checkpoints.size(), ctx_impl.params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); return cur_with_size.size; } @@ -674,7 +671,7 @@ private: size_t restore_checkpoint(size_t ckpt_size_part_expected) override { auto & ckpt = slot.prompt.checkpoints.back(); - SLT_INF(slot, "restoring checkpoint (pos_min = %d, pos_max = %d)\n", ckpt.pos_min, ckpt.pos_max); + SLT_DBG(slot, "restoring checkpoint (pos_min = %d, pos_max = %d)\n", ckpt.pos_min, ckpt.pos_max); const size_t n = llama_state_seq_set_data_ext(ctx_impl.ctx, ckpt.data.data(), ckpt.size(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); if (n != ckpt_size_part_expected) { @@ -855,7 +852,7 @@ private: return false; } slot.spec_callback = std::make_unique(slot, *this); - slot.spec_session = new common_speculative_session(*slot.spec_callback, + slot.spec_session = std::make_unique(*slot.spec_callback, params_base.speculative, slot.ctx); SLT_INF(slot, "%s", "speculative decoding context initialized\n"); } @@ -2180,12 +2177,16 @@ private: // generate draft tokens in speculative decoding mode // TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK] // perform the speculative drafting for all sequences at the same time in a single batch + llama_tokens draft; const int n_draft_max_slot = slot.get_n_draft_max(); - - const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens(); - llama_tokens draft = slot.spec_session->compute_draft(cached_text_tokens, slot.sampled, n_draft_max_slot); - if (draft.size() > 0) { - SLT_DBG(slot, "compute_draft: #tokens=%d\n", (int) draft.size()); + if (n_draft_max_slot > 0) { + const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens(); + // compute draft and add draft to internal batch + draft = slot.spec_session->compute_draft(cached_text_tokens, slot.sampled, n_draft_max_slot); + if (draft.size() > 0) { + SLT_DBG(slot, "compute_draft: #cached_text_tokens=%zu, #tokens=%zu, #i_batch_dft=%zu\n", + cached_text_tokens.size(), draft.size(), slot.i_batch_dft.size()); + } } if (draft.empty()) { @@ -2940,7 +2941,7 @@ private: slot.i_batch_dft.clear(); const size_t n_draft = accept_response.draft_size_initial; if (accept_response.skip_acceptance) { - SLT_INF(slot, "partial acceptance: n_tokens=%zu, n_draft=%zu\n", accept_response.tokens.size(), n_draft); + SLT_DBG(slot, "partial acceptance: n_tokens=%zu, n_draft=%zu\n", accept_response.tokens.size(), n_draft); continue; } const auto ids = accept_response.tokens;