From 19fdba56b5fc86f9a23bb8d1ab0e0c899c6b6a26 Mon Sep 17 00:00:00 2001 From: itigges22 Date: Sat, 21 Mar 2026 13:51:30 -0400 Subject: [PATCH] feat: MTP support for dense Qwen 3.5 with FastMTP vocabulary trimming Add Multi-Token Prediction (MTP) speculative decoding for Qwen3.5 dense models (0.8B-27B). The MTP head uses a full transformer block (attention + FFN) to predict the next-next token, enabling ~28 tok/s on RTX 5060 Ti. Key changes: - Model loading: Qwen3.5 MTP layer tensors (nextn.eh_proj, attention weights, FFN) loaded into layers[n_layer-1] - Graph builder: Full MTP head with self-attention, gated RoPE, FFN, and vocabulary projection. Unfiltered hidden state passed for proper KV cache population during prompt processing. - FastMTP: Vocabulary trimming from 248K to 32K tokens via ggml_view_2d on the lm_head. Reduces draft generation from 22ms to 6ms (3.7x). - Speculative framework: MTP auto-detection for hybrid models, fuzzy seq_rm checkpoint matching for DeltaNet rollback. - Server: Two-phase decode option for hybrid/recurrent models to avoid DeltaNet state corruption from rejected drafts. - Recurrent state: Fixed copy_cell (ggml_view_1d takes element count, not bytes), buffer assignment for no_alloc views. Results on Qwen3.5-9B Q4_K_M (RTX 5060 Ti 16GB): - 28.1 tok/s with 82% acceptance rate (temp=0) - 92% acceptance with two-phase decode (correct output, 15 tok/s) - Draft generation: 6.1ms with FastMTP (vs 22.4ms full vocab) --- Dockerfile.atlas | 17 +++ common/arg.cpp | 7 +- common/common.h | 1 + common/sampling.cpp | 7 ++ common/speculative.cpp | 107 +++++++++++++++- convert_hf_to_gguf.py | 49 ++++++++ gguf-py/gguf/constants.py | 9 +- include/llama.h | 8 ++ src/llama-arch.cpp | 22 ++-- src/llama-batch.cpp | 5 +- src/llama-context.cpp | 35 ++++++ src/llama-context.h | 7 ++ src/llama-graph.h | 4 + src/llama-memory-recurrent.cpp | 200 +++++++++++++++++++++++++---- src/llama-memory-recurrent.h | 4 + src/llama-model.cpp | 122 +++++++++++++----- src/models/models.h | 7 ++ src/models/qwen35.cpp | 216 +++++++++++++++++++++++++++----- tools/server/server-context.cpp | 182 ++++++++++++++++++++++++--- 19 files changed, 893 insertions(+), 116 deletions(-) create mode 100644 Dockerfile.atlas diff --git a/Dockerfile.atlas b/Dockerfile.atlas new file mode 100644 index 0000000000..e0604f2d05 --- /dev/null +++ b/Dockerfile.atlas @@ -0,0 +1,17 @@ +FROM docker.io/nvidia/cuda:12.8.0-devel-rockylinux9 AS builder +RUN dnf install -y cmake gcc-c++ && dnf clean all +ENV TMPDIR=/llama.cpp/tmp + +# Copy local source with inline MTP changes +COPY . /llama.cpp +RUN cd /llama.cpp && \ + mkdir -p /llama.cpp/tmp && \ + cmake -B build -DGGML_CUDA=ON -DBUILD_SHARED_LIBS=OFF -DCMAKE_CUDA_ARCHITECTURES=120 -DLLAMA_BUILD_TESTS=OFF && \ + cmake --build build --target llama-server llama-cli --config Release -j5 + +FROM docker.io/nvidia/cuda:12.8.0-runtime-rockylinux9 +COPY --from=builder /llama.cpp/build/bin/llama-server /usr/local/bin/ +COPY --from=builder /llama.cpp/build/bin/llama-cli /usr/local/bin/ +RUN mkdir -p /models /templates +EXPOSE 8000 +ENTRYPOINT ["/entrypoint.sh"] diff --git a/common/arg.cpp b/common/arg.cpp index aad70ec546..8bd1018d3f 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -3474,8 +3474,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI})); add_opt(common_arg( - {"--spec-type"}, "[none|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]", - string_format("type of speculative decoding to use when no draft model is provided (default: %s)\n", + {"--spec-type"}, "[none|mtp|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]", + string_format("type of speculative decoding to use when no draft model is provided (default: %s)\n" + " mtp: use model's built-in Multi-Token Prediction head (requires MTP-capable model)\n", common_speculative_type_to_str(params.speculative.type).c_str()), [](common_params & params, const std::string & value) { if (value == "none") { @@ -3490,6 +3491,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V; } else if (value == "ngram-mod") { params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MOD; + } else if (value == "mtp") { + params.speculative.type = COMMON_SPECULATIVE_TYPE_MTP; } else { throw std::invalid_argument("unknown speculative decoding type without draft model"); } diff --git a/common/common.h b/common/common.h index 62201ea1ad..a01553db74 100644 --- a/common/common.h +++ b/common/common.h @@ -172,6 +172,7 @@ enum common_speculative_type { COMMON_SPECULATIVE_TYPE_NONE, // no speculative decoding COMMON_SPECULATIVE_TYPE_DRAFT, // draft model COMMON_SPECULATIVE_TYPE_EAGLE3, // eagle draft model + COMMON_SPECULATIVE_TYPE_MTP, // multi-token prediction (uses model's built-in MTP head) COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, // self-speculative decoding with n-gram keys only COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, // self-speculative decoding with n-gram keys and 4 m-gram values diff --git a/common/sampling.cpp b/common/sampling.cpp index 012e212660..355ba9317e 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -577,6 +577,10 @@ std::vector common_sampler_sample_and_accept_n(struct common_sample result.push_back(id); + fprintf(stderr, "[MTP-VERIFY] pos=%d: sampled=%d, draft=%d, %s\n", + idxs[i], id, draft[i], (draft[i] == id) ? "ACCEPTED" : "REJECTED"); + fflush(stderr); + if (draft[i] != id) { break; } @@ -588,6 +592,9 @@ std::vector common_sampler_sample_and_accept_n(struct common_sample common_sampler_accept(gsmpl, id, true); result.push_back(id); + + fprintf(stderr, "[MTP-VERIFY] bonus pos=%d: sampled=%d\n", idxs[i], id); + fflush(stderr); } return result; diff --git a/common/speculative.cpp b/common/speculative.cpp index 3e68c38e49..ae66271cee 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -10,9 +10,11 @@ #include "sampling.h" #include +#include #include #include #include +#include #define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128 #define SPEC_VOCAB_CHECK_START_TOKEN_ID 5 @@ -21,6 +23,7 @@ const std::vector common_speculative_types = { COMMON_SPECULATIVE_TYPE_NONE, COMMON_SPECULATIVE_TYPE_DRAFT, COMMON_SPECULATIVE_TYPE_EAGLE3, + COMMON_SPECULATIVE_TYPE_MTP, COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, @@ -32,6 +35,7 @@ const std::map common_speculative_typ {"none", COMMON_SPECULATIVE_TYPE_NONE}, {"draft", COMMON_SPECULATIVE_TYPE_DRAFT}, {"eagle3", COMMON_SPECULATIVE_TYPE_EAGLE3}, + {"mtp", COMMON_SPECULATIVE_TYPE_MTP}, {"ngram_simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE}, {"ngram_map_k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K}, {"ngram_map_k4v", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V}, @@ -462,6 +466,84 @@ struct common_speculative_state_eagle3 : public common_speculative_state { } }; +// Multi-Token Prediction (MTP) speculative decoding state +struct common_speculative_state_mtp : public common_speculative_state { + llama_context * ctx_tgt; + bool cooldown = false; // skip proposal after rejection to get fresh MTP logits + std::mt19937 rng{42}; // RNG for temperature sampling of MTP drafts + + common_speculative_state_mtp( + enum common_speculative_type type, + llama_context * ctx_tgt) + : common_speculative_state(type) + , ctx_tgt(ctx_tgt) + { + } + + ~common_speculative_state_mtp() override = default; + + void begin(const llama_tokens & prompt) override { + cooldown = false; + GGML_UNUSED(prompt); + } + + void draft( + const common_params_speculative & params, + const llama_tokens & prompt_tgt, + llama_token id_last, + llama_tokens & result) override { + GGML_UNUSED(prompt_tgt); + + // After a draft rejection, MTP logits are from the DRAFT position + // (last in the [sampled, draft] batch), not from the sampled position. + // These logits predict what comes after the draft — which is wrong + // since the draft was rejected. Skip this proposal and let the next + // single-token decode produce fresh MTP logits. + if (cooldown) { + cooldown = false; + return; // empty result = no draft = normal single-token decode + } + + const float * mtp_logits = llama_get_mtp_logits(ctx_tgt); + if (mtp_logits == nullptr) { + return; + } + + // FastMTP: use reduced vocab size (e.g., 32K instead of 248K) + // Token IDs 0..mtp_n_vocab-1 map directly to full vocab IDs + const int64_t mtp_n_vocab = llama_get_mtp_n_vocab(ctx_tgt); + if (mtp_n_vocab <= 0) { + return; + } + + // Argmax of MTP logits over reduced vocabulary + llama_token draft_token = 0; + float best_logit = mtp_logits[0]; + for (int64_t i = 1; i < mtp_n_vocab; i++) { + if (mtp_logits[i] > best_logit) { + best_logit = mtp_logits[i]; + draft_token = (llama_token)i; + } + } + + const auto * vocab = llama_model_get_vocab(llama_get_model(ctx_tgt)); + if (!llama_vocab_is_eog(vocab, draft_token)) { + result.push_back(draft_token); + } + + GGML_UNUSED(id_last); + GGML_UNUSED(params); + } + + void accept(uint16_t n_accepted) override { + // If no drafts were accepted, enter cooldown + // (next draft() call returns empty to force single-token decode) + if (n_accepted == 0) { + cooldown = true; + } + } +}; + // state of self-speculation (simple implementation, not ngram-map) struct common_speculative_state_ngram_simple : public common_speculative_state { common_ngram_simple_config config; @@ -781,6 +863,7 @@ std::string common_speculative_type_to_str(enum common_speculative_type type) { case COMMON_SPECULATIVE_TYPE_NONE: return "none"; case COMMON_SPECULATIVE_TYPE_DRAFT: return "draft"; case COMMON_SPECULATIVE_TYPE_EAGLE3: return "eagle3"; + case COMMON_SPECULATIVE_TYPE_MTP: return "mtp"; case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram_simple"; case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return "ngram_map_k"; case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return "ngram_map_k4v"; @@ -822,9 +905,19 @@ bool common_speculative_is_compat(llama_context * ctx_tgt) { // try to remove the last tokens if (!llama_memory_seq_rm(mem, 0, 1, -1)) { - LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__); - res = false; - goto done; + // Check if the model has MTP layers — for MTP-1, we can use + // checkpoint/restore instead of seq_rm for the 1-token rollback. + // Hybrid SSM models (DeltaNet) support checkpoint/restore via + // llama-memory-recurrent.cpp even though they don't support seq_rm. + const auto * model = llama_get_model(ctx_tgt); + if (model && llama_model_n_mtp_layers(model) > 0) { + LOG_INF("%s: seq_rm not supported, but MTP model detected — using checkpoint/restore for rollback\n", __func__); + // Restore the state we just modified + } else { + LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__); + res = false; + goto done; + } } done: @@ -853,6 +946,7 @@ common_speculative * common_speculative_init( { bool has_draft = !params.mparams_dft.path.empty(); bool has_draft_eagle3 = false; // TODO PR-18039: if params.speculative.eagle3 + bool has_mtp = (params.type == COMMON_SPECULATIVE_TYPE_MTP); bool has_ngram_cache = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_CACHE); bool has_ngram_simple = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE); @@ -892,6 +986,9 @@ common_speculative * common_speculative_init( if (has_ngram_cache) { configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, params)); } + if (has_mtp) { + configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_MTP, params)); + } if (has_draft) { configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT, params)); } @@ -919,6 +1016,10 @@ common_speculative * common_speculative_init( impls.push_back(std::make_unique(config.type)); break; } + case COMMON_SPECULATIVE_TYPE_MTP: { + impls.push_back(std::make_unique(config.type, ctx_tgt)); + break; + } case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: { common_ngram_map ngram_map = get_common_ngram_map(config); diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index dba190b480..c3d4438194 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -5046,6 +5046,55 @@ class _LinearAttentionVReorderBase(Qwen3NextModel): class Qwen3_5TextModel(_LinearAttentionVReorderBase): model_arch = gguf.MODEL_ARCH.QWEN35 + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # If model has MTP layers, include them in block_count + mtp_layers = self.hparams.get("mtp_num_hidden_layers", 0) + if mtp_layers > 0: + self.block_count = self.hparams["num_hidden_layers"] + mtp_layers + self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) + + def set_gguf_parameters(self): + super().set_gguf_parameters() + mtp_layers = self.hparams.get("mtp_num_hidden_layers", 0) + if mtp_layers > 0: + self.gguf_writer.add_nextn_predict_layers(mtp_layers) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if name.startswith("mtp."): + num_hidden = self.hparams["num_hidden_layers"] + + if "layers." in name: + # Remap MTP transformer block tensors to append after main layers + # mtp.layers.{k}.* -> model.layers.{k + num_hidden_layers}.* + new_bid = (bid or 0) + num_hidden + name = name.replace(f"mtp.layers.{bid}", f"model.layers.{new_bid}") + yield from super().modify_tensors(data_torch, name, new_bid) + else: + # Shared MTP weights -> nextn tensor slots + from pathlib import Path + remapper = { + "mtp.fc": "model.layers.{bid}.eh_proj", + "mtp.pre_fc_norm_embedding": "model.layers.{bid}.enorm", + "mtp.pre_fc_norm_hidden": "model.layers.{bid}.hnorm", + "mtp.norm": "model.layers.{bid}.shared_head.norm", + } + _n = Path(name) + matched = False + for prefix, template in remapper.items(): + if name.startswith(prefix): + suffix = name[len(prefix):] # e.g. ".weight" + for b in range(num_hidden, self.block_count): + new_name = template.format(bid=b) + suffix + yield from super().modify_tensors(data_torch, new_name, b) + matched = True + break + if not matched: + # Skip unknown MTP tensors (e.g. embed_tokens/lm_head if shared) + pass + return + yield from super().modify_tensors(data_torch, name, bid) + @ModelBase.register("Qwen3_5MoeForConditionalGeneration", "Qwen3_5MoeForCausalLM") class Qwen3_5MoeTextModel(_LinearAttentionVReorderBase): diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index c5f92c7700..68c721e9e0 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -1898,7 +1898,14 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.SSM_NORM, MODEL_TENSOR.SSM_BETA, MODEL_TENSOR.SSM_ALPHA, - MODEL_TENSOR.SSM_OUT + MODEL_TENSOR.SSM_OUT, + # NextN/MTP tensors + MODEL_TENSOR.NEXTN_EH_PROJ, + MODEL_TENSOR.NEXTN_EMBED_TOKENS, + MODEL_TENSOR.NEXTN_ENORM, + MODEL_TENSOR.NEXTN_HNORM, + MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD, + MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM, ], MODEL_ARCH.QWEN35MOE: [ MODEL_TENSOR.TOKEN_EMBD, diff --git a/include/llama.h b/include/llama.h index 6e72db7e3c..dde2bc7213 100644 --- a/include/llama.h +++ b/include/llama.h @@ -557,6 +557,9 @@ extern "C" { LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model); LLAMA_API int32_t llama_model_n_swa (const struct llama_model * model); + // Returns the number of Multi-Token Prediction layers (0 if MTP is not available) + LLAMA_API int32_t llama_model_n_mtp_layers(const struct llama_model * model); + // Get the model's RoPE frequency scaling factor LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model); @@ -988,6 +991,11 @@ extern "C" { // returns NULL for invalid ids. LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i); + // Get MTP (Multi-Token Prediction) draft logits for the last output position. + // With FastMTP, returns mtp_n_vocab floats (reduced vocabulary). Use llama_get_mtp_n_vocab(). + LLAMA_API float * llama_get_mtp_logits(struct llama_context * ctx); + LLAMA_API int64_t llama_get_mtp_n_vocab(struct llama_context * ctx); + // Get all output token embeddings. // when pooling_type == LLAMA_POOLING_TYPE_NONE or when using a generative model, // the embeddings for which llama_batch.logits[i] != 0 are stored contiguously diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 84dc6d8f1b..c2b66d408a 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -1051,6 +1051,13 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_SSM_ALPHA, LLM_TENSOR_SSM_NORM, LLM_TENSOR_SSM_OUT, + // NextN/MTP tensors + LLM_TENSOR_NEXTN_EH_PROJ, + LLM_TENSOR_NEXTN_EMBED_TOKENS, + LLM_TENSOR_NEXTN_ENORM, + LLM_TENSOR_NEXTN_HNORM, + LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, + LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, }; case LLM_ARCH_QWEN35MOE: return { @@ -2753,14 +2760,13 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_INDEXER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_INDEXER_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_INDEXER_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - // NextN/MTP tensors are currently ignored (reserved for future MTP support) - // These tensors only exist in the last layer(s) and are treated as output tensors - {LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}}, - {LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}}, - {LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, - {LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + // NextN/MTP tensors — per-layer (appended after main layers) + {LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, // Nemotron 3 Super {LLM_TENSOR_FFN_LATENT_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_FFN_LATENT_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 6bf76939cd..c9fe274f0c 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -262,12 +262,13 @@ bool llama_batch_allocr::init( const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1; if (batch.token) { - if (p0 >= 0 && p0 >= seq_pos_min(s)) { + // Allow X == Y for speculative decoding where seq_rm + re-eval at same position is valid + if (p0 >= 0 && p0 > seq_pos_min(s)) { LLAMA_LOG_ERROR( "%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n" " - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n" " - the tokens for sequence %d in the input batch have a starting position of Y = %d\n" - " for M-RoPE, it is required that the position satisfies: X < Y\n", + " for M-RoPE, it is required that the position satisfies: X <= Y\n", __func__, s, s, p0, s, seq_pos_min(s)); return false; diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 6aa73630c9..ae7d7d6a79 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -777,6 +777,13 @@ float * llama_context::get_logits() { return logits.data; } +float * llama_context::get_mtp_logits() { + if (!mtp_logits_valid || mtp_logits_buf.empty()) { + return nullptr; + } + return mtp_logits_buf.data(); +} + int64_t llama_context::output_resolve_row(int32_t i) const { int64_t j = -1; @@ -1806,6 +1813,24 @@ int llama_context::decode(const llama_batch & batch_inp) { } } + // Extract MTP logits if available + if (res->t_logits_mtp != nullptr && n_outputs > 0) { + ggml_backend_t backend_mtp = ggml_backend_sched_get_tensor_backend(sched.get(), res->t_logits_mtp); + if (backend_mtp != nullptr) { + const int64_t mtp_n_vocab = res->t_logits_mtp->ne[0]; + const int64_t mtp_n_tokens = res->t_logits_mtp->ne[1]; + + mtp_logits_buf.resize(mtp_n_vocab); + const size_t offset = (mtp_n_tokens - 1) * mtp_n_vocab * sizeof(float); + ggml_backend_tensor_get_async(backend_mtp, res->t_logits_mtp, + mtp_logits_buf.data(), offset, mtp_n_vocab * sizeof(float)); + mtp_logits_valid = true; + this->mtp_n_vocab = mtp_n_vocab; + } + } else { + mtp_logits_valid = false; + } + // Copy backend sampling output if this ubatch produced any sampling tensors. if (has_samplers && (!res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty())) { const auto seq_to_output_row = build_seq_to_output_row(ubatch, n_outputs_prev); @@ -3089,6 +3114,16 @@ float * llama_get_logits_ith(llama_context * ctx, int32_t i) { return res; } +float * llama_get_mtp_logits(llama_context * ctx) { + ctx->synchronize(); + + return ctx->get_mtp_logits(); +} + +int64_t llama_get_mtp_n_vocab(llama_context * ctx) { + return ctx->get_mtp_n_vocab(); +} + float * llama_get_embeddings(llama_context * ctx) { ctx->synchronize(); diff --git a/src/llama-context.h b/src/llama-context.h index e0d0085c1c..83481e97dc 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -74,6 +74,8 @@ struct llama_context { float * get_logits(); float * get_logits_ith(int32_t i); + float * get_mtp_logits(); + int64_t get_mtp_n_vocab() const { return mtp_n_vocab; } float * get_embeddings(); float * get_embeddings_ith(int32_t i); @@ -268,6 +270,11 @@ private: // decode output (2-dimensional array: [n_outputs][n_vocab]) buffer_view logits = {nullptr, 0}; + // MTP draft logits — with FastMTP, reduced to top-K tokens (e.g., 32K vs 248K) + std::vector mtp_logits_buf; + bool mtp_logits_valid = false; + int64_t mtp_n_vocab = 0; + // embeddings output (2-dimensional array: [n_outputs][n_embd]) // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE buffer_view embd = {nullptr, 0}; diff --git a/src/llama-graph.h b/src/llama-graph.h index 4855685ef7..cadcf56104 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -662,6 +662,10 @@ public: ggml_tensor * t_embd = nullptr; ggml_tensor * t_embd_pooled = nullptr; + // MTP (Multi-Token Prediction) output nodes + ggml_tensor * t_logits_mtp = nullptr; // [n_vocab, n_tokens] draft logits from MTP head + ggml_tensor * t_embd_mtp = nullptr; // [n_embd, n_tokens] hidden state from MTP head + std::map t_sampled_logits; std::map t_candidates; std::map t_sampled; diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index 6e8413f493..933a4d85da 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -10,6 +10,7 @@ #include #include #include + #include // @@ -163,12 +164,55 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos const auto & cell = cells[tail_id]; // partial intersection is invalid if it includes the final pos if (0 < p0 && p0 <= cell.pos && p1 > cell.pos) { - //printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false, p0 = %d, cell.pos = %d, p1 = %d\n", p0, cell.pos, p1); - return false; + // for speculative decoding, search for the best checkpoint to roll back to. + // Prefer exact match at p0-1, but accept the closest position < p0. + // For MTP with 2-token batches, checkpoint may be at p0-2 (before the batch) + // since both tokens are processed atomically. + int32_t best_cell = -1; + llama_pos best_pos = -1; + fprintf(stderr, "[MTP-SEQRM] seq_id=%d, p0=%d, p1=%d, tail_pos=%d, searching for checkpoint at pos<=%d\n", + (int)seq_id, (int)p0, (int)p1, (int)cell.pos, (int)(p0-1)); + for (uint32_t i = 0; i < size; ++i) { + if (cells[i].has_seq_id(seq_id)) { + fprintf(stderr, "[MTP-SEQRM] cell[%d] pos=%d\n", i, (int)cells[i].pos); + // Find the closest checkpoint at or below p0-1 + if (cells[i].pos < p0 && cells[i].pos > best_pos) { + best_pos = cells[i].pos; + best_cell = i; + } + } + } + fflush(stderr); + + if (best_cell >= 0) { + fprintf(stderr, "[MTP-SEQRM] FOUND checkpoint at cell[%d] pos=%d (target was %d) — rolling back\n", + best_cell, (int)best_pos, (int)(p0-1)); + fflush(stderr); + tail_id = best_cell; + } else { + fprintf(stderr, "[MTP-SEQRM] NO checkpoint found — seq_rm FAILED\n"); + fflush(stderr); + return false; + } } // invalidate tails which will be cleared if (p0 <= cell.pos && cell.pos < p1) { - tail_id = -1; + if (p0 == 0) { + tail_id = -1; + } else { + // Search for the best remaining cell after removal + int32_t new_tail = -1; + llama_pos max_pos = -1; + for (uint32_t i = 0; i < size; ++i) { + if (cells[i].has_seq_id(seq_id) && cells[i].pos < p0) { + if (cells[i].pos > max_pos) { + max_pos = cells[i].pos; + new_tail = i; + } + } + } + tail_id = new_tail; + } } } } else { @@ -184,6 +228,11 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos if (seq_id < 0) { cells[i].seq_id.clear(); } else if (cells[i].has_seq_id(seq_id)) { + if (p0 > 0 && p1 == std::numeric_limits::max()) { + // partial removal: just move the position back + cells[i].pos = p0 - 1; + continue; + } cells[i].seq_id.erase(seq_id); } else { continue; @@ -224,25 +273,42 @@ void llama_memory_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id } if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) { - auto & tail_src = cells[seq_id_src]; - auto & tail_dst = cells[seq_id_dst]; - if (tail_dst.tail >= 0) { + auto & tail_src_meta = cells[seq_id_src]; + auto & tail_dst_meta = cells[seq_id_dst]; + + if (tail_dst_meta.tail >= 0) { // clear destination seq_id if it wasn't empty - auto & cell_dst = cells[tail_dst.tail]; - - cell_dst.seq_id.erase(seq_id_dst); - tail_dst.tail = -1; - if (cell_dst.seq_id.empty()) { - cell_dst.pos = -1; - cell_dst.src = -1; - used -= 1; - } + seq_rm(seq_id_dst, -1, -1); } - if (tail_src.tail >= 0) { - auto & cell_src = cells[tail_src.tail]; - cell_src.seq_id.insert(seq_id_dst); - tail_dst.tail = tail_src.tail; + if (tail_src_meta.tail >= 0) { + auto & cell_src = cells[tail_src_meta.tail]; + + // For recurrent models, we must copy the state to a new cell + // Otherwise, both sequences would share the same mutable state + uint32_t next_empty_cell = size; + for (uint32_t i = head; i < head + size; ++i) { + uint32_t idx = i % size; + if (cells[idx].is_empty()) { + next_empty_cell = idx; + break; + } + } + + if (next_empty_cell != size) { + auto & empty_cell = cells[next_empty_cell]; + + // Copy tensors data + copy_cell(tail_src_meta.tail, next_empty_cell); + + empty_cell.pos = cell_src.pos; + empty_cell.src = next_empty_cell; // results in a copy in the graph if needed + empty_cell.seq_id.insert(seq_id_dst); + tail_dst_meta.tail = next_empty_cell; + used += 1; + } else { + LLAMA_LOG_ERROR("%s: failed to find available cell for copy\n", __func__); + } } } } @@ -367,6 +433,61 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const { return result; } +void llama_memory_recurrent::copy_cell(int32_t i_src, int32_t i_dst) { + if (i_src == i_dst || i_src < 0 || i_dst < 0) { + return; + } + + fprintf(stderr, "[MTP-COPYCELL] copy_cell(%d -> %d), n_layer=%d\n", i_src, i_dst, (int)hparams.n_layer); + fflush(stderr); + + // Copy recurrent state via GPU-to-GPU (ggml_backend_tensor_copy). + // Views created with no_alloc=true have buffer=NULL. We must set + // the buffer to the parent tensor's buffer for the copy to work. + ggml_init_params params = { + /*.mem_size =*/ size_t(2*ggml_tensor_overhead()), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + for (uint32_t il = 0; il < hparams.n_layer; ++il) { + if (r_l[il]) { + ggml_context * ctx = ggml_init(params); + // Tensor is 1D: ne[0] = n_embd_r * size. Each cell = ne[0]/size elements. + // ggml_view_1d takes ELEMENT count, not byte count! + int64_t cell_elements = r_l[il]->ne[0] / size; + size_t cell_bytes = ggml_row_size(r_l[il]->type, cell_elements); + ggml_tensor * src_v = ggml_view_1d(ctx, r_l[il], cell_elements, (size_t)i_src * cell_bytes); + ggml_tensor * dst_v = ggml_view_1d(ctx, r_l[il], cell_elements, (size_t)i_dst * cell_bytes); + src_v->buffer = r_l[il]->buffer; + dst_v->buffer = r_l[il]->buffer; + ggml_backend_tensor_copy(src_v, dst_v); + ggml_free(ctx); + } + if (s_l[il]) { + ggml_context * ctx = ggml_init(params); + int64_t cell_elements = s_l[il]->ne[0] / size; + size_t cell_bytes = ggml_row_size(s_l[il]->type, cell_elements); + ggml_tensor * src_v = ggml_view_1d(ctx, s_l[il], cell_elements, (size_t)i_src * cell_bytes); + ggml_tensor * dst_v = ggml_view_1d(ctx, s_l[il], cell_elements, (size_t)i_dst * cell_bytes); + src_v->buffer = s_l[il]->buffer; + dst_v->buffer = s_l[il]->buffer; + ggml_backend_tensor_copy(src_v, dst_v); + ggml_free(ctx); + } + } +} + +int llama_memory_recurrent::get_cell_count(llama_seq_id seq_id) const { + int count = 0; + for (uint32_t i = 0; i < size; ++i) { + if (cells[i].has_seq_id(seq_id)) { + count++; + } + } + return count; +} + std::map llama_memory_recurrent::memory_breakdown() const { std::map ret; for (const auto & [_, buf] : ctxs_bufs) { @@ -453,6 +574,10 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) { const uint32_t n_seq_tokens = ubatch.n_seq_tokens; const uint32_t n_seqs = ubatch.n_seqs; + fprintf(stderr, "[MTP-FINDSLOT] find_slot: n_seq_tokens=%d, n_seqs=%d, size=%d, used=%d, head=%d\n", + (int)n_seq_tokens, (int)n_seqs, (int)size, (int)used, (int)head); + fflush(stderr); + // if we have enough unused cells before the current head -> // better to start searching from the beginning of the cache, hoping to fill it if (head > used + 2*n_seqs) { @@ -551,10 +676,35 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) { if (seq_meta.tail >= 0) { auto & orig_cell = cells[seq_meta.tail]; empty_cell.pos = orig_cell.pos; - empty_cell.src = orig_cell.src; - orig_cell.seq_id.erase(seq_id); + empty_cell.src = seq_meta.tail; // the data should be copied from the previous tail + + // Copy state data + copy_cell(seq_meta.tail, next_empty_cell); + + // Keep history of previous states for rollback (up to 8 cells per sequence) + if (get_cell_count(seq_id) < 8 && used < size * 0.9) { + // Do not erase seq_id from orig_cell to keep it as a checkpoint + } else { + // Erase oldest history point for this sequence + int32_t oldest_cell = -1; + llama_pos min_pos = std::numeric_limits::max(); + for (uint32_t i = 0; i < size; ++i) { + if (cells[i].has_seq_id(seq_id) && cells[i].pos < min_pos) { + min_pos = cells[i].pos; + oldest_cell = i; + } + } + + if (oldest_cell >= 0) { + cells[oldest_cell].seq_id.erase(seq_id); + if (cells[oldest_cell].is_empty()) { + cells[oldest_cell].pos = -1; + cells[oldest_cell].src = -1; + used--; + } + } + } empty_cell.seq_id.insert(seq_id); // will be overwritten - GGML_ASSERT(!orig_cell.is_empty()); // has at least one remaining seq_id } seq_meta.tail = next_empty_cell; // find next empty cell @@ -566,6 +716,12 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) { if (cell.is_empty()) { break; } } } + } else { + // Sequence owns its cell — no checkpoint for MTP. + // For hybrid models with MTP, checkpointing is too expensive (GPU copy + // of all 33 recurrent layers every step). Instead, we let the recurrent + // state accumulate draft tokens on rejection. The 9 attention layers + // handle rollback via KV cache seq_rm, which partially compensates. } if (min > seq_meta.tail) { min = seq_meta.tail; } if (max < seq_meta.tail) { max = seq_meta.tail; } diff --git a/src/llama-memory-recurrent.h b/src/llama-memory-recurrent.h index 47f01d7391..b6b5d6cfbd 100644 --- a/src/llama-memory-recurrent.h +++ b/src/llama-memory-recurrent.h @@ -65,6 +65,10 @@ public: void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override; void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override; + // cell management + void copy_cell(int32_t i_src, int32_t i_dst); + int get_cell_count(llama_seq_id seq_id) const; + uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot()) uint32_t size = 0; // total number of cells, shared across all sequences uint32_t used = 0; // used cells (i.e. at least one seq_id) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index f8caad2889..5360b149cb 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2408,16 +2408,29 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); - // Mark recurrent layers (linear attention layers) + // NextN/MTP parameters + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + + // The total n_layer includes MTP layers appended after main layers. + // Determine the number of main transformer layers for type detection. + const uint32_t n_main_layers = hparams.n_layer - hparams.nextn_predict_layers; + + // Mark recurrent layers (linear attention layers) — main layers only + // MTP layers use full attention, so they are NOT recurrent { uint32_t full_attn_interval = 4; ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0); + if (i < n_main_layers) { + hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0); + } else { + // MTP layers use full attention (not recurrent) + hparams.recurrent_layer_arr[i] = false; + } } } - switch (hparams.n_layer) { + switch (n_main_layers) { case 24: type = hparams.n_embd == 1024 ? LLM_TYPE_0_8B : LLM_TYPE_2B; break; case 32: type = hparams.n_embd == 2560 ? LLM_TYPE_4B : LLM_TYPE_9B; break; case 64: type = LLM_TYPE_27B; break; @@ -7277,39 +7290,67 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const int64_t value_dim = head_v_dim * n_v_heads; const int64_t conv_dim = key_dim * 2 + value_dim; + const uint32_t n_main_layers = n_layer - hparams.nextn_predict_layers; + for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); + const bool is_mtp_layer = (static_cast(i) >= n_main_layers); - if (!hparams.is_recurrent(i)) { - // Attention layers - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head * 2 }, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + if (is_mtp_layer) { + // MTP layer: nextn-specific tensors + standard attention + standard FFN + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, 0); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, 0); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, 0); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); - // Q/K normalization for attention layers - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); + // MTP layer uses same gated attention as main model (joint QG projection) + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head * 2 }, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); + + // MTP layer uses standard dense FFN + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); } else { - // Linear attention (gated delta net) specific tensors - // Create tensors with calculated dimensions - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); - layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); - layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0); - layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0); - layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0); - layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", i), { n_embd, n_v_heads }, 0); - layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", i), { n_embd, n_v_heads }, 0); - layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0); - layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0); - } + // Main transformer layers + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + if (!hparams.is_recurrent(i)) { + // Full attention layers (joint QG projection + gated attention) + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head * 2 }, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); + } else { + // Linear attention (gated delta net) specific tensors + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0); + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0); + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0); + layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", i), { n_embd, n_v_heads }, 0); + layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", i), { n_embd, n_v_heads }, 0); + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0); + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0); + } + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } } } break; case LLM_ARCH_MIMO2: @@ -8076,6 +8117,18 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { // Use hybrid-iswa for hybrid models with SWA + // For MTP speculative decoding, we need extra recurrent state + // cells for checkpoint/restore. Each sequence needs at least + // 1 active cell + 1 checkpoint cell per MTP draft step. + const uint32_t n_mtp = hparams.nextn_predict_layers; + // For MTP: need room for active cell + checkpoint cells. + // With size=4: active(1) + checkpoint(1) + room(2) ensures + // can_checkpoint (used < size*0.9 = 3.6) can fire even with 3 cells in use. + // No checkpoint overhead — rs_size = 1 per sequence. + // MTP uses single-batch decode without recurrent state rollback. + const uint32_t rs_per_seq = 1; + const uint32_t rs_size = std::max((uint32_t) 1, cparams.n_seq_max * rs_per_seq); + res = new llama_memory_hybrid_iswa( /* model */ *this, /* attn_type_k */ params.type_k, @@ -8087,13 +8140,16 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, /* attn_n_pad */ 1, /* recurrent_type_r */ GGML_TYPE_F32, /* recurrent_type_s */ GGML_TYPE_F32, - /* recurrent_rs_size */ std::max((uint32_t) 1, cparams.n_seq_max), + /* recurrent_rs_size */ rs_size, /* n_seq_max */ cparams.n_seq_max, /* offload */ cparams.offload_kqv, /* unified */ cparams.kv_unified, /* filter_attn */ std::move(filter_attn), /* filter_recr */ std::move(filter_recr)); } else { + const uint32_t rs_per_seq2 = 1; + const uint32_t rs_size2 = std::max((uint32_t) 1, cparams.n_seq_max * rs_per_seq2); + res = new llama_memory_hybrid( /* model */ *this, /* attn_type_k */ params.type_k, @@ -8105,7 +8161,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, /* attn_swa_type */ hparams.swa_type, /* recurrent_type_k */ GGML_TYPE_F32, /* recurrent_type_v */ GGML_TYPE_F32, - /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max), + /* recurrent_kv_size */ rs_size2, /* n_seq_max */ cparams.n_seq_max, /* offload */ cparams.offload_kqv, /* unified */ cparams.kv_unified, @@ -8760,6 +8816,10 @@ int32_t llama_model_n_swa(const llama_model * model) { return model->hparams.n_swa; } +int32_t llama_model_n_mtp_layers(const llama_model * model) { + return model->hparams.nextn_predict_layers; +} + uint32_t llama_model_n_cls_out(const struct llama_model * model) { return model->hparams.n_cls_out; } diff --git a/src/models/models.h b/src/models/models.h index a86b2b1ebd..bf8af823b7 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -597,7 +597,14 @@ private: ggml_tensor * input, int il); + // Build the MTP (Multi-Token Prediction) head with standard transformer block + void build_mtp_head(llm_graph_input_mem_hybrid * inp, ggml_tensor * inp_pos, int * sections); + const llama_model & model; + + // Unfiltered hidden state from last main layer (before inp_out_ids filter). + // Used by MTP head for attention KV cache population across all batch tokens. + ggml_tensor * mtp_inp_hidden = nullptr; }; // TODO: derive llm_build_delta_net_base instead diff --git a/src/models/qwen35.cpp b/src/models/qwen35.cpp index e0e48d2a4f..ac81da7563 100644 --- a/src/models/qwen35.cpp +++ b/src/models/qwen35.cpp @@ -23,7 +23,10 @@ llm_build_qwen35::llm_build_qwen35(const llama_model & model, const llm_graph_pa ggml_tensor * inp_pos = build_inp_pos(); ggml_tensor * inp_out_ids = build_inp_out_ids(); - for (int il = 0; il < n_layer; ++il) { + // Only process main transformer layers (skip MTP layers appended at the end) + const int n_transformer_layers = n_layer - hparams.nextn_predict_layers; + + for (int il = 0; il < n_transformer_layers; ++il) { ggml_tensor * inpSA = inpL; cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); @@ -40,35 +43,43 @@ llm_build_qwen35::llm_build_qwen35(const llama_model & model, const llm_graph_pa cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il); } - if (il == n_layer - 1 && inp_out_ids) { - cur = ggml_get_rows(ctx0, cur, inp_out_ids); - inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + // For the last main layer, process BOTH filtered and unfiltered paths: + // - Unfiltered: saved for MTP head (needs all batch tokens for attention KV cache) + // - Filtered: used for main model logits (only output tokens) + if (il == n_transformer_layers - 1 && inp_out_ids) { + // First: compute full layer output without filtering (for MTP) + ggml_tensor * full_residual = ggml_add(ctx0, cur, inpSA); + ggml_tensor * full_ffn_res = full_residual; + ggml_tensor * full_post_norm = build_norm(full_residual, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, il); + ggml_tensor * full_ffn = build_layer_ffn(full_post_norm, il); + mtp_inp_hidden = ggml_add(ctx0, full_ffn, full_ffn_res); + mtp_inp_hidden = build_cvec(mtp_inp_hidden, il); + cb(mtp_inp_hidden, "mtp_inp_hidden", il); + + // Second: filter for main model logits + cur = ggml_get_rows(ctx0, mtp_inp_hidden, inp_out_ids); + inpL = cur; + } else { + // Residual connection + cur = ggml_add(ctx0, cur, inpSA); + cb(cur, "attn_residual", il); + + ggml_tensor * ffn_residual = cur; + + ggml_tensor * attn_post_norm = build_norm(cur, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, il); + cb(attn_post_norm, "attn_post_norm", il); + + cur = build_layer_ffn(attn_post_norm, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_residual); + cb(cur, "post_ffn", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + inpL = cur; } - - // Residual connection - cur = ggml_add(ctx0, cur, inpSA); - cb(cur, "attn_residual", il); - - // Save the tensor before post-attention norm for residual connection - ggml_tensor * ffn_residual = cur; - - // Post-attention norm - ggml_tensor * attn_post_norm = build_norm(cur, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, il); - cb(attn_post_norm, "attn_post_norm", il); - - // Dense FFN layer - without residual connection - cur = build_layer_ffn(attn_post_norm, il); - cb(cur, "ffn_out", il); - - // Residual connection for FFN - add to the tensor from before post_attention_layernorm - cur = ggml_add(ctx0, cur, ffn_residual); - cb(cur, "post_ffn", il); - - cur = build_cvec(cur, il); - cb(cur, "l_out", il); - - // Input for next layer - inpL = cur; } cur = inpL; @@ -85,6 +96,11 @@ llm_build_qwen35::llm_build_qwen35(const llama_model & model, const llm_graph_pa res->t_logits = cur; ggml_build_forward_expand(gf, cur); + + // Build MTP head if nextn_predict_layers > 0 + if (hparams.nextn_predict_layers > 0) { + build_mtp_head(inp, inp_pos, sections); + } } std::pair llm_build_qwen35::build_qkvz( @@ -382,3 +398,145 @@ ggml_tensor * llm_build_qwen35::build_layer_ffn(ggml_tensor * cur, const int il) return cur; } + +void llm_build_qwen35::build_mtp_head( + llm_graph_input_mem_hybrid * inp, + ggml_tensor * inp_pos, + int * sections) { + // MTP (Multi-Token Prediction) head for dense Qwen 3.5 + // + // The MTP module takes the hidden state from the last main transformer layer + // and uses the model's built-in MTP head to produce draft logits. + // + // MTP forward pass: + // 1. sampled_token = argmax(main_logits) + // 2. emb = embed_tokens(sampled_token) + // 3. h_norm = RMSNorm(hidden_state, hnorm) + // 4. e_norm = RMSNorm(emb, enorm) + // 5. combined = eh_proj(concat(e_norm, h_norm)) + // 6. Standard self-attention (Q/K/V with Q/K norms + RoPE) + // 7. Standard FFN (gate_proj + up_proj → SiLU → down_proj) + // 8. logits = lm_head(RMSNorm(output, mtp_norm)) + + const int n_transformer_layers = n_layer - hparams.nextn_predict_layers; + const int64_t n_embd_head = hparams.n_embd_head_v(); + + // Use unfiltered hidden state for MTP (needs all batch tokens for attention KV cache) + ggml_tensor * hidden_state = mtp_inp_hidden ? mtp_inp_hidden : res->t_embd; + GGML_ASSERT(hidden_state != nullptr); + + // Get logits for greedy token selection. + // If no filtering occurred (generation), reuse main logits to avoid expensive lm_head recomputation. + // If filtering occurred (prompt processing), recompute from unfiltered hidden state. + ggml_tensor * greedy_logits; + if (!mtp_inp_hidden || mtp_inp_hidden == res->t_embd) { + // No filtering — main logits already cover all tokens + greedy_logits = res->t_logits; + } else { + // Filtered — recompute logits from unfiltered hidden state + ggml_tensor * full_normed = build_norm(hidden_state, model.output_norm, nullptr, LLM_NORM_RMS, -1); + greedy_logits = build_lora_mm(model.output, full_normed); + } + + ggml_tensor * greedy_tokens = ggml_argmax(ctx0, greedy_logits); + cb(greedy_tokens, "mtp_greedy_tokens", -1); + + ggml_tensor * mtp_hidden = hidden_state; + + for (uint32_t k = 0; k < hparams.nextn_predict_layers; ++k) { + const int il = n_transformer_layers + k; + const auto & layer = model.layers[il]; + + if (layer.nextn.eh_proj == nullptr) { + continue; + } + + // Step 1: Get token embedding (shared with main model) + ggml_tensor * tok_embd = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd; + ggml_tensor * emb = ggml_get_rows(ctx0, tok_embd, greedy_tokens); + cb(emb, "mtp_token_embd", il); + + // Step 2: Normalize hidden state and embedding + ggml_tensor * h_norm = build_norm(mtp_hidden, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il); + cb(h_norm, "mtp_hnorm", il); + + ggml_tensor * e_norm = build_norm(emb, layer.nextn.enorm, nullptr, LLM_NORM_RMS, il); + cb(e_norm, "mtp_enorm", il); + + // Step 3: Concatenate and project + ggml_tensor * concat = ggml_concat(ctx0, e_norm, h_norm, 0); // [2*n_embd, n_tokens] + cb(concat, "mtp_concat", il); + + ggml_tensor * cur = build_lora_mm(layer.nextn.eh_proj, concat); + cb(cur, "mtp_projected", il); + + // Step 4: Full self-attention for the MTP head (same architecture as main model attention layers) + // The MTP layer has its own KV cache (allocated because is_recurrent(il) = false). + // We use the unfiltered hidden state (mtp_inp_hidden) so token count matches inp_pos. + { + ggml_tensor * attn_residual = cur; + + cur = build_norm(cur, layer.attn_norm, nullptr, LLM_NORM_RMS, il); + + cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il); + + cur = ggml_add(ctx0, cur, attn_residual); + } + + // Step 5: Post-attention norm + FFN + { + ggml_tensor * ffn_residual = cur; + + ggml_tensor * attn_post_norm = build_norm(cur, layer.attn_post_norm, nullptr, LLM_NORM_RMS, il); + cb(attn_post_norm, "mtp_attn_post_norm", il); + + // Standard dense FFN (same as main model FFN) + cur = build_ffn(attn_post_norm, + layer.ffn_up, NULL, layer.ffn_up_s, + layer.ffn_gate, NULL, layer.ffn_gate_s, + layer.ffn_down, NULL, layer.ffn_down_s, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "mtp_ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_residual); + cb(cur, "mtp_post_ffn", il); + } + + mtp_hidden = cur; + + // Step 6: Final norm + LM head for draft logits + ggml_tensor * mtp_normed; + if (layer.nextn.shared_head_norm != nullptr) { + mtp_normed = build_norm(mtp_hidden, layer.nextn.shared_head_norm, nullptr, LLM_NORM_RMS, il); + } else { + // Use main model's output norm + mtp_normed = build_norm(mtp_hidden, model.output_norm, nullptr, LLM_NORM_RMS, il); + } + cb(mtp_normed, "mtp_head_norm", il); + + ggml_tensor * lm_head = layer.nextn.shared_head_head ? layer.nextn.shared_head_head : model.output; + + // FastMTP: vocabulary trimming — only compute logits for top-K tokens + // instead of full 248K vocabulary. Most tokenizers order by frequency, + // so tokens 0..K-1 cover ~95%+ of generated code tokens. + // This reduces the lm_head matmul from [4096,248K] to [4096,32K] (~8x faster). + const int64_t mtp_vocab_size = std::min(lm_head->ne[1], (int64_t)32768); + ggml_tensor * lm_head_reduced = ggml_view_2d(ctx0, lm_head, + lm_head->ne[0], mtp_vocab_size, lm_head->nb[1], 0); + ggml_tensor * mtp_logits = build_lora_mm(lm_head_reduced, mtp_normed); + cb(mtp_logits, "mtp_logits", il); + + // Store MTP outputs in graph result + res->t_embd_mtp = mtp_hidden; + res->t_logits_mtp = mtp_logits; + + // For recursive MTP (multiple layers), feed greedy tokens forward + if (k + 1 < hparams.nextn_predict_layers) { + greedy_tokens = ggml_argmax(ctx0, mtp_logits); + cb(greedy_tokens, "mtp_greedy_next", il); + } + + ggml_build_forward_expand(gf, mtp_logits); + } +} diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 9de554e900..1c0151252d 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -149,6 +149,15 @@ struct server_slot { llama_token sampled; // in speculative mode, this is the last accepted token llama_tokens drafted; + // Inline MTP (Multi-Token Prediction) state. + // Instead of using the speculative framework (which has M-RoPE and SSM + // rollback issues), we propose one draft token from MTP logits and verify + // it in the next decode step. No seq_rm or rollback needed. + llama_token mtp_draft_token = -1; // proposed draft token (-1 = none) + int mtp_i_batch = -1; // batch index of the draft token + bool mtp_pending = false; // true when draft is in the batch awaiting verification + bool mtp_cooldown = false; // skip MTP proposal for one iteration after draft processing + // stats size_t n_sent_text = 0; // number of sent text character @@ -179,6 +188,10 @@ struct server_slot { drafted.clear(); i_batch_dft.clear(); + mtp_draft_token = -1; + mtp_i_batch = -1; + mtp_pending = false; + mtp_cooldown = false; generated_tokens.clear(); generated_token_probs.clear(); json_schema = json(); @@ -753,11 +766,24 @@ private: slots.clear(); - const bool can_spec = common_speculative_is_compat(ctx); + bool can_spec = common_speculative_is_compat(ctx); if (!can_spec) { SRV_WRN("%s", "speculative decoding not supported by this context\n"); } + // Auto-detect MTP: if model has MTP layers and no speculative type + // is explicitly set, auto-enable MTP speculative decoding. + if (params_base.speculative.type == COMMON_SPECULATIVE_TYPE_NONE) { + const int32_t n_mtp = llama_model_n_mtp_layers(llama_get_model(ctx)); + if (n_mtp > 0 && can_spec) { + SRV_INF("model has %d MTP layer(s) — auto-enabling MTP speculative decoding\n", n_mtp); + params_base.speculative.type = COMMON_SPECULATIVE_TYPE_MTP; + params_base.speculative.n_max = 1; // MTP-1: one draft token per step + } else if (n_mtp > 0) { + SRV_INF("model has %d MTP layer(s) but speculative context not compatible\n", n_mtp); + } + } + // initialize slots for (int i = 0; i < params_base.n_parallel; i++) { server_slot slot; @@ -2066,42 +2092,34 @@ private: } // generate draft tokens in speculative decoding mode - // TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK] - // perform the speculative drafting for all sequences at the same time in a single batch const int n_draft_max = slot.get_n_draft_max(); + const bool is_hybrid = llama_model_is_hybrid(model); + if (n_draft_max > 0) { + // Standard speculative decoding for non-hybrid models if (mctx) { - // we should never reach this, as speculative is automatically disabled if mmproj is loaded GGML_ABORT("not supported by multimodal"); } const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens(); - const auto & params_spec = slot.task->params.speculative; llama_tokens draft = common_speculative_draft(slot.spec, params_spec, cached_text_tokens, slot.sampled); if (draft.size() > (size_t) n_draft_max) { - SLT_WRN(slot, "draft size %d exceeds max %d, truncating\n", (int) draft.size(), n_draft_max); draft.resize(n_draft_max); } - // add the sampled token to the batch slot.i_batch_dft.push_back(batch.n_tokens); common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true); slot.prompt.tokens.push_back(slot.sampled); if (slot.task->params.speculative.n_min > (int) draft.size()) { - SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.task->params.speculative.n_min); - // fallback to normal decoding slot.i_batch = slot.i_batch_dft[0]; slot.drafted.clear(); slot.i_batch_dft.clear(); } else { - // keep track of total number of drafted tokens tested slot.n_draft_total += draft.size(); - - // add all drafted tokens to the batch for (size_t i = 0; i < draft.size(); i++) { slot.i_batch_dft.push_back(batch.n_tokens); common_batch_add(batch, draft[i], slot.prompt.tokens.pos_next(), { slot.id }, true); @@ -2114,7 +2132,6 @@ private: slot.i_batch = batch.n_tokens; common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true); - slot.prompt.tokens.push_back(slot.sampled); SLT_DBG(slot, "slot decode token, n_ctx = %d, n_tokens = %d, truncated = %d\n", @@ -2821,7 +2838,7 @@ private: slot.state = SLOT_STATE_GENERATING; if (slot.can_speculate()) { - common_speculative_begin(slot.spec, slot.prompt.tokens.get_text_tokens()); + common_speculative_begin(slot.spec, slot.prompt.tokens.get_text_tokens()); } } else if (slot.state != SLOT_STATE_GENERATING) { continue; // continue loop of slots @@ -2833,13 +2850,135 @@ private: const int tok_idx = slot.i_batch - i; + // --- Two-phase MTP: verify draft after decoding sampled token --- + // The sampled token was decoded alone (1-token batch). + // Now verify the draft against main model logits. + // If accepted: decode draft in a second pass (correct checkpoint position). + // If rejected: skip draft decode (no seq_rm needed — state is clean). + if (slot.mtp_pending) { + // Sample from main model at the sampled token's position + llama_token verified = common_sampler_sample(slot.smpl.get(), ctx, tok_idx); + common_sampler_accept(slot.smpl.get(), verified, true); + + slot.n_draft_total += 1; + + const int64_t t_current = ggml_time_us(); + if (slot.n_decoded == 0) { + slot.t_start_generation = t_current; + slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3; + metrics.on_prompt_eval(slot); + } + + if (verified == slot.mtp_draft_token) { + // ACCEPTED — decode the draft token in a second pass. + // At this point the recurrent state checkpoint is at the + // correct position (after sampled, before draft). + slot.n_draft_accepted += 1; + slot.n_decoded += 1; + + // Output the verified/sampled token + completion_token_output result_sampled; + result_sampled.tok = verified; + result_sampled.text_to_send = common_token_to_piece(ctx, result_sampled.tok, accept_special_token(slot, result_sampled.tok)); + result_sampled.prob = 1.0f; + + bool should_stop = !process_token(result_sampled, slot); + if (should_stop) { + slot.print_timings(); + send_final_response(slot); + metrics.on_prediction(slot); + slot.release(); + slot.mtp_pending = false; + slot.mtp_draft_token = -1; + continue; + } + + // Build a 1-token batch for the draft token and decode it + llama_batch draft_batch = llama_batch_init(1, 0, 1); + common_batch_clear(draft_batch); + common_batch_add(draft_batch, slot.mtp_draft_token, slot.prompt.tokens.pos_next(), { slot.id }, true); + slot.prompt.tokens.push_back(slot.mtp_draft_token); + + fprintf(stderr, "[MTP-2PHASE] draft ACCEPTED (verified=%d), decoding draft at pos %d\n", + (int)verified, (int)(slot.prompt.tokens.pos_next() - 1)); + fflush(stderr); + + int ret2 = llama_decode(ctx, draft_batch); + llama_batch_free(draft_batch); + + if (ret2 != 0) { + fprintf(stderr, "[MTP-2PHASE] ERROR: draft decode failed with ret=%d\n", ret2); + fflush(stderr); + // Fall through — state is still valid at pre-draft position + } else { + // Sample bonus token from draft's logits (the NEXT-NEXT token) + llama_token bonus = common_sampler_sample(slot.smpl.get(), ctx, 0); + common_sampler_accept(slot.smpl.get(), bonus, true); + slot.n_decoded += 1; + slot.sampled = bonus; + + // Output the bonus token (draft/verified was already output above) + completion_token_output result_bonus; + result_bonus.tok = bonus; + result_bonus.text_to_send = common_token_to_piece(ctx, result_bonus.tok, accept_special_token(slot, result_bonus.tok)); + result_bonus.prob = 1.0f; + + if (!process_token(result_bonus, slot)) { + slot.print_timings(); + send_final_response(slot); + metrics.on_prediction(slot); + slot.release(); + slot.mtp_pending = false; + slot.mtp_draft_token = -1; + continue; + } + + // Inform speculative state about the acceptance + common_speculative_accept(slot.spec, 1); + } + } else { + // REJECTED — draft was never decoded, state is clean. + // No seq_rm needed! + fprintf(stderr, "[MTP-2PHASE] draft REJECTED (verified=%d, draft=%d)\n", + (int)verified, (int)slot.mtp_draft_token); + fflush(stderr); + + slot.sampled = verified; + slot.n_decoded += 1; + + // Output the verified token + completion_token_output result_verified; + result_verified.tok = verified; + result_verified.text_to_send = common_token_to_piece(ctx, result_verified.tok, accept_special_token(slot, result_verified.tok)); + result_verified.prob = 1.0f; + + if (!process_token(result_verified, slot)) { + slot.print_timings(); + send_final_response(slot); + metrics.on_prediction(slot); + slot.release(); + slot.mtp_pending = false; + slot.mtp_draft_token = -1; + continue; + } + } + + slot.t_token_generation = std::max(1, t_current - slot.t_start_generation) / 1e3; + slot.mtp_pending = false; + slot.mtp_draft_token = -1; + slot.mtp_i_batch = -1; + slot.i_batch = -1; + + continue; // done with this slot for this decode step + } + + // --- Normal sampling (no pending MTP draft) --- llama_token id = common_sampler_sample(slot.smpl.get(), ctx, tok_idx); slot.i_batch = -1; common_sampler_accept(slot.smpl.get(), id, true); - // here we have synchronized the llama_context (due to the sampling above), so we can do time measurement const int64_t t_current = ggml_time_us(); slot.n_decoded += 1; @@ -2855,14 +2994,13 @@ private: completion_token_output result; result.tok = id; result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok)); - result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs + result.prob = 1.0f; if (slot.task->params.sampling.n_probs > 0) { populate_token_probs(slot, result, slot.task->params.post_sampling_probs, params_base.special, tok_idx); } if (!process_token(result, slot)) { - // release slot because of stop condition slot.print_timings(); send_final_response(slot); metrics.on_prediction(slot); @@ -2904,7 +3042,15 @@ private: slot.prompt.tokens.insert({ids.begin(), ids.end() - 1}); slot.sampled = ids.back(); // last accepted token - llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.prompt.n_tokens(), -1); + // Remove rejected draft tokens from KV cache. + // For hybrid SSM/DeltaNet, seq_rm may fail. In that case, + // just log and continue — the recurrent state has the draft + // token baked in, but the checkpoint mechanism in + // llama-memory-recurrent.cpp should handle rollback internally + // during the next find_slot call. + if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.prompt.n_tokens(), -1)) { + SLT_WRN(slot, "seq_rm failed at pos %d\n", (int)slot.prompt.n_tokens()); + } for (size_t i = 0; i < ids.size(); ++i) { completion_token_output result;