revert: use argmax for MTP draft (temperature sampling reduces acceptance 95%→39%)

Temperature sampling from MTP logits doesn't match the main model's
distribution because they have different probability spaces. Argmax
gives 89-95% acceptance vs 39% with temperature sampling.

The 5% mismatch at temp=0.6 is expected — the main model sometimes
samples non-argmax tokens. This is the natural speculative decoding
behavior and doesn't need fixing.
This commit is contained in:
itigges22 2026-03-19 19:41:38 -04:00
parent bc443d36a8
commit 72cdcce738
1 changed files with 18 additions and 5 deletions

View File

@ -10,9 +10,11 @@
#include "sampling.h"
#include <algorithm>
#include <cmath>
#include <cstring>
#include <iomanip>
#include <map>
#include <random>
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
@ -468,6 +470,7 @@ struct common_speculative_state_eagle3 : public common_speculative_state {
struct common_speculative_state_mtp : public common_speculative_state {
llama_context * ctx_tgt;
bool cooldown = false; // skip proposal after rejection to get fresh MTP logits
std::mt19937 rng{42}; // RNG for temperature sampling of MTP drafts
common_speculative_state_mtp(
enum common_speculative_type type,
@ -507,18 +510,28 @@ struct common_speculative_state_mtp : public common_speculative_state {
}
const int n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(llama_get_model(ctx_tgt)));
if (n_vocab <= 0) {
return;
}
llama_token best_token = 0;
// Argmax of MTP logits — the MTP head is trained to predict the
// same token the main model would pick. At temperature=0 (greedy),
// this gives ~100% acceptance. At temperature>0, the main model
// sometimes samples non-argmax tokens (~5% mismatch at temp=0.6).
// This is the expected behavior — temperature sampling on MTP logits
// doesn't help because MTP and main model have different distributions.
llama_token draft_token = 0;
float best_logit = mtp_logits[0];
for (int i = 1; i < n_vocab; ++i) {
for (int i = 1; i < n_vocab; i++) {
if (mtp_logits[i] > best_logit) {
best_logit = mtp_logits[i];
best_token = i;
draft_token = i;
}
}
if (best_token >= 0 && best_token < n_vocab) {
result.push_back(best_token);
const auto * vocab = llama_model_get_vocab(llama_get_model(ctx_tgt));
if (!llama_vocab_is_eog(vocab, draft_token)) {
result.push_back(draft_token);
}
GGML_UNUSED(id_last);