diff --git a/common/speculative.cpp b/common/speculative.cpp index fff32001ab..edd88050a2 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -10,9 +10,11 @@ #include "sampling.h" #include +#include #include #include #include +#include #define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128 #define SPEC_VOCAB_CHECK_START_TOKEN_ID 5 @@ -468,6 +470,7 @@ struct common_speculative_state_eagle3 : public common_speculative_state { struct common_speculative_state_mtp : public common_speculative_state { llama_context * ctx_tgt; bool cooldown = false; // skip proposal after rejection to get fresh MTP logits + std::mt19937 rng{42}; // RNG for temperature sampling of MTP drafts common_speculative_state_mtp( enum common_speculative_type type, @@ -507,18 +510,28 @@ struct common_speculative_state_mtp : public common_speculative_state { } const int n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(llama_get_model(ctx_tgt))); + if (n_vocab <= 0) { + return; + } - llama_token best_token = 0; + // Argmax of MTP logits — the MTP head is trained to predict the + // same token the main model would pick. At temperature=0 (greedy), + // this gives ~100% acceptance. At temperature>0, the main model + // sometimes samples non-argmax tokens (~5% mismatch at temp=0.6). + // This is the expected behavior — temperature sampling on MTP logits + // doesn't help because MTP and main model have different distributions. + llama_token draft_token = 0; float best_logit = mtp_logits[0]; - for (int i = 1; i < n_vocab; ++i) { + for (int i = 1; i < n_vocab; i++) { if (mtp_logits[i] > best_logit) { best_logit = mtp_logits[i]; - best_token = i; + draft_token = i; } } - if (best_token >= 0 && best_token < n_vocab) { - result.push_back(best_token); + const auto * vocab = llama_model_get_vocab(llama_get_model(ctx_tgt)); + if (!llama_vocab_is_eog(vocab, draft_token)) { + result.push_back(draft_token); } GGML_UNUSED(id_last);