diff --git a/common/ngram-map.cpp b/common/ngram-map.cpp index 2f8a8c7bbe..b41f6bd8ab 100644 --- a/common/ngram-map.cpp +++ b/common/ngram-map.cpp @@ -524,7 +524,7 @@ void common_ngram_map_accept(common_ngram_map & map, uint16_t n_accepted) { struct common_ngram_map_value & curr_value = curr_key.values[val_idx]; // value used for draft generation. // update the value statistics - LOG_INF("common_ngram_map_send_accepted: n_accepted = %d, prev value_num = %d\n", + LOG_DBG("common_ngram_map_send_accepted: n_accepted = %d, prev value_num = %d\n", n_accepted, curr_value.n_accepted); curr_value.n_accepted = n_accepted; } diff --git a/common/speculative.cpp b/common/speculative.cpp index 537aa6cf1f..c145cab398 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -144,10 +144,28 @@ struct common_speculative_state { virtual void accept(uint16_t n_accepted) = 0; }; +struct common_speculative_checkpoint { + llama_pos pos_min; + llama_pos pos_max; + + int64_t n_tokens; + + std::vector data; + + size_t size() const { + return data.size(); + } + + size_t ckpt_size; +}; + struct common_speculative_state_draft : public common_speculative_state { llama_context * ctx_tgt; // only used for retokenizing from ctx_dft llama_context * ctx_dft; + struct common_speculative_checkpoint ckpt; + bool use_checkpoint; + common_sampler * smpl; llama_batch batch; @@ -160,10 +178,12 @@ struct common_speculative_state_draft : public common_speculative_state { enum common_speculative_type type, llama_context * ctx_tgt, llama_context * ctx_dft, - const std::vector> & replacements) + const std::vector> & replacements, + bool use_checkpoint) : common_speculative_state(type) , ctx_tgt(ctx_tgt) , ctx_dft(ctx_dft) + , use_checkpoint(use_checkpoint) { batch = llama_batch_init(llama_n_batch(ctx_dft), 0, 1); smpl = nullptr; @@ -218,7 +238,48 @@ struct common_speculative_state_draft : public common_speculative_state { } void begin(const llama_tokens & prompt) override { - GGML_UNUSED(prompt); + if (use_checkpoint && ckpt.size() > 0) { + // delete checkpoint + LOG_DBG("%s: delete checkpoint, prompt.size=%zu, pos_min=%d, pos_max=%d, n_tokens=%zu, size=%.3f MiB\n", + __func__, prompt.size(), + ckpt.pos_min, ckpt.pos_max, ckpt.n_tokens, (float) ckpt.data.size() / 1024 / 1024); + ckpt.pos_min = 0; + ckpt.pos_max = 0; + ckpt.n_tokens = 0; + ckpt.ckpt_size = 0; + ckpt.data.clear(); + } + } + + size_t draft_init_checkpoint(int n_tokens_prompt, int n_tokens_batch) { + int slot_id = 0; + const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx_dft, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + + ckpt.pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_dft), slot_id); + ckpt.pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), slot_id); + ckpt.n_tokens = n_tokens_prompt - n_tokens_batch; + ckpt.data.resize(checkpoint_size); + + const size_t n = llama_state_seq_get_data_ext(ctx_dft, ckpt.data.data(), checkpoint_size, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + if (n != checkpoint_size) { + GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", checkpoint_size, n); + } + + LOG_DBG("%s: pos_min = %d, pos_max = %d, size = %.3f MiB\n", __func__, + ckpt.pos_min, ckpt.pos_max, (float) ckpt.data.size() / 1024 / 1024); + return n; + } + + size_t draft_restore_checkpoint(size_t ckpt_size_part_expected) { + int slot_id = 0; + LOG_DBG("%s: pos_min = %d, pos_max = %d\n", __func__, ckpt.pos_min, ckpt.pos_max); + const size_t n = llama_state_seq_set_data_ext(ctx_dft, + ckpt.data.data(), ckpt.size(), slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + if (n != ckpt_size_part_expected) { + GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu", + __func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), ckpt_size_part_expected, n); + } + return n; } void draft( @@ -236,8 +297,8 @@ struct common_speculative_state_draft : public common_speculative_state { auto * mem_dft = llama_get_memory(ctx_dft); - int reuse_i = 0; - int reuse_n = 0; + int reuse_i = 0; // index of part to be reused in prompt_dft + int reuse_n = 0; // length of part to be reused in prompt_dft const int n_ctx = llama_n_ctx(ctx_dft) - params.n_max; @@ -287,18 +348,26 @@ struct common_speculative_state_draft : public common_speculative_state { } } - LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt_dft.size()); + LOG_DBG("%s: reuse_i = %d, reuse_n = %d, #prompt_dft = %zu, #prompt_cur = %zu\n", + __func__, reuse_i, reuse_n, prompt_dft.size(), prompt_cur.size()); + if (use_checkpoint && ckpt.ckpt_size == 0 && reuse_n > 0) { + LOG_DBG("%s: no checkpoint available, no reuse, (reuse_i=%d, reuse_n=%d) -> (0, 0)\n", + __func__, reuse_i, reuse_n); + reuse_i = 0; + reuse_n = 0; + } result.clear(); result.reserve(params.n_max); - if (reuse_n == 0) { + bool needs_ckpt = use_checkpoint && prompt_dft.size() > 0; + if (reuse_n == 0 || (use_checkpoint && reuse_i > 0)) { llama_memory_clear(mem_dft, false); prompt_dft.clear(); } else { // this happens when a previous draft has been discarded (for example, due to being too small), but the // target model agreed with it. in this case, we simply pass back the previous results to save compute - if (reuse_i + reuse_n < (int) prompt_dft.size() && prompt_dft[reuse_i + reuse_n] == id_last) { + if (reuse_i + reuse_n < (int64_t) prompt_dft.size() && prompt_dft[reuse_i + reuse_n] == id_last) { for (int i = reuse_i + reuse_n + 1; i < (int) prompt_dft.size(); ++i) { result.push_back(prompt_dft[i]); @@ -310,19 +379,50 @@ struct common_speculative_state_draft : public common_speculative_state { return; } + bool do_restore = false; + if (prompt_dft.size() > prompt_cur.size() && reuse_i + reuse_n < (int64_t) prompt_dft.size()) { + // This can happen after a partial acceptance (speculative decoding with checkpoints) + LOG_DBG("%s: #prompt_dft=%zu, #prompt_cur=%zu, shorten draft\n", + __func__, prompt_dft.size(), prompt_cur.size()); + prompt_dft.resize(prompt_cur.size()); + do_restore = true; + } + if (reuse_i > 0) { - llama_memory_seq_rm (mem_dft, 0, 0, reuse_i); + bool is_removed = llama_memory_seq_rm (mem_dft, 0, 0, reuse_i); + if (!is_removed) { + LOG_ERR("%s: llama_memory_seq_rm failed, reuse_i=%d\n", __func__, reuse_i); + } llama_memory_seq_add(mem_dft, 0, reuse_i, -1, -reuse_i); prompt_dft.erase(prompt_dft.begin(), prompt_dft.begin() + reuse_i); } - if (reuse_n < (int) prompt_dft.size()) { - llama_memory_seq_rm (mem_dft, 0, reuse_n, -1); - prompt_dft.erase(prompt_dft.begin() + reuse_n, prompt_dft.end()); + if (reuse_n < (int) prompt_dft.size() || do_restore) { + if (use_checkpoint) { + if (ckpt.n_tokens > (int64_t) prompt_dft.size()) { + LOG_INF("%s: checkpoint is too large, prompt_tgt.size=%zu, ckpt.n_tokens=%zu, reuse_n=%d, prompt_dft.size=%zu\n", + __func__, prompt_tgt.size(), ckpt.n_tokens, reuse_n, prompt_dft.size()); + } + draft_restore_checkpoint(ckpt.ckpt_size); + reuse_n = ckpt.n_tokens; + prompt_dft.resize(reuse_n); + needs_ckpt = false; + } else { + bool is_removed = llama_memory_seq_rm (mem_dft, 0, reuse_n, -1); + if (!is_removed) { + LOG_ERR("%s: llama_memory_seq_rm failed, reuse_n=%d, prompt_dft.size=%zu\n", + __func__, reuse_n, prompt_dft.size()); + } + prompt_dft.erase(prompt_dft.begin() + reuse_n, prompt_dft.end()); + } } } + if (needs_ckpt && use_checkpoint) { + ckpt.ckpt_size = draft_init_checkpoint(prompt_dft.size(), batch.n_tokens); + } + // prepare a batch to evaluate any new tokens in the prompt common_batch_clear(batch); @@ -337,7 +437,11 @@ struct common_speculative_state_draft : public common_speculative_state { if (batch.n_tokens > 0) { //LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str()); - llama_decode(ctx_dft, batch); + int ret = llama_decode(ctx_dft, batch); + if (ret != 0 && ret != 1) { + LOG_WRN("%s: llama_decode returned %d, prompt_cur.size=%zu\n", + __func__, ret, prompt_cur.size()); + } } const llama_pos n_past = prompt_dft.size(); @@ -351,7 +455,11 @@ struct common_speculative_state_draft : public common_speculative_state { LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx_dft, prompt_dft).c_str()); - llama_decode(ctx_dft, batch); + int ret = llama_decode(ctx_dft, batch); + if (ret != 0 && ret != 1) { + LOG_WRN("%s: llama_decode returned %d, prompt_cur.size=%zu, prompt_dft.size=%zu\n", + __func__, ret, prompt_cur.size(), prompt_dft.size()); + } common_sampler_reset(smpl); @@ -387,7 +495,11 @@ struct common_speculative_state_draft : public common_speculative_state { common_batch_add(batch, id, n_past + i + 1, { 0 }, true); // evaluate the drafted tokens on the draft model - llama_decode(ctx_dft, batch); + ret = llama_decode(ctx_dft, batch); + if (ret != 0) { + LOG_WRN("%s: llama_decode[%d] returned %d, prompt_cur.size=%zu, prompt_dft.size=%zu\n", + __func__, i, ret, prompt_cur.size(), prompt_dft.size()); + } prompt_dft.push_back(id); } @@ -909,9 +1021,10 @@ common_speculative * common_speculative_init( break; case COMMON_SPECULATIVE_TYPE_DRAFT: { impls.push_back(std::make_unique(config.type, - /* .ctx_tgt = */ ctx_tgt, - /* .ctx_dft = */ ctx_dft, - /* .replacements = */ params.replacements + /* .ctx_tgt = */ ctx_tgt, + /* .ctx_dft = */ ctx_dft, + /* .replacements = */ params.replacements, + /* .use_checkpoint= */ params.use_checkpoints )); break; } @@ -1147,13 +1260,16 @@ struct common_speculative_session::impl { if (spec_ckpt_n_denials == 1) { // there is a previous speculation which wasn't accepted in full length if (draft.empty()) { - LOG_WRN("%s: draft of length 0 after denied checkpoint\n", __func__); + // switch to non-draft inference + LOG_DBG("%s: draft of length 0 after denied checkpoint\n", __func__); clear_draft(); return draft; } // we use the shortened draft of previous speculation LOG_DBG("%s: reuse shortened draft, #tokens=%zu, id_last=%d, size=%zu\n", __func__, cached_text_tokens.size(), id_last, draft.size()); + } else if (spec_ckpt_n_denials > 1) { + GGML_ABORT("illegal state: spec_ckpt_n_denials = %d > 1", spec_ckpt_n_denials); } else { // call the speculative implementation to create a draft draft = common_speculative_draft(spec, params_spec, cached_text_tokens, id_last); @@ -1224,6 +1340,7 @@ struct common_speculative_session::impl { draft.resize(ids.size() - 1); if (spec_has_ckpt) { // we need to rollback to the state before sampling the draft tokens + // (restore_checkpoint shortens context and slot.prompt.tokens) const size_t n = callback.restore_checkpoint(spec_ckpt_size_part); LOG_DBG("%s: partial acceptance: %zu < %zu, restored checkpoint: got %zu bytes\n", __func__, @@ -1236,7 +1353,8 @@ struct common_speculative_session::impl { spec_ckpt_n_denials++; if (ids.size() > 1u + static_cast(params_spec.n_min) && spec_ckpt_n_denials == 1) { // we will do the batch again but with the shortened draft - return common_speculative_accept_response(std::move(ids), n_draft, true); + //return common_speculative_accept_response(std::move(ids), n_draft, true); + LOG_DBG("%s: partial draft disabled\n", __func__); } LOG_DBG("%s: don't accept partial draft, n_draft=%zu, ids.size=%zu\n", __func__, n_draft, ids.size()); @@ -1245,7 +1363,9 @@ struct common_speculative_session::impl { // use the sampled token only ids.resize(1); // drafted tokens in prompt have been deleted in restore_checkpoint(...). - return common_speculative_accept_response{std::move(ids), 0, false}; + + // skip acceptance, don't calculate a new draft + return common_speculative_accept_response{std::move(ids), 0, true}; } } const size_t draft_size_accepted = draft.size(); diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index e30111d4f6..bbbd0e2154 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -2184,7 +2184,8 @@ private: // compute draft and add draft to internal batch draft = slot.spec_session->compute_draft(cached_text_tokens, slot.sampled, n_draft_max_slot); if (draft.size() > 0) { - SLT_DBG(slot, "compute_draft: #cached_text_tokens=%zu, #tokens=%zu, #i_batch_dft=%zu\n", + SLT_DBG(slot, "compute_draft: id=%d, #cached_text_tokens=%zu, #tokens=%zu, #i_batch_dft=%zu\n", + slot.sampled, cached_text_tokens.size(), draft.size(), slot.i_batch_dft.size()); } } @@ -2198,7 +2199,8 @@ private: slot.prompt.tokens.push_back(slot.sampled); - SLT_DBG(slot, "slot decode token, n_ctx = %d, n_tokens = %d, truncated = %d\n", + SLT_DBG(slot, "slot decode token, id=%d, n_ctx = %d, n_tokens = %d, truncated = %d\n", + slot.sampled, slot.n_ctx, slot.prompt.n_tokens(), slot.truncated); } } @@ -2954,6 +2956,7 @@ private: // update how many tokens out of those tested were accepted slot.n_draft_accepted += ids.size() - 1; + slot.n_draft_total += n_draft; // rollback to the state before sampling the draft tokens slot.prompt.tokens.keep_first(slot.prompt.n_tokens() - n_draft); @@ -2961,6 +2964,7 @@ private: // add accepted tokens to the prompt slot.prompt.tokens.insert({ids.begin(), ids.end() - 1}); slot.sampled = ids.back(); // last accepted token + SLT_DBG(slot, "add accepted tokens: sampled=%d, ids.size=%zu, n_draft=%zu\n", slot.sampled, ids.size(), n_draft); slot.spec_session->rewind(slot.prompt.n_tokens());