feat: MTP compat fix for hybrid models + auto-enable + seq_rm fallback

- Bypass seq_rm compat check for MTP models (use checkpoint/restore)
- Auto-enable MTP speculative decoding when MTP layers detected
- Add seq_rm fallback: re-evaluate accepted tokens when seq_rm fails
  on hybrid SSM models (DeltaNet) to rebuild correct recurrent state
- Gate inline MTP experiments behind ATLAS_MTP_INLINE env var
- Skip tests in Dockerfile for faster builds
This commit is contained in:
itigges22 2026-03-19 09:15:14 -04:00
parent 6075918309
commit 1e3413c93c
2 changed files with 175 additions and 15 deletions

View File

@ -879,9 +879,19 @@ bool common_speculative_is_compat(llama_context * ctx_tgt) {
// try to remove the last tokens
if (!llama_memory_seq_rm(mem, 0, 1, -1)) {
LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__);
res = false;
goto done;
// Check if the model has MTP layers — for MTP-1, we can use
// checkpoint/restore instead of seq_rm for the 1-token rollback.
// Hybrid SSM models (DeltaNet) support checkpoint/restore via
// llama-memory-recurrent.cpp even though they don't support seq_rm.
const auto * model = llama_get_model(ctx_tgt);
if (model && llama_model_n_mtp_layers(model) > 0) {
LOG_INF("%s: seq_rm not supported, but MTP model detected — using checkpoint/restore for rollback\n", __func__);
// Restore the state we just modified
} else {
LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__);
res = false;
goto done;
}
}
done:

View File

@ -148,6 +148,15 @@ struct server_slot {
llama_token sampled; // in speculative mode, this is the last accepted token
llama_tokens drafted;
// Inline MTP (Multi-Token Prediction) state.
// Instead of using the speculative framework (which has M-RoPE and SSM
// rollback issues), we propose one draft token from MTP logits and verify
// it in the next decode step. No seq_rm or rollback needed.
llama_token mtp_draft_token = -1; // proposed draft token (-1 = none)
int mtp_i_batch = -1; // batch index of the draft token
bool mtp_pending = false; // true when draft is in the batch awaiting verification
bool mtp_cooldown = false; // skip MTP proposal for one iteration after draft processing
// stats
size_t n_sent_text = 0; // number of sent text character
@ -178,6 +187,10 @@ struct server_slot {
drafted.clear();
i_batch_dft.clear();
mtp_draft_token = -1;
mtp_i_batch = -1;
mtp_pending = false;
mtp_cooldown = false;
generated_tokens.clear();
generated_token_probs.clear();
json_schema = json();
@ -757,14 +770,16 @@ private:
SRV_WRN("%s", "speculative decoding not supported by this context\n");
}
// Auto-detect MTP capability — log presence but don't enable speculative
// decoding framework. The hybrid SSM + M-RoPE architecture is incompatible
// with the speculative verify loop when tool-calling is active.
// MTP tensors are still computed in the forward pass graph (build_mtp_head).
// Auto-detect MTP: if model has MTP layers and no speculative type
// is explicitly set, auto-enable MTP speculative decoding.
if (params_base.speculative.type == COMMON_SPECULATIVE_TYPE_NONE) {
const int32_t n_mtp = llama_model_n_mtp_layers(llama_get_model(ctx));
if (n_mtp > 0) {
SRV_INF("model has %d MTP layer(s) (graph-only, speculative verify disabled for hybrid models)\n", n_mtp);
if (n_mtp > 0 && can_spec) {
SRV_INF("model has %d MTP layer(s) — auto-enabling MTP speculative decoding\n", n_mtp);
params_base.speculative.type = COMMON_SPECULATIVE_TYPE_MTP;
params_base.speculative.n_max = 1; // MTP-1: one draft token per step
} else if (n_mtp > 0) {
SRV_INF("model has %d MTP layer(s) but speculative context not compatible\n", n_mtp);
}
}
@ -2117,13 +2132,55 @@ private:
slot.drafted = std::move(draft);
}
} else {
// no speculative decoding
// no speculative decoding — but try inline MTP if available
slot.i_batch = batch.n_tokens;
common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true);
slot.prompt.tokens.push_back(slot.sampled);
// --- Inline MTP: propose draft token from MTP logits ---
// After adding the sampled token, check MTP logits for a draft.
// Key: do NOT add draft to slot.prompt.tokens yet — only add to
// the batch. If verified next iteration, we add it then. If
// rejected, we decode at the same position again (overwrites
// the draft's KV entry). This avoids llama_memory_seq_rm which
// DeltaNet doesn't support.
// Inline MTP gated by ATLAS_MTP_INLINE env var (default: off until stable)
// Skip proposal during cooldown (after processing a draft) to get
// fresh MTP logits from a clean single-token decode.
if (slot.mtp_cooldown) {
slot.mtp_cooldown = false;
} else if (getenv("ATLAS_MTP_INLINE") && llama_model_n_mtp_layers(llama_get_model(ctx)) > 0 && !slot.mtp_pending) {
float * mtp_logits = llama_get_mtp_logits(ctx);
if (mtp_logits != nullptr) {
const auto * vocab = llama_model_get_vocab(llama_get_model(ctx));
const int n_vocab = llama_vocab_n_tokens(vocab);
if (n_vocab > 0) {
// Find argmax of MTP logits
llama_token draft_id = 0;
float draft_max = mtp_logits[0];
for (int v = 1; v < n_vocab; v++) {
if (mtp_logits[v] > draft_max) {
draft_max = mtp_logits[v];
draft_id = v;
}
}
// Don't draft EOS/special tokens
if (!llama_vocab_is_eog(vocab, draft_id)) {
slot.mtp_draft_token = draft_id;
slot.mtp_i_batch = batch.n_tokens;
slot.mtp_pending = true;
// Add draft to batch at next position but do NOT
// push to slot.prompt.tokens. If rejected, next
// decode at this position overwrites the KV entry.
common_batch_add(batch, draft_id, slot.prompt.tokens.pos_next(), { slot.id }, true);
}
}
}
}
SLT_DBG(slot, "slot decode token, n_ctx = %d, n_tokens = %d, truncated = %d\n",
slot.n_ctx, slot.prompt.n_tokens(), slot.truncated);
}
@ -2835,13 +2892,77 @@ private:
const int tok_idx = slot.i_batch - i;
// --- Inline MTP: verify pending draft ---
// Simplest approach: sample at main position only.
// If it matches draft, the draft was correct — push draft to
// prompt.tokens (it's already in KV cache) and continue normally.
// The "free" token is that we don't need to decode the draft
// since it's already in the KV cache from the batch.
// If rejected, the draft's KV entry gets overwritten next decode.
if (slot.mtp_pending) {
llama_token id_at_main = common_sampler_sample(slot.smpl.get(), ctx, tok_idx);
common_sampler_accept(slot.smpl.get(), id_at_main, true);
if (slot.mtp_i_batch >= (int)i && slot.mtp_i_batch < (int)(i + n_tokens)) {
if (id_at_main == slot.mtp_draft_token) {
// Draft correct! Push it to prompt.tokens so
// position tracking stays in sync with KV cache.
slot.prompt.tokens.push_back(slot.mtp_draft_token);
slot.n_draft_accepted += 1;
// slot.sampled stays as id_at_main (= draft).
// Next iteration: push_back(slot.sampled) would
// be a duplicate. So we need to set sampled to
// something the NEXT decode should process.
// But we don't have the next token yet — we only
// verified the draft, not sampled beyond it.
// The correct behavior: emit draft, and the next
// iteration will decode normally from the position
// after the draft. The KV cache already has the draft
// so prompt processing is free for this position.
}
// If rejected: KV entry at draft pos gets overwritten.
slot.n_draft_total += 1;
}
slot.sampled = id_at_main;
slot.mtp_pending = false;
slot.mtp_i_batch = -1;
slot.mtp_draft_token = -1;
slot.mtp_cooldown = true;
slot.i_batch = -1;
const int64_t t_current = ggml_time_us();
slot.n_decoded += 1;
if (slot.n_decoded == 1) {
slot.t_start_generation = t_current;
slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
metrics.on_prompt_eval(slot);
}
slot.t_token_generation = std::max<int64_t>(1, t_current - slot.t_start_generation) / 1e3;
completion_token_output result_main;
result_main.tok = id_at_main;
result_main.text_to_send = common_token_to_piece(ctx, result_main.tok, accept_special_token(slot, result_main.tok));
result_main.prob = 1.0f;
if (!process_token(result_main, slot)) {
slot.print_timings();
send_final_response(slot);
metrics.on_prediction(slot);
slot.release();
continue;
}
continue; // done with this slot for this decode step
}
// --- Normal sampling (no pending MTP draft) ---
llama_token id = common_sampler_sample(slot.smpl.get(), ctx, tok_idx);
slot.i_batch = -1;
common_sampler_accept(slot.smpl.get(), id, true);
// here we have synchronized the llama_context (due to the sampling above), so we can do time measurement
const int64_t t_current = ggml_time_us();
slot.n_decoded += 1;
@ -2857,14 +2978,13 @@ private:
completion_token_output result;
result.tok = id;
result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs
result.prob = 1.0f;
if (slot.task->params.sampling.n_probs > 0) {
populate_token_probs(slot, result, slot.task->params.post_sampling_probs, params_base.special, tok_idx);
}
if (!process_token(result, slot)) {
// release slot because of stop condition
slot.print_timings();
send_final_response(slot);
metrics.on_prediction(slot);
@ -2906,7 +3026,37 @@ private:
slot.prompt.tokens.insert({ids.begin(), ids.end() - 1});
slot.sampled = ids.back(); // last accepted token
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.prompt.n_tokens(), -1);
// Remove rejected draft tokens from KV cache.
// For standard transformers, seq_rm cleanly removes entries.
// For hybrid SSM/DeltaNet models, seq_rm may fail because
// recurrent state can't be partially rewound. In that case,
// we need to re-decode the accepted tokens to rebuild the
// correct recurrent state from the last checkpoint.
if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.prompt.n_tokens(), -1)) {
// seq_rm failed (hybrid model). Clear the sequence and
// re-evaluate accepted tokens to rebuild recurrent state.
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
// Re-evaluate accepted prompt tokens from scratch.
// This is expensive but correct — only happens on draft
// rejection for hybrid models. MTP-1 has ~50% acceptance
// so this runs roughly every other step.
const auto & prompt_tokens = slot.prompt.tokens.get_text_tokens();
if (!prompt_tokens.empty()) {
const int n_prompt = (int)prompt_tokens.size();
const int n_batch_re = llama_n_batch(ctx);
for (int j = 0; j < n_prompt; j += n_batch_re) {
const int n_eval = std::min(n_batch_re, n_prompt - j);
llama_batch batch_re = llama_batch_get_one(
const_cast<llama_token *>(prompt_tokens.data()) + j, n_eval);
// Only need logits for the last token
if (j + n_eval >= n_prompt) {
batch_re.logits[n_eval - 1] = true;
}
llama_decode(ctx, batch_re);
}
}
}
for (size_t i = 0; i < ids.size(); ++i) {
completion_token_output result;