diff --git a/common/speculative.cpp b/common/speculative.cpp index 48a121fdef..8bad556d69 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -879,9 +879,19 @@ bool common_speculative_is_compat(llama_context * ctx_tgt) { // try to remove the last tokens if (!llama_memory_seq_rm(mem, 0, 1, -1)) { - LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__); - res = false; - goto done; + // Check if the model has MTP layers — for MTP-1, we can use + // checkpoint/restore instead of seq_rm for the 1-token rollback. + // Hybrid SSM models (DeltaNet) support checkpoint/restore via + // llama-memory-recurrent.cpp even though they don't support seq_rm. + const auto * model = llama_get_model(ctx_tgt); + if (model && llama_model_n_mtp_layers(model) > 0) { + LOG_INF("%s: seq_rm not supported, but MTP model detected — using checkpoint/restore for rollback\n", __func__); + // Restore the state we just modified + } else { + LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__); + res = false; + goto done; + } } done: diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 60fae14281..0be0a6d49c 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -148,6 +148,15 @@ struct server_slot { llama_token sampled; // in speculative mode, this is the last accepted token llama_tokens drafted; + // Inline MTP (Multi-Token Prediction) state. + // Instead of using the speculative framework (which has M-RoPE and SSM + // rollback issues), we propose one draft token from MTP logits and verify + // it in the next decode step. No seq_rm or rollback needed. + llama_token mtp_draft_token = -1; // proposed draft token (-1 = none) + int mtp_i_batch = -1; // batch index of the draft token + bool mtp_pending = false; // true when draft is in the batch awaiting verification + bool mtp_cooldown = false; // skip MTP proposal for one iteration after draft processing + // stats size_t n_sent_text = 0; // number of sent text character @@ -178,6 +187,10 @@ struct server_slot { drafted.clear(); i_batch_dft.clear(); + mtp_draft_token = -1; + mtp_i_batch = -1; + mtp_pending = false; + mtp_cooldown = false; generated_tokens.clear(); generated_token_probs.clear(); json_schema = json(); @@ -757,14 +770,16 @@ private: SRV_WRN("%s", "speculative decoding not supported by this context\n"); } - // Auto-detect MTP capability — log presence but don't enable speculative - // decoding framework. The hybrid SSM + M-RoPE architecture is incompatible - // with the speculative verify loop when tool-calling is active. - // MTP tensors are still computed in the forward pass graph (build_mtp_head). + // Auto-detect MTP: if model has MTP layers and no speculative type + // is explicitly set, auto-enable MTP speculative decoding. if (params_base.speculative.type == COMMON_SPECULATIVE_TYPE_NONE) { const int32_t n_mtp = llama_model_n_mtp_layers(llama_get_model(ctx)); - if (n_mtp > 0) { - SRV_INF("model has %d MTP layer(s) (graph-only, speculative verify disabled for hybrid models)\n", n_mtp); + if (n_mtp > 0 && can_spec) { + SRV_INF("model has %d MTP layer(s) — auto-enabling MTP speculative decoding\n", n_mtp); + params_base.speculative.type = COMMON_SPECULATIVE_TYPE_MTP; + params_base.speculative.n_max = 1; // MTP-1: one draft token per step + } else if (n_mtp > 0) { + SRV_INF("model has %d MTP layer(s) but speculative context not compatible\n", n_mtp); } } @@ -2117,13 +2132,55 @@ private: slot.drafted = std::move(draft); } } else { - // no speculative decoding + // no speculative decoding — but try inline MTP if available slot.i_batch = batch.n_tokens; common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true); - slot.prompt.tokens.push_back(slot.sampled); + // --- Inline MTP: propose draft token from MTP logits --- + // After adding the sampled token, check MTP logits for a draft. + // Key: do NOT add draft to slot.prompt.tokens yet — only add to + // the batch. If verified next iteration, we add it then. If + // rejected, we decode at the same position again (overwrites + // the draft's KV entry). This avoids llama_memory_seq_rm which + // DeltaNet doesn't support. + // Inline MTP gated by ATLAS_MTP_INLINE env var (default: off until stable) + // Skip proposal during cooldown (after processing a draft) to get + // fresh MTP logits from a clean single-token decode. + if (slot.mtp_cooldown) { + slot.mtp_cooldown = false; + } else if (getenv("ATLAS_MTP_INLINE") && llama_model_n_mtp_layers(llama_get_model(ctx)) > 0 && !slot.mtp_pending) { + float * mtp_logits = llama_get_mtp_logits(ctx); + if (mtp_logits != nullptr) { + const auto * vocab = llama_model_get_vocab(llama_get_model(ctx)); + const int n_vocab = llama_vocab_n_tokens(vocab); + if (n_vocab > 0) { + // Find argmax of MTP logits + llama_token draft_id = 0; + float draft_max = mtp_logits[0]; + for (int v = 1; v < n_vocab; v++) { + if (mtp_logits[v] > draft_max) { + draft_max = mtp_logits[v]; + draft_id = v; + } + } + + // Don't draft EOS/special tokens + if (!llama_vocab_is_eog(vocab, draft_id)) { + slot.mtp_draft_token = draft_id; + slot.mtp_i_batch = batch.n_tokens; + slot.mtp_pending = true; + + // Add draft to batch at next position but do NOT + // push to slot.prompt.tokens. If rejected, next + // decode at this position overwrites the KV entry. + common_batch_add(batch, draft_id, slot.prompt.tokens.pos_next(), { slot.id }, true); + } + } + } + } + SLT_DBG(slot, "slot decode token, n_ctx = %d, n_tokens = %d, truncated = %d\n", slot.n_ctx, slot.prompt.n_tokens(), slot.truncated); } @@ -2835,13 +2892,77 @@ private: const int tok_idx = slot.i_batch - i; + // --- Inline MTP: verify pending draft --- + // Simplest approach: sample at main position only. + // If it matches draft, the draft was correct — push draft to + // prompt.tokens (it's already in KV cache) and continue normally. + // The "free" token is that we don't need to decode the draft + // since it's already in the KV cache from the batch. + // If rejected, the draft's KV entry gets overwritten next decode. + if (slot.mtp_pending) { + llama_token id_at_main = common_sampler_sample(slot.smpl.get(), ctx, tok_idx); + common_sampler_accept(slot.smpl.get(), id_at_main, true); + + if (slot.mtp_i_batch >= (int)i && slot.mtp_i_batch < (int)(i + n_tokens)) { + if (id_at_main == slot.mtp_draft_token) { + // Draft correct! Push it to prompt.tokens so + // position tracking stays in sync with KV cache. + slot.prompt.tokens.push_back(slot.mtp_draft_token); + slot.n_draft_accepted += 1; + // slot.sampled stays as id_at_main (= draft). + // Next iteration: push_back(slot.sampled) would + // be a duplicate. So we need to set sampled to + // something the NEXT decode should process. + // But we don't have the next token yet — we only + // verified the draft, not sampled beyond it. + // The correct behavior: emit draft, and the next + // iteration will decode normally from the position + // after the draft. The KV cache already has the draft + // so prompt processing is free for this position. + } + // If rejected: KV entry at draft pos gets overwritten. + slot.n_draft_total += 1; + } + + slot.sampled = id_at_main; + slot.mtp_pending = false; + slot.mtp_i_batch = -1; + slot.mtp_draft_token = -1; + slot.mtp_cooldown = true; + slot.i_batch = -1; + + const int64_t t_current = ggml_time_us(); + slot.n_decoded += 1; + if (slot.n_decoded == 1) { + slot.t_start_generation = t_current; + slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3; + metrics.on_prompt_eval(slot); + } + slot.t_token_generation = std::max(1, t_current - slot.t_start_generation) / 1e3; + + completion_token_output result_main; + result_main.tok = id_at_main; + result_main.text_to_send = common_token_to_piece(ctx, result_main.tok, accept_special_token(slot, result_main.tok)); + result_main.prob = 1.0f; + + if (!process_token(result_main, slot)) { + slot.print_timings(); + send_final_response(slot); + metrics.on_prediction(slot); + slot.release(); + continue; + } + + continue; // done with this slot for this decode step + } + + // --- Normal sampling (no pending MTP draft) --- llama_token id = common_sampler_sample(slot.smpl.get(), ctx, tok_idx); slot.i_batch = -1; common_sampler_accept(slot.smpl.get(), id, true); - // here we have synchronized the llama_context (due to the sampling above), so we can do time measurement const int64_t t_current = ggml_time_us(); slot.n_decoded += 1; @@ -2857,14 +2978,13 @@ private: completion_token_output result; result.tok = id; result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok)); - result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs + result.prob = 1.0f; if (slot.task->params.sampling.n_probs > 0) { populate_token_probs(slot, result, slot.task->params.post_sampling_probs, params_base.special, tok_idx); } if (!process_token(result, slot)) { - // release slot because of stop condition slot.print_timings(); send_final_response(slot); metrics.on_prediction(slot); @@ -2906,7 +3026,37 @@ private: slot.prompt.tokens.insert({ids.begin(), ids.end() - 1}); slot.sampled = ids.back(); // last accepted token - llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.prompt.n_tokens(), -1); + // Remove rejected draft tokens from KV cache. + // For standard transformers, seq_rm cleanly removes entries. + // For hybrid SSM/DeltaNet models, seq_rm may fail because + // recurrent state can't be partially rewound. In that case, + // we need to re-decode the accepted tokens to rebuild the + // correct recurrent state from the last checkpoint. + if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.prompt.n_tokens(), -1)) { + // seq_rm failed (hybrid model). Clear the sequence and + // re-evaluate accepted tokens to rebuild recurrent state. + llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1); + + // Re-evaluate accepted prompt tokens from scratch. + // This is expensive but correct — only happens on draft + // rejection for hybrid models. MTP-1 has ~50% acceptance + // so this runs roughly every other step. + const auto & prompt_tokens = slot.prompt.tokens.get_text_tokens(); + if (!prompt_tokens.empty()) { + const int n_prompt = (int)prompt_tokens.size(); + const int n_batch_re = llama_n_batch(ctx); + for (int j = 0; j < n_prompt; j += n_batch_re) { + const int n_eval = std::min(n_batch_re, n_prompt - j); + llama_batch batch_re = llama_batch_get_one( + const_cast(prompt_tokens.data()) + j, n_eval); + // Only need logits for the last token + if (j + n_eval >= n_prompt) { + batch_re.logits[n_eval - 1] = true; + } + llama_decode(ctx, batch_re); + } + } + } for (size_t i = 0; i < ids.size(); ++i) { completion_token_output result;