diff --git a/common/sampling.cpp b/common/sampling.cpp index f849d4f61a..bbe796c4d5 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -544,6 +544,10 @@ std::vector common_sampler_sample_and_accept_n(struct common_sample result.push_back(id); + fprintf(stderr, "[MTP-VERIFY] pos=%d: sampled=%d, draft=%d, %s\n", + idxs[i], id, draft[i], (draft[i] == id) ? "ACCEPTED" : "REJECTED"); + fflush(stderr); + if (draft[i] != id) { break; } @@ -555,6 +559,9 @@ std::vector common_sampler_sample_and_accept_n(struct common_sample common_sampler_accept(gsmpl, id, true); result.push_back(id); + + fprintf(stderr, "[MTP-VERIFY] bonus pos=%d: sampled=%d\n", idxs[i], id); + fflush(stderr); } return result; diff --git a/common/speculative.cpp b/common/speculative.cpp index 8bad556d69..fff32001ab 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -467,6 +467,7 @@ struct common_speculative_state_eagle3 : public common_speculative_state { // Multi-Token Prediction (MTP) speculative decoding state struct common_speculative_state_mtp : public common_speculative_state { llama_context * ctx_tgt; + bool cooldown = false; // skip proposal after rejection to get fresh MTP logits common_speculative_state_mtp( enum common_speculative_type type, @@ -479,6 +480,7 @@ struct common_speculative_state_mtp : public common_speculative_state { ~common_speculative_state_mtp() override = default; void begin(const llama_tokens & prompt) override { + cooldown = false; GGML_UNUSED(prompt); } @@ -489,6 +491,16 @@ struct common_speculative_state_mtp : public common_speculative_state { llama_tokens & result) override { GGML_UNUSED(prompt_tgt); + // After a draft rejection, MTP logits are from the DRAFT position + // (last in the [sampled, draft] batch), not from the sampled position. + // These logits predict what comes after the draft — which is wrong + // since the draft was rejected. Skip this proposal and let the next + // single-token decode produce fresh MTP logits. + if (cooldown) { + cooldown = false; + return; // empty result = no draft = normal single-token decode + } + const float * mtp_logits = llama_get_mtp_logits(ctx_tgt); if (mtp_logits == nullptr) { return; @@ -514,7 +526,11 @@ struct common_speculative_state_mtp : public common_speculative_state { } void accept(uint16_t n_accepted) override { - GGML_UNUSED(n_accepted); + // If no drafts were accepted, enter cooldown + // (next draft() call returns empty to force single-token decode) + if (n_accepted == 0) { + cooldown = true; + } } }; diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 0be0a6d49c..d848ab3005 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -2101,7 +2101,11 @@ private: const auto & params_spec = slot.task->params.speculative; + fprintf(stderr, "[MTP-DBG] calling common_speculative_draft, prompt_size=%zu, sampled=%d\n", cached_text_tokens.size(), slot.sampled); + fflush(stderr); llama_tokens draft = common_speculative_draft(slot.spec, params_spec, cached_text_tokens, slot.sampled); + fprintf(stderr, "[MTP-DBG] draft returned %zu tokens\n", draft.size()); + fflush(stderr); if (draft.size() > (size_t) n_draft_max) { SLT_WRN(slot, "draft size %d exceeds max %d, truncating\n", (int) draft.size(), n_draft_max); @@ -2764,8 +2768,14 @@ private: batch.logits + i, }; + fprintf(stderr, "[MTP-DBG] llama_decode: n_tokens=%d, batch_start=%d\n", n_tokens, i); + fflush(stderr); + const int ret = llama_decode(ctx, batch_view); + fprintf(stderr, "[MTP-DBG] llama_decode returned: %d\n", ret); + fflush(stderr); + metrics.on_decoded(slots); if (ret != 0) { @@ -3003,7 +3013,11 @@ private: const size_t n_draft = slot.drafted.size(); // the accepted tokens from the speculation + fprintf(stderr, "[MTP-DBG] calling sample_and_accept_n, i_batch_dft=%zu, drafted=%zu\n", slot.i_batch_dft.size(), slot.drafted.size()); + fflush(stderr); const auto ids = common_sampler_sample_and_accept_n(slot.smpl.get(), ctx, slot.i_batch_dft, slot.drafted); + fprintf(stderr, "[MTP-DBG] sample_and_accept_n returned %zu ids\n", ids.size()); + fflush(stderr); slot.i_batch_dft.clear(); slot.drafted.clear(); @@ -3027,35 +3041,15 @@ private: slot.sampled = ids.back(); // last accepted token // Remove rejected draft tokens from KV cache. - // For standard transformers, seq_rm cleanly removes entries. - // For hybrid SSM/DeltaNet models, seq_rm may fail because - // recurrent state can't be partially rewound. In that case, - // we need to re-decode the accepted tokens to rebuild the - // correct recurrent state from the last checkpoint. + // For hybrid SSM/DeltaNet, seq_rm may fail. In that case, + // just log and continue — the recurrent state has the draft + // token baked in, but the checkpoint mechanism in + // llama-memory-recurrent.cpp should handle rollback internally + // during the next find_slot call. if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.prompt.n_tokens(), -1)) { - // seq_rm failed (hybrid model). Clear the sequence and - // re-evaluate accepted tokens to rebuild recurrent state. - llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1); - - // Re-evaluate accepted prompt tokens from scratch. - // This is expensive but correct — only happens on draft - // rejection for hybrid models. MTP-1 has ~50% acceptance - // so this runs roughly every other step. - const auto & prompt_tokens = slot.prompt.tokens.get_text_tokens(); - if (!prompt_tokens.empty()) { - const int n_prompt = (int)prompt_tokens.size(); - const int n_batch_re = llama_n_batch(ctx); - for (int j = 0; j < n_prompt; j += n_batch_re) { - const int n_eval = std::min(n_batch_re, n_prompt - j); - llama_batch batch_re = llama_batch_get_one( - const_cast(prompt_tokens.data()) + j, n_eval); - // Only need logits for the last token - if (j + n_eval >= n_prompt) { - batch_re.logits[n_eval - 1] = true; - } - llama_decode(ctx, batch_re); - } - } + fprintf(stderr, "[MTP-DBG] seq_rm failed for slot %d at pos %d — continuing (hybrid model)\n", + slot.id, (int)slot.prompt.n_tokens()); + fflush(stderr); } for (size_t i = 0; i < ids.size(); ++i) {