revert: use argmax for MTP draft (temperature sampling reduces acceptance 95%→39%)
Temperature sampling from MTP logits doesn't match the main model's distribution because they have different probability spaces. Argmax gives 89-95% acceptance vs 39% with temperature sampling. The 5% mismatch at temp=0.6 is expected — the main model sometimes samples non-argmax tokens. This is the natural speculative decoding behavior and doesn't need fixing.
This commit is contained in:
parent
bc443d36a8
commit
72cdcce738
|
|
@ -10,9 +10,11 @@
|
|||
#include "sampling.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <iomanip>
|
||||
#include <map>
|
||||
#include <random>
|
||||
|
||||
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
|
||||
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
|
||||
|
|
@ -468,6 +470,7 @@ struct common_speculative_state_eagle3 : public common_speculative_state {
|
|||
struct common_speculative_state_mtp : public common_speculative_state {
|
||||
llama_context * ctx_tgt;
|
||||
bool cooldown = false; // skip proposal after rejection to get fresh MTP logits
|
||||
std::mt19937 rng{42}; // RNG for temperature sampling of MTP drafts
|
||||
|
||||
common_speculative_state_mtp(
|
||||
enum common_speculative_type type,
|
||||
|
|
@ -507,18 +510,28 @@ struct common_speculative_state_mtp : public common_speculative_state {
|
|||
}
|
||||
|
||||
const int n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(llama_get_model(ctx_tgt)));
|
||||
if (n_vocab <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
llama_token best_token = 0;
|
||||
// Argmax of MTP logits — the MTP head is trained to predict the
|
||||
// same token the main model would pick. At temperature=0 (greedy),
|
||||
// this gives ~100% acceptance. At temperature>0, the main model
|
||||
// sometimes samples non-argmax tokens (~5% mismatch at temp=0.6).
|
||||
// This is the expected behavior — temperature sampling on MTP logits
|
||||
// doesn't help because MTP and main model have different distributions.
|
||||
llama_token draft_token = 0;
|
||||
float best_logit = mtp_logits[0];
|
||||
for (int i = 1; i < n_vocab; ++i) {
|
||||
for (int i = 1; i < n_vocab; i++) {
|
||||
if (mtp_logits[i] > best_logit) {
|
||||
best_logit = mtp_logits[i];
|
||||
best_token = i;
|
||||
draft_token = i;
|
||||
}
|
||||
}
|
||||
|
||||
if (best_token >= 0 && best_token < n_vocab) {
|
||||
result.push_back(best_token);
|
||||
const auto * vocab = llama_model_get_vocab(llama_get_model(ctx_tgt));
|
||||
if (!llama_vocab_is_eog(vocab, draft_token)) {
|
||||
result.push_back(draft_token);
|
||||
}
|
||||
|
||||
GGML_UNUSED(id_last);
|
||||
|
|
|
|||
Loading…
Reference in New Issue