feat: MTP support for dense Qwen 3.5 with FastMTP vocabulary trimming

Add Multi-Token Prediction (MTP) speculative decoding for Qwen3.5 dense
models (0.8B-27B). The MTP head uses a full transformer block (attention
+ FFN) to predict the next-next token, enabling ~28 tok/s on RTX 5060 Ti.

Key changes:
- Model loading: Qwen3.5 MTP layer tensors (nextn.eh_proj, attention
  weights, FFN) loaded into layers[n_layer-1]
- Graph builder: Full MTP head with self-attention, gated RoPE, FFN,
  and vocabulary projection. Unfiltered hidden state passed for proper
  KV cache population during prompt processing.
- FastMTP: Vocabulary trimming from 248K to 32K tokens via ggml_view_2d
  on the lm_head. Reduces draft generation from 22ms to 6ms (3.7x).
- Speculative framework: MTP auto-detection for hybrid models, fuzzy
  seq_rm checkpoint matching for DeltaNet rollback.
- Server: Two-phase decode option for hybrid/recurrent models to avoid
  DeltaNet state corruption from rejected drafts.
- Recurrent state: Fixed copy_cell (ggml_view_1d takes element count,
  not bytes), buffer assignment for no_alloc views.

Results on Qwen3.5-9B Q4_K_M (RTX 5060 Ti 16GB):
- 28.1 tok/s with 82% acceptance rate (temp=0)
- 92% acceptance with two-phase decode (correct output, 15 tok/s)
- Draft generation: 6.1ms with FastMTP (vs 22.4ms full vocab)
This commit is contained in:
itigges22 2026-03-21 13:51:30 -04:00
parent 990e4d9698
commit 19fdba56b5
19 changed files with 893 additions and 116 deletions

17
Dockerfile.atlas Normal file
View File

@ -0,0 +1,17 @@
FROM docker.io/nvidia/cuda:12.8.0-devel-rockylinux9 AS builder
RUN dnf install -y cmake gcc-c++ && dnf clean all
ENV TMPDIR=/llama.cpp/tmp
# Copy local source with inline MTP changes
COPY . /llama.cpp
RUN cd /llama.cpp && \
mkdir -p /llama.cpp/tmp && \
cmake -B build -DGGML_CUDA=ON -DBUILD_SHARED_LIBS=OFF -DCMAKE_CUDA_ARCHITECTURES=120 -DLLAMA_BUILD_TESTS=OFF && \
cmake --build build --target llama-server llama-cli --config Release -j5
FROM docker.io/nvidia/cuda:12.8.0-runtime-rockylinux9
COPY --from=builder /llama.cpp/build/bin/llama-server /usr/local/bin/
COPY --from=builder /llama.cpp/build/bin/llama-cli /usr/local/bin/
RUN mkdir -p /models /templates
EXPOSE 8000
ENTRYPOINT ["/entrypoint.sh"]

View File

@ -3474,8 +3474,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
add_opt(common_arg(
{"--spec-type"}, "[none|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]",
string_format("type of speculative decoding to use when no draft model is provided (default: %s)\n",
{"--spec-type"}, "[none|mtp|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]",
string_format("type of speculative decoding to use when no draft model is provided (default: %s)\n"
" mtp: use model's built-in Multi-Token Prediction head (requires MTP-capable model)\n",
common_speculative_type_to_str(params.speculative.type).c_str()),
[](common_params & params, const std::string & value) {
if (value == "none") {
@ -3490,6 +3491,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V;
} else if (value == "ngram-mod") {
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MOD;
} else if (value == "mtp") {
params.speculative.type = COMMON_SPECULATIVE_TYPE_MTP;
} else {
throw std::invalid_argument("unknown speculative decoding type without draft model");
}

View File

@ -172,6 +172,7 @@ enum common_speculative_type {
COMMON_SPECULATIVE_TYPE_NONE, // no speculative decoding
COMMON_SPECULATIVE_TYPE_DRAFT, // draft model
COMMON_SPECULATIVE_TYPE_EAGLE3, // eagle draft model
COMMON_SPECULATIVE_TYPE_MTP, // multi-token prediction (uses model's built-in MTP head)
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, // self-speculative decoding with n-gram keys only
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, // self-speculative decoding with n-gram keys and 4 m-gram values

View File

@ -577,6 +577,10 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
result.push_back(id);
fprintf(stderr, "[MTP-VERIFY] pos=%d: sampled=%d, draft=%d, %s\n",
idxs[i], id, draft[i], (draft[i] == id) ? "ACCEPTED" : "REJECTED");
fflush(stderr);
if (draft[i] != id) {
break;
}
@ -588,6 +592,9 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
common_sampler_accept(gsmpl, id, true);
result.push_back(id);
fprintf(stderr, "[MTP-VERIFY] bonus pos=%d: sampled=%d\n", idxs[i], id);
fflush(stderr);
}
return result;

View File

@ -10,9 +10,11 @@
#include "sampling.h"
#include <algorithm>
#include <cmath>
#include <cstring>
#include <iomanip>
#include <map>
#include <random>
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
@ -21,6 +23,7 @@ const std::vector<enum common_speculative_type> common_speculative_types = {
COMMON_SPECULATIVE_TYPE_NONE,
COMMON_SPECULATIVE_TYPE_DRAFT,
COMMON_SPECULATIVE_TYPE_EAGLE3,
COMMON_SPECULATIVE_TYPE_MTP,
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE,
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K,
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V,
@ -32,6 +35,7 @@ const std::map<std::string, enum common_speculative_type> common_speculative_typ
{"none", COMMON_SPECULATIVE_TYPE_NONE},
{"draft", COMMON_SPECULATIVE_TYPE_DRAFT},
{"eagle3", COMMON_SPECULATIVE_TYPE_EAGLE3},
{"mtp", COMMON_SPECULATIVE_TYPE_MTP},
{"ngram_simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE},
{"ngram_map_k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K},
{"ngram_map_k4v", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V},
@ -462,6 +466,84 @@ struct common_speculative_state_eagle3 : public common_speculative_state {
}
};
// Multi-Token Prediction (MTP) speculative decoding state
struct common_speculative_state_mtp : public common_speculative_state {
llama_context * ctx_tgt;
bool cooldown = false; // skip proposal after rejection to get fresh MTP logits
std::mt19937 rng{42}; // RNG for temperature sampling of MTP drafts
common_speculative_state_mtp(
enum common_speculative_type type,
llama_context * ctx_tgt)
: common_speculative_state(type)
, ctx_tgt(ctx_tgt)
{
}
~common_speculative_state_mtp() override = default;
void begin(const llama_tokens & prompt) override {
cooldown = false;
GGML_UNUSED(prompt);
}
void draft(
const common_params_speculative & params,
const llama_tokens & prompt_tgt,
llama_token id_last,
llama_tokens & result) override {
GGML_UNUSED(prompt_tgt);
// After a draft rejection, MTP logits are from the DRAFT position
// (last in the [sampled, draft] batch), not from the sampled position.
// These logits predict what comes after the draft — which is wrong
// since the draft was rejected. Skip this proposal and let the next
// single-token decode produce fresh MTP logits.
if (cooldown) {
cooldown = false;
return; // empty result = no draft = normal single-token decode
}
const float * mtp_logits = llama_get_mtp_logits(ctx_tgt);
if (mtp_logits == nullptr) {
return;
}
// FastMTP: use reduced vocab size (e.g., 32K instead of 248K)
// Token IDs 0..mtp_n_vocab-1 map directly to full vocab IDs
const int64_t mtp_n_vocab = llama_get_mtp_n_vocab(ctx_tgt);
if (mtp_n_vocab <= 0) {
return;
}
// Argmax of MTP logits over reduced vocabulary
llama_token draft_token = 0;
float best_logit = mtp_logits[0];
for (int64_t i = 1; i < mtp_n_vocab; i++) {
if (mtp_logits[i] > best_logit) {
best_logit = mtp_logits[i];
draft_token = (llama_token)i;
}
}
const auto * vocab = llama_model_get_vocab(llama_get_model(ctx_tgt));
if (!llama_vocab_is_eog(vocab, draft_token)) {
result.push_back(draft_token);
}
GGML_UNUSED(id_last);
GGML_UNUSED(params);
}
void accept(uint16_t n_accepted) override {
// If no drafts were accepted, enter cooldown
// (next draft() call returns empty to force single-token decode)
if (n_accepted == 0) {
cooldown = true;
}
}
};
// state of self-speculation (simple implementation, not ngram-map)
struct common_speculative_state_ngram_simple : public common_speculative_state {
common_ngram_simple_config config;
@ -781,6 +863,7 @@ std::string common_speculative_type_to_str(enum common_speculative_type type) {
case COMMON_SPECULATIVE_TYPE_NONE: return "none";
case COMMON_SPECULATIVE_TYPE_DRAFT: return "draft";
case COMMON_SPECULATIVE_TYPE_EAGLE3: return "eagle3";
case COMMON_SPECULATIVE_TYPE_MTP: return "mtp";
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram_simple";
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return "ngram_map_k";
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return "ngram_map_k4v";
@ -822,9 +905,19 @@ bool common_speculative_is_compat(llama_context * ctx_tgt) {
// try to remove the last tokens
if (!llama_memory_seq_rm(mem, 0, 1, -1)) {
LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__);
res = false;
goto done;
// Check if the model has MTP layers — for MTP-1, we can use
// checkpoint/restore instead of seq_rm for the 1-token rollback.
// Hybrid SSM models (DeltaNet) support checkpoint/restore via
// llama-memory-recurrent.cpp even though they don't support seq_rm.
const auto * model = llama_get_model(ctx_tgt);
if (model && llama_model_n_mtp_layers(model) > 0) {
LOG_INF("%s: seq_rm not supported, but MTP model detected — using checkpoint/restore for rollback\n", __func__);
// Restore the state we just modified
} else {
LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__);
res = false;
goto done;
}
}
done:
@ -853,6 +946,7 @@ common_speculative * common_speculative_init(
{
bool has_draft = !params.mparams_dft.path.empty();
bool has_draft_eagle3 = false; // TODO PR-18039: if params.speculative.eagle3
bool has_mtp = (params.type == COMMON_SPECULATIVE_TYPE_MTP);
bool has_ngram_cache = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_CACHE);
bool has_ngram_simple = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE);
@ -892,6 +986,9 @@ common_speculative * common_speculative_init(
if (has_ngram_cache) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, params));
}
if (has_mtp) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_MTP, params));
}
if (has_draft) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT, params));
}
@ -919,6 +1016,10 @@ common_speculative * common_speculative_init(
impls.push_back(std::make_unique<common_speculative_state_eagle3>(config.type));
break;
}
case COMMON_SPECULATIVE_TYPE_MTP: {
impls.push_back(std::make_unique<common_speculative_state_mtp>(config.type, ctx_tgt));
break;
}
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: {
common_ngram_map ngram_map = get_common_ngram_map(config);

View File

@ -5046,6 +5046,55 @@ class _LinearAttentionVReorderBase(Qwen3NextModel):
class Qwen3_5TextModel(_LinearAttentionVReorderBase):
model_arch = gguf.MODEL_ARCH.QWEN35
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# If model has MTP layers, include them in block_count
mtp_layers = self.hparams.get("mtp_num_hidden_layers", 0)
if mtp_layers > 0:
self.block_count = self.hparams["num_hidden_layers"] + mtp_layers
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
def set_gguf_parameters(self):
super().set_gguf_parameters()
mtp_layers = self.hparams.get("mtp_num_hidden_layers", 0)
if mtp_layers > 0:
self.gguf_writer.add_nextn_predict_layers(mtp_layers)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if name.startswith("mtp."):
num_hidden = self.hparams["num_hidden_layers"]
if "layers." in name:
# Remap MTP transformer block tensors to append after main layers
# mtp.layers.{k}.* -> model.layers.{k + num_hidden_layers}.*
new_bid = (bid or 0) + num_hidden
name = name.replace(f"mtp.layers.{bid}", f"model.layers.{new_bid}")
yield from super().modify_tensors(data_torch, name, new_bid)
else:
# Shared MTP weights -> nextn tensor slots
from pathlib import Path
remapper = {
"mtp.fc": "model.layers.{bid}.eh_proj",
"mtp.pre_fc_norm_embedding": "model.layers.{bid}.enorm",
"mtp.pre_fc_norm_hidden": "model.layers.{bid}.hnorm",
"mtp.norm": "model.layers.{bid}.shared_head.norm",
}
_n = Path(name)
matched = False
for prefix, template in remapper.items():
if name.startswith(prefix):
suffix = name[len(prefix):] # e.g. ".weight"
for b in range(num_hidden, self.block_count):
new_name = template.format(bid=b) + suffix
yield from super().modify_tensors(data_torch, new_name, b)
matched = True
break
if not matched:
# Skip unknown MTP tensors (e.g. embed_tokens/lm_head if shared)
pass
return
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("Qwen3_5MoeForConditionalGeneration", "Qwen3_5MoeForCausalLM")
class Qwen3_5MoeTextModel(_LinearAttentionVReorderBase):

View File

@ -1898,7 +1898,14 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.SSM_NORM,
MODEL_TENSOR.SSM_BETA,
MODEL_TENSOR.SSM_ALPHA,
MODEL_TENSOR.SSM_OUT
MODEL_TENSOR.SSM_OUT,
# NextN/MTP tensors
MODEL_TENSOR.NEXTN_EH_PROJ,
MODEL_TENSOR.NEXTN_EMBED_TOKENS,
MODEL_TENSOR.NEXTN_ENORM,
MODEL_TENSOR.NEXTN_HNORM,
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD,
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM,
],
MODEL_ARCH.QWEN35MOE: [
MODEL_TENSOR.TOKEN_EMBD,

View File

@ -557,6 +557,9 @@ extern "C" {
LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_swa (const struct llama_model * model);
// Returns the number of Multi-Token Prediction layers (0 if MTP is not available)
LLAMA_API int32_t llama_model_n_mtp_layers(const struct llama_model * model);
// Get the model's RoPE frequency scaling factor
LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);
@ -988,6 +991,11 @@ extern "C" {
// returns NULL for invalid ids.
LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i);
// Get MTP (Multi-Token Prediction) draft logits for the last output position.
// With FastMTP, returns mtp_n_vocab floats (reduced vocabulary). Use llama_get_mtp_n_vocab().
LLAMA_API float * llama_get_mtp_logits(struct llama_context * ctx);
LLAMA_API int64_t llama_get_mtp_n_vocab(struct llama_context * ctx);
// Get all output token embeddings.
// when pooling_type == LLAMA_POOLING_TYPE_NONE or when using a generative model,
// the embeddings for which llama_batch.logits[i] != 0 are stored contiguously

View File

@ -1051,6 +1051,13 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
LLM_TENSOR_SSM_ALPHA,
LLM_TENSOR_SSM_NORM,
LLM_TENSOR_SSM_OUT,
// NextN/MTP tensors
LLM_TENSOR_NEXTN_EH_PROJ,
LLM_TENSOR_NEXTN_EMBED_TOKENS,
LLM_TENSOR_NEXTN_ENORM,
LLM_TENSOR_NEXTN_HNORM,
LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD,
LLM_TENSOR_NEXTN_SHARED_HEAD_NORM,
};
case LLM_ARCH_QWEN35MOE:
return {
@ -2753,14 +2760,13 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_INDEXER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_INDEXER_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_INDEXER_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
// NextN/MTP tensors are currently ignored (reserved for future MTP support)
// These tensors only exist in the last layer(s) and are treated as output tensors
{LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
{LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
{LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
{LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
// NextN/MTP tensors — per-layer (appended after main layers)
{LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}},
{LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
// Nemotron 3 Super
{LLM_TENSOR_FFN_LATENT_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_FFN_LATENT_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},

View File

@ -262,12 +262,13 @@ bool llama_batch_allocr::init(
const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;
if (batch.token) {
if (p0 >= 0 && p0 >= seq_pos_min(s)) {
// Allow X == Y for speculative decoding where seq_rm + re-eval at same position is valid
if (p0 >= 0 && p0 > seq_pos_min(s)) {
LLAMA_LOG_ERROR(
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
" - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
" for M-RoPE, it is required that the position satisfies: X < Y\n",
" for M-RoPE, it is required that the position satisfies: X <= Y\n",
__func__, s, s, p0, s, seq_pos_min(s));
return false;

View File

@ -777,6 +777,13 @@ float * llama_context::get_logits() {
return logits.data;
}
float * llama_context::get_mtp_logits() {
if (!mtp_logits_valid || mtp_logits_buf.empty()) {
return nullptr;
}
return mtp_logits_buf.data();
}
int64_t llama_context::output_resolve_row(int32_t i) const {
int64_t j = -1;
@ -1806,6 +1813,24 @@ int llama_context::decode(const llama_batch & batch_inp) {
}
}
// Extract MTP logits if available
if (res->t_logits_mtp != nullptr && n_outputs > 0) {
ggml_backend_t backend_mtp = ggml_backend_sched_get_tensor_backend(sched.get(), res->t_logits_mtp);
if (backend_mtp != nullptr) {
const int64_t mtp_n_vocab = res->t_logits_mtp->ne[0];
const int64_t mtp_n_tokens = res->t_logits_mtp->ne[1];
mtp_logits_buf.resize(mtp_n_vocab);
const size_t offset = (mtp_n_tokens - 1) * mtp_n_vocab * sizeof(float);
ggml_backend_tensor_get_async(backend_mtp, res->t_logits_mtp,
mtp_logits_buf.data(), offset, mtp_n_vocab * sizeof(float));
mtp_logits_valid = true;
this->mtp_n_vocab = mtp_n_vocab;
}
} else {
mtp_logits_valid = false;
}
// Copy backend sampling output if this ubatch produced any sampling tensors.
if (has_samplers && (!res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty())) {
const auto seq_to_output_row = build_seq_to_output_row(ubatch, n_outputs_prev);
@ -3089,6 +3114,16 @@ float * llama_get_logits_ith(llama_context * ctx, int32_t i) {
return res;
}
float * llama_get_mtp_logits(llama_context * ctx) {
ctx->synchronize();
return ctx->get_mtp_logits();
}
int64_t llama_get_mtp_n_vocab(llama_context * ctx) {
return ctx->get_mtp_n_vocab();
}
float * llama_get_embeddings(llama_context * ctx) {
ctx->synchronize();

View File

@ -74,6 +74,8 @@ struct llama_context {
float * get_logits();
float * get_logits_ith(int32_t i);
float * get_mtp_logits();
int64_t get_mtp_n_vocab() const { return mtp_n_vocab; }
float * get_embeddings();
float * get_embeddings_ith(int32_t i);
@ -268,6 +270,11 @@ private:
// decode output (2-dimensional array: [n_outputs][n_vocab])
buffer_view<float> logits = {nullptr, 0};
// MTP draft logits — with FastMTP, reduced to top-K tokens (e.g., 32K vs 248K)
std::vector<float> mtp_logits_buf;
bool mtp_logits_valid = false;
int64_t mtp_n_vocab = 0;
// embeddings output (2-dimensional array: [n_outputs][n_embd])
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
buffer_view<float> embd = {nullptr, 0};

View File

@ -662,6 +662,10 @@ public:
ggml_tensor * t_embd = nullptr;
ggml_tensor * t_embd_pooled = nullptr;
// MTP (Multi-Token Prediction) output nodes
ggml_tensor * t_logits_mtp = nullptr; // [n_vocab, n_tokens] draft logits from MTP head
ggml_tensor * t_embd_mtp = nullptr; // [n_embd, n_tokens] hidden state from MTP head
std::map<llama_seq_id, ggml_tensor*> t_sampled_logits;
std::map<llama_seq_id, ggml_tensor*> t_candidates;
std::map<llama_seq_id, ggml_tensor*> t_sampled;

View File

@ -10,6 +10,7 @@
#include <cstring>
#include <limits>
#include <map>
#include <stdexcept>
//
@ -163,12 +164,55 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
const auto & cell = cells[tail_id];
// partial intersection is invalid if it includes the final pos
if (0 < p0 && p0 <= cell.pos && p1 > cell.pos) {
//printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false, p0 = %d, cell.pos = %d, p1 = %d\n", p0, cell.pos, p1);
return false;
// for speculative decoding, search for the best checkpoint to roll back to.
// Prefer exact match at p0-1, but accept the closest position < p0.
// For MTP with 2-token batches, checkpoint may be at p0-2 (before the batch)
// since both tokens are processed atomically.
int32_t best_cell = -1;
llama_pos best_pos = -1;
fprintf(stderr, "[MTP-SEQRM] seq_id=%d, p0=%d, p1=%d, tail_pos=%d, searching for checkpoint at pos<=%d\n",
(int)seq_id, (int)p0, (int)p1, (int)cell.pos, (int)(p0-1));
for (uint32_t i = 0; i < size; ++i) {
if (cells[i].has_seq_id(seq_id)) {
fprintf(stderr, "[MTP-SEQRM] cell[%d] pos=%d\n", i, (int)cells[i].pos);
// Find the closest checkpoint at or below p0-1
if (cells[i].pos < p0 && cells[i].pos > best_pos) {
best_pos = cells[i].pos;
best_cell = i;
}
}
}
fflush(stderr);
if (best_cell >= 0) {
fprintf(stderr, "[MTP-SEQRM] FOUND checkpoint at cell[%d] pos=%d (target was %d) — rolling back\n",
best_cell, (int)best_pos, (int)(p0-1));
fflush(stderr);
tail_id = best_cell;
} else {
fprintf(stderr, "[MTP-SEQRM] NO checkpoint found — seq_rm FAILED\n");
fflush(stderr);
return false;
}
}
// invalidate tails which will be cleared
if (p0 <= cell.pos && cell.pos < p1) {
tail_id = -1;
if (p0 == 0) {
tail_id = -1;
} else {
// Search for the best remaining cell after removal
int32_t new_tail = -1;
llama_pos max_pos = -1;
for (uint32_t i = 0; i < size; ++i) {
if (cells[i].has_seq_id(seq_id) && cells[i].pos < p0) {
if (cells[i].pos > max_pos) {
max_pos = cells[i].pos;
new_tail = i;
}
}
}
tail_id = new_tail;
}
}
}
} else {
@ -184,6 +228,11 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
if (seq_id < 0) {
cells[i].seq_id.clear();
} else if (cells[i].has_seq_id(seq_id)) {
if (p0 > 0 && p1 == std::numeric_limits<llama_pos>::max()) {
// partial removal: just move the position back
cells[i].pos = p0 - 1;
continue;
}
cells[i].seq_id.erase(seq_id);
} else {
continue;
@ -224,25 +273,42 @@ void llama_memory_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id
}
if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) {
auto & tail_src = cells[seq_id_src];
auto & tail_dst = cells[seq_id_dst];
if (tail_dst.tail >= 0) {
auto & tail_src_meta = cells[seq_id_src];
auto & tail_dst_meta = cells[seq_id_dst];
if (tail_dst_meta.tail >= 0) {
// clear destination seq_id if it wasn't empty
auto & cell_dst = cells[tail_dst.tail];
cell_dst.seq_id.erase(seq_id_dst);
tail_dst.tail = -1;
if (cell_dst.seq_id.empty()) {
cell_dst.pos = -1;
cell_dst.src = -1;
used -= 1;
}
seq_rm(seq_id_dst, -1, -1);
}
if (tail_src.tail >= 0) {
auto & cell_src = cells[tail_src.tail];
cell_src.seq_id.insert(seq_id_dst);
tail_dst.tail = tail_src.tail;
if (tail_src_meta.tail >= 0) {
auto & cell_src = cells[tail_src_meta.tail];
// For recurrent models, we must copy the state to a new cell
// Otherwise, both sequences would share the same mutable state
uint32_t next_empty_cell = size;
for (uint32_t i = head; i < head + size; ++i) {
uint32_t idx = i % size;
if (cells[idx].is_empty()) {
next_empty_cell = idx;
break;
}
}
if (next_empty_cell != size) {
auto & empty_cell = cells[next_empty_cell];
// Copy tensors data
copy_cell(tail_src_meta.tail, next_empty_cell);
empty_cell.pos = cell_src.pos;
empty_cell.src = next_empty_cell; // results in a copy in the graph if needed
empty_cell.seq_id.insert(seq_id_dst);
tail_dst_meta.tail = next_empty_cell;
used += 1;
} else {
LLAMA_LOG_ERROR("%s: failed to find available cell for copy\n", __func__);
}
}
}
}
@ -367,6 +433,61 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
return result;
}
void llama_memory_recurrent::copy_cell(int32_t i_src, int32_t i_dst) {
if (i_src == i_dst || i_src < 0 || i_dst < 0) {
return;
}
fprintf(stderr, "[MTP-COPYCELL] copy_cell(%d -> %d), n_layer=%d\n", i_src, i_dst, (int)hparams.n_layer);
fflush(stderr);
// Copy recurrent state via GPU-to-GPU (ggml_backend_tensor_copy).
// Views created with no_alloc=true have buffer=NULL. We must set
// the buffer to the parent tensor's buffer for the copy to work.
ggml_init_params params = {
/*.mem_size =*/ size_t(2*ggml_tensor_overhead()),
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true,
};
for (uint32_t il = 0; il < hparams.n_layer; ++il) {
if (r_l[il]) {
ggml_context * ctx = ggml_init(params);
// Tensor is 1D: ne[0] = n_embd_r * size. Each cell = ne[0]/size elements.
// ggml_view_1d takes ELEMENT count, not byte count!
int64_t cell_elements = r_l[il]->ne[0] / size;
size_t cell_bytes = ggml_row_size(r_l[il]->type, cell_elements);
ggml_tensor * src_v = ggml_view_1d(ctx, r_l[il], cell_elements, (size_t)i_src * cell_bytes);
ggml_tensor * dst_v = ggml_view_1d(ctx, r_l[il], cell_elements, (size_t)i_dst * cell_bytes);
src_v->buffer = r_l[il]->buffer;
dst_v->buffer = r_l[il]->buffer;
ggml_backend_tensor_copy(src_v, dst_v);
ggml_free(ctx);
}
if (s_l[il]) {
ggml_context * ctx = ggml_init(params);
int64_t cell_elements = s_l[il]->ne[0] / size;
size_t cell_bytes = ggml_row_size(s_l[il]->type, cell_elements);
ggml_tensor * src_v = ggml_view_1d(ctx, s_l[il], cell_elements, (size_t)i_src * cell_bytes);
ggml_tensor * dst_v = ggml_view_1d(ctx, s_l[il], cell_elements, (size_t)i_dst * cell_bytes);
src_v->buffer = s_l[il]->buffer;
dst_v->buffer = s_l[il]->buffer;
ggml_backend_tensor_copy(src_v, dst_v);
ggml_free(ctx);
}
}
}
int llama_memory_recurrent::get_cell_count(llama_seq_id seq_id) const {
int count = 0;
for (uint32_t i = 0; i < size; ++i) {
if (cells[i].has_seq_id(seq_id)) {
count++;
}
}
return count;
}
std::map<ggml_backend_buffer_type_t, size_t> llama_memory_recurrent::memory_breakdown() const {
std::map<ggml_backend_buffer_type_t, size_t> ret;
for (const auto & [_, buf] : ctxs_bufs) {
@ -453,6 +574,10 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
const uint32_t n_seqs = ubatch.n_seqs;
fprintf(stderr, "[MTP-FINDSLOT] find_slot: n_seq_tokens=%d, n_seqs=%d, size=%d, used=%d, head=%d\n",
(int)n_seq_tokens, (int)n_seqs, (int)size, (int)used, (int)head);
fflush(stderr);
// if we have enough unused cells before the current head ->
// better to start searching from the beginning of the cache, hoping to fill it
if (head > used + 2*n_seqs) {
@ -551,10 +676,35 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
if (seq_meta.tail >= 0) {
auto & orig_cell = cells[seq_meta.tail];
empty_cell.pos = orig_cell.pos;
empty_cell.src = orig_cell.src;
orig_cell.seq_id.erase(seq_id);
empty_cell.src = seq_meta.tail; // the data should be copied from the previous tail
// Copy state data
copy_cell(seq_meta.tail, next_empty_cell);
// Keep history of previous states for rollback (up to 8 cells per sequence)
if (get_cell_count(seq_id) < 8 && used < size * 0.9) {
// Do not erase seq_id from orig_cell to keep it as a checkpoint
} else {
// Erase oldest history point for this sequence
int32_t oldest_cell = -1;
llama_pos min_pos = std::numeric_limits<llama_pos>::max();
for (uint32_t i = 0; i < size; ++i) {
if (cells[i].has_seq_id(seq_id) && cells[i].pos < min_pos) {
min_pos = cells[i].pos;
oldest_cell = i;
}
}
if (oldest_cell >= 0) {
cells[oldest_cell].seq_id.erase(seq_id);
if (cells[oldest_cell].is_empty()) {
cells[oldest_cell].pos = -1;
cells[oldest_cell].src = -1;
used--;
}
}
}
empty_cell.seq_id.insert(seq_id); // will be overwritten
GGML_ASSERT(!orig_cell.is_empty()); // has at least one remaining seq_id
}
seq_meta.tail = next_empty_cell;
// find next empty cell
@ -566,6 +716,12 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
if (cell.is_empty()) { break; }
}
}
} else {
// Sequence owns its cell — no checkpoint for MTP.
// For hybrid models with MTP, checkpointing is too expensive (GPU copy
// of all 33 recurrent layers every step). Instead, we let the recurrent
// state accumulate draft tokens on rejection. The 9 attention layers
// handle rollback via KV cache seq_rm, which partially compensates.
}
if (min > seq_meta.tail) { min = seq_meta.tail; }
if (max < seq_meta.tail) { max = seq_meta.tail; }

View File

@ -65,6 +65,10 @@ public:
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
// cell management
void copy_cell(int32_t i_src, int32_t i_dst);
int get_cell_count(llama_seq_id seq_id) const;
uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
uint32_t size = 0; // total number of cells, shared across all sequences
uint32_t used = 0; // used cells (i.e. at least one seq_id)

View File

@ -2408,16 +2408,29 @@ void llama_model::load_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group);
// Mark recurrent layers (linear attention layers)
// NextN/MTP parameters
ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false);
// The total n_layer includes MTP layers appended after main layers.
// Determine the number of main transformer layers for type detection.
const uint32_t n_main_layers = hparams.n_layer - hparams.nextn_predict_layers;
// Mark recurrent layers (linear attention layers) — main layers only
// MTP layers use full attention, so they are NOT recurrent
{
uint32_t full_attn_interval = 4;
ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false);
for (uint32_t i = 0; i < hparams.n_layer; ++i) {
hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0);
if (i < n_main_layers) {
hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0);
} else {
// MTP layers use full attention (not recurrent)
hparams.recurrent_layer_arr[i] = false;
}
}
}
switch (hparams.n_layer) {
switch (n_main_layers) {
case 24: type = hparams.n_embd == 1024 ? LLM_TYPE_0_8B : LLM_TYPE_2B; break;
case 32: type = hparams.n_embd == 2560 ? LLM_TYPE_4B : LLM_TYPE_9B; break;
case 64: type = LLM_TYPE_27B; break;
@ -7277,39 +7290,67 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
const int64_t value_dim = head_v_dim * n_v_heads;
const int64_t conv_dim = key_dim * 2 + value_dim;
const uint32_t n_main_layers = n_layer - hparams.nextn_predict_layers;
for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0);
const bool is_mtp_layer = (static_cast<uint32_t>(i) >= n_main_layers);
if (!hparams.is_recurrent(i)) {
// Attention layers
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head * 2 }, 0);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0);
if (is_mtp_layer) {
// MTP layer: nextn-specific tensors + standard attention + standard FFN
layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, 0);
layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, 0);
layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, 0);
layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED);
layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
// Q/K normalization for attention layers
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0);
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0);
// MTP layer uses same gated attention as main model (joint QG projection)
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0);
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head * 2 }, 0);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0);
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0);
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0);
// MTP layer uses standard dense FFN
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
} else {
// Linear attention (gated delta net) specific tensors
// Create tensors with calculated dimensions
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED);
layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED);
layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0);
layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0);
layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0);
layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", i), { n_embd, n_v_heads }, 0);
layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", i), { n_embd, n_v_heads }, 0);
layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0);
layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0);
}
// Main transformer layers
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0);
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
if (!hparams.is_recurrent(i)) {
// Full attention layers (joint QG projection + gated attention)
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head * 2 }, 0);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0);
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0);
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0);
} else {
// Linear attention (gated delta net) specific tensors
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED);
layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED);
layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0);
layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0);
layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0);
layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", i), { n_embd, n_v_heads }, 0);
layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", i), { n_embd, n_v_heads }, 0);
layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0);
layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0);
}
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
}
}
} break;
case LLM_ARCH_MIMO2:
@ -8076,6 +8117,18 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
// Use hybrid-iswa for hybrid models with SWA
// For MTP speculative decoding, we need extra recurrent state
// cells for checkpoint/restore. Each sequence needs at least
// 1 active cell + 1 checkpoint cell per MTP draft step.
const uint32_t n_mtp = hparams.nextn_predict_layers;
// For MTP: need room for active cell + checkpoint cells.
// With size=4: active(1) + checkpoint(1) + room(2) ensures
// can_checkpoint (used < size*0.9 = 3.6) can fire even with 3 cells in use.
// No checkpoint overhead — rs_size = 1 per sequence.
// MTP uses single-batch decode without recurrent state rollback.
const uint32_t rs_per_seq = 1;
const uint32_t rs_size = std::max((uint32_t) 1, cparams.n_seq_max * rs_per_seq);
res = new llama_memory_hybrid_iswa(
/* model */ *this,
/* attn_type_k */ params.type_k,
@ -8087,13 +8140,16 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
/* attn_n_pad */ 1,
/* recurrent_type_r */ GGML_TYPE_F32,
/* recurrent_type_s */ GGML_TYPE_F32,
/* recurrent_rs_size */ std::max((uint32_t) 1, cparams.n_seq_max),
/* recurrent_rs_size */ rs_size,
/* n_seq_max */ cparams.n_seq_max,
/* offload */ cparams.offload_kqv,
/* unified */ cparams.kv_unified,
/* filter_attn */ std::move(filter_attn),
/* filter_recr */ std::move(filter_recr));
} else {
const uint32_t rs_per_seq2 = 1;
const uint32_t rs_size2 = std::max((uint32_t) 1, cparams.n_seq_max * rs_per_seq2);
res = new llama_memory_hybrid(
/* model */ *this,
/* attn_type_k */ params.type_k,
@ -8105,7 +8161,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
/* attn_swa_type */ hparams.swa_type,
/* recurrent_type_k */ GGML_TYPE_F32,
/* recurrent_type_v */ GGML_TYPE_F32,
/* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max),
/* recurrent_kv_size */ rs_size2,
/* n_seq_max */ cparams.n_seq_max,
/* offload */ cparams.offload_kqv,
/* unified */ cparams.kv_unified,
@ -8760,6 +8816,10 @@ int32_t llama_model_n_swa(const llama_model * model) {
return model->hparams.n_swa;
}
int32_t llama_model_n_mtp_layers(const llama_model * model) {
return model->hparams.nextn_predict_layers;
}
uint32_t llama_model_n_cls_out(const struct llama_model * model) {
return model->hparams.n_cls_out;
}

View File

@ -597,7 +597,14 @@ private:
ggml_tensor * input,
int il);
// Build the MTP (Multi-Token Prediction) head with standard transformer block
void build_mtp_head(llm_graph_input_mem_hybrid * inp, ggml_tensor * inp_pos, int * sections);
const llama_model & model;
// Unfiltered hidden state from last main layer (before inp_out_ids filter).
// Used by MTP head for attention KV cache population across all batch tokens.
ggml_tensor * mtp_inp_hidden = nullptr;
};
// TODO: derive llm_build_delta_net_base instead

View File

@ -23,7 +23,10 @@ llm_build_qwen35::llm_build_qwen35(const llama_model & model, const llm_graph_pa
ggml_tensor * inp_pos = build_inp_pos();
ggml_tensor * inp_out_ids = build_inp_out_ids();
for (int il = 0; il < n_layer; ++il) {
// Only process main transformer layers (skip MTP layers appended at the end)
const int n_transformer_layers = n_layer - hparams.nextn_predict_layers;
for (int il = 0; il < n_transformer_layers; ++il) {
ggml_tensor * inpSA = inpL;
cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il);
@ -40,35 +43,43 @@ llm_build_qwen35::llm_build_qwen35(const llama_model & model, const llm_graph_pa
cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il);
}
if (il == n_layer - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
// For the last main layer, process BOTH filtered and unfiltered paths:
// - Unfiltered: saved for MTP head (needs all batch tokens for attention KV cache)
// - Filtered: used for main model logits (only output tokens)
if (il == n_transformer_layers - 1 && inp_out_ids) {
// First: compute full layer output without filtering (for MTP)
ggml_tensor * full_residual = ggml_add(ctx0, cur, inpSA);
ggml_tensor * full_ffn_res = full_residual;
ggml_tensor * full_post_norm = build_norm(full_residual, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, il);
ggml_tensor * full_ffn = build_layer_ffn(full_post_norm, il);
mtp_inp_hidden = ggml_add(ctx0, full_ffn, full_ffn_res);
mtp_inp_hidden = build_cvec(mtp_inp_hidden, il);
cb(mtp_inp_hidden, "mtp_inp_hidden", il);
// Second: filter for main model logits
cur = ggml_get_rows(ctx0, mtp_inp_hidden, inp_out_ids);
inpL = cur;
} else {
// Residual connection
cur = ggml_add(ctx0, cur, inpSA);
cb(cur, "attn_residual", il);
ggml_tensor * ffn_residual = cur;
ggml_tensor * attn_post_norm = build_norm(cur, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, il);
cb(attn_post_norm, "attn_post_norm", il);
cur = build_layer_ffn(attn_post_norm, il);
cb(cur, "ffn_out", il);
cur = ggml_add(ctx0, cur, ffn_residual);
cb(cur, "post_ffn", il);
cur = build_cvec(cur, il);
cb(cur, "l_out", il);
inpL = cur;
}
// Residual connection
cur = ggml_add(ctx0, cur, inpSA);
cb(cur, "attn_residual", il);
// Save the tensor before post-attention norm for residual connection
ggml_tensor * ffn_residual = cur;
// Post-attention norm
ggml_tensor * attn_post_norm = build_norm(cur, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, il);
cb(attn_post_norm, "attn_post_norm", il);
// Dense FFN layer - without residual connection
cur = build_layer_ffn(attn_post_norm, il);
cb(cur, "ffn_out", il);
// Residual connection for FFN - add to the tensor from before post_attention_layernorm
cur = ggml_add(ctx0, cur, ffn_residual);
cb(cur, "post_ffn", il);
cur = build_cvec(cur, il);
cb(cur, "l_out", il);
// Input for next layer
inpL = cur;
}
cur = inpL;
@ -85,6 +96,11 @@ llm_build_qwen35::llm_build_qwen35(const llama_model & model, const llm_graph_pa
res->t_logits = cur;
ggml_build_forward_expand(gf, cur);
// Build MTP head if nextn_predict_layers > 0
if (hparams.nextn_predict_layers > 0) {
build_mtp_head(inp, inp_pos, sections);
}
}
std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen35::build_qkvz(
@ -382,3 +398,145 @@ ggml_tensor * llm_build_qwen35::build_layer_ffn(ggml_tensor * cur, const int il)
return cur;
}
void llm_build_qwen35::build_mtp_head(
llm_graph_input_mem_hybrid * inp,
ggml_tensor * inp_pos,
int * sections) {
// MTP (Multi-Token Prediction) head for dense Qwen 3.5
//
// The MTP module takes the hidden state from the last main transformer layer
// and uses the model's built-in MTP head to produce draft logits.
//
// MTP forward pass:
// 1. sampled_token = argmax(main_logits)
// 2. emb = embed_tokens(sampled_token)
// 3. h_norm = RMSNorm(hidden_state, hnorm)
// 4. e_norm = RMSNorm(emb, enorm)
// 5. combined = eh_proj(concat(e_norm, h_norm))
// 6. Standard self-attention (Q/K/V with Q/K norms + RoPE)
// 7. Standard FFN (gate_proj + up_proj → SiLU → down_proj)
// 8. logits = lm_head(RMSNorm(output, mtp_norm))
const int n_transformer_layers = n_layer - hparams.nextn_predict_layers;
const int64_t n_embd_head = hparams.n_embd_head_v();
// Use unfiltered hidden state for MTP (needs all batch tokens for attention KV cache)
ggml_tensor * hidden_state = mtp_inp_hidden ? mtp_inp_hidden : res->t_embd;
GGML_ASSERT(hidden_state != nullptr);
// Get logits for greedy token selection.
// If no filtering occurred (generation), reuse main logits to avoid expensive lm_head recomputation.
// If filtering occurred (prompt processing), recompute from unfiltered hidden state.
ggml_tensor * greedy_logits;
if (!mtp_inp_hidden || mtp_inp_hidden == res->t_embd) {
// No filtering — main logits already cover all tokens
greedy_logits = res->t_logits;
} else {
// Filtered — recompute logits from unfiltered hidden state
ggml_tensor * full_normed = build_norm(hidden_state, model.output_norm, nullptr, LLM_NORM_RMS, -1);
greedy_logits = build_lora_mm(model.output, full_normed);
}
ggml_tensor * greedy_tokens = ggml_argmax(ctx0, greedy_logits);
cb(greedy_tokens, "mtp_greedy_tokens", -1);
ggml_tensor * mtp_hidden = hidden_state;
for (uint32_t k = 0; k < hparams.nextn_predict_layers; ++k) {
const int il = n_transformer_layers + k;
const auto & layer = model.layers[il];
if (layer.nextn.eh_proj == nullptr) {
continue;
}
// Step 1: Get token embedding (shared with main model)
ggml_tensor * tok_embd = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd;
ggml_tensor * emb = ggml_get_rows(ctx0, tok_embd, greedy_tokens);
cb(emb, "mtp_token_embd", il);
// Step 2: Normalize hidden state and embedding
ggml_tensor * h_norm = build_norm(mtp_hidden, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il);
cb(h_norm, "mtp_hnorm", il);
ggml_tensor * e_norm = build_norm(emb, layer.nextn.enorm, nullptr, LLM_NORM_RMS, il);
cb(e_norm, "mtp_enorm", il);
// Step 3: Concatenate and project
ggml_tensor * concat = ggml_concat(ctx0, e_norm, h_norm, 0); // [2*n_embd, n_tokens]
cb(concat, "mtp_concat", il);
ggml_tensor * cur = build_lora_mm(layer.nextn.eh_proj, concat);
cb(cur, "mtp_projected", il);
// Step 4: Full self-attention for the MTP head (same architecture as main model attention layers)
// The MTP layer has its own KV cache (allocated because is_recurrent(il) = false).
// We use the unfiltered hidden state (mtp_inp_hidden) so token count matches inp_pos.
{
ggml_tensor * attn_residual = cur;
cur = build_norm(cur, layer.attn_norm, nullptr, LLM_NORM_RMS, il);
cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il);
cur = ggml_add(ctx0, cur, attn_residual);
}
// Step 5: Post-attention norm + FFN
{
ggml_tensor * ffn_residual = cur;
ggml_tensor * attn_post_norm = build_norm(cur, layer.attn_post_norm, nullptr, LLM_NORM_RMS, il);
cb(attn_post_norm, "mtp_attn_post_norm", il);
// Standard dense FFN (same as main model FFN)
cur = build_ffn(attn_post_norm,
layer.ffn_up, NULL, layer.ffn_up_s,
layer.ffn_gate, NULL, layer.ffn_gate_s,
layer.ffn_down, NULL, layer.ffn_down_s,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(cur, "mtp_ffn_out", il);
cur = ggml_add(ctx0, cur, ffn_residual);
cb(cur, "mtp_post_ffn", il);
}
mtp_hidden = cur;
// Step 6: Final norm + LM head for draft logits
ggml_tensor * mtp_normed;
if (layer.nextn.shared_head_norm != nullptr) {
mtp_normed = build_norm(mtp_hidden, layer.nextn.shared_head_norm, nullptr, LLM_NORM_RMS, il);
} else {
// Use main model's output norm
mtp_normed = build_norm(mtp_hidden, model.output_norm, nullptr, LLM_NORM_RMS, il);
}
cb(mtp_normed, "mtp_head_norm", il);
ggml_tensor * lm_head = layer.nextn.shared_head_head ? layer.nextn.shared_head_head : model.output;
// FastMTP: vocabulary trimming — only compute logits for top-K tokens
// instead of full 248K vocabulary. Most tokenizers order by frequency,
// so tokens 0..K-1 cover ~95%+ of generated code tokens.
// This reduces the lm_head matmul from [4096,248K] to [4096,32K] (~8x faster).
const int64_t mtp_vocab_size = std::min(lm_head->ne[1], (int64_t)32768);
ggml_tensor * lm_head_reduced = ggml_view_2d(ctx0, lm_head,
lm_head->ne[0], mtp_vocab_size, lm_head->nb[1], 0);
ggml_tensor * mtp_logits = build_lora_mm(lm_head_reduced, mtp_normed);
cb(mtp_logits, "mtp_logits", il);
// Store MTP outputs in graph result
res->t_embd_mtp = mtp_hidden;
res->t_logits_mtp = mtp_logits;
// For recursive MTP (multiple layers), feed greedy tokens forward
if (k + 1 < hparams.nextn_predict_layers) {
greedy_tokens = ggml_argmax(ctx0, mtp_logits);
cb(greedy_tokens, "mtp_greedy_next", il);
}
ggml_build_forward_expand(gf, mtp_logits);
}
}

View File

@ -149,6 +149,15 @@ struct server_slot {
llama_token sampled; // in speculative mode, this is the last accepted token
llama_tokens drafted;
// Inline MTP (Multi-Token Prediction) state.
// Instead of using the speculative framework (which has M-RoPE and SSM
// rollback issues), we propose one draft token from MTP logits and verify
// it in the next decode step. No seq_rm or rollback needed.
llama_token mtp_draft_token = -1; // proposed draft token (-1 = none)
int mtp_i_batch = -1; // batch index of the draft token
bool mtp_pending = false; // true when draft is in the batch awaiting verification
bool mtp_cooldown = false; // skip MTP proposal for one iteration after draft processing
// stats
size_t n_sent_text = 0; // number of sent text character
@ -179,6 +188,10 @@ struct server_slot {
drafted.clear();
i_batch_dft.clear();
mtp_draft_token = -1;
mtp_i_batch = -1;
mtp_pending = false;
mtp_cooldown = false;
generated_tokens.clear();
generated_token_probs.clear();
json_schema = json();
@ -753,11 +766,24 @@ private:
slots.clear();
const bool can_spec = common_speculative_is_compat(ctx);
bool can_spec = common_speculative_is_compat(ctx);
if (!can_spec) {
SRV_WRN("%s", "speculative decoding not supported by this context\n");
}
// Auto-detect MTP: if model has MTP layers and no speculative type
// is explicitly set, auto-enable MTP speculative decoding.
if (params_base.speculative.type == COMMON_SPECULATIVE_TYPE_NONE) {
const int32_t n_mtp = llama_model_n_mtp_layers(llama_get_model(ctx));
if (n_mtp > 0 && can_spec) {
SRV_INF("model has %d MTP layer(s) — auto-enabling MTP speculative decoding\n", n_mtp);
params_base.speculative.type = COMMON_SPECULATIVE_TYPE_MTP;
params_base.speculative.n_max = 1; // MTP-1: one draft token per step
} else if (n_mtp > 0) {
SRV_INF("model has %d MTP layer(s) but speculative context not compatible\n", n_mtp);
}
}
// initialize slots
for (int i = 0; i < params_base.n_parallel; i++) {
server_slot slot;
@ -2066,42 +2092,34 @@ private:
}
// generate draft tokens in speculative decoding mode
// TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK]
// perform the speculative drafting for all sequences at the same time in a single batch
const int n_draft_max = slot.get_n_draft_max();
const bool is_hybrid = llama_model_is_hybrid(model);
if (n_draft_max > 0) {
// Standard speculative decoding for non-hybrid models
if (mctx) {
// we should never reach this, as speculative is automatically disabled if mmproj is loaded
GGML_ABORT("not supported by multimodal");
}
const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens();
const auto & params_spec = slot.task->params.speculative;
llama_tokens draft = common_speculative_draft(slot.spec, params_spec, cached_text_tokens, slot.sampled);
if (draft.size() > (size_t) n_draft_max) {
SLT_WRN(slot, "draft size %d exceeds max %d, truncating\n", (int) draft.size(), n_draft_max);
draft.resize(n_draft_max);
}
// add the sampled token to the batch
slot.i_batch_dft.push_back(batch.n_tokens);
common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true);
slot.prompt.tokens.push_back(slot.sampled);
if (slot.task->params.speculative.n_min > (int) draft.size()) {
SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.task->params.speculative.n_min);
// fallback to normal decoding
slot.i_batch = slot.i_batch_dft[0];
slot.drafted.clear();
slot.i_batch_dft.clear();
} else {
// keep track of total number of drafted tokens tested
slot.n_draft_total += draft.size();
// add all drafted tokens to the batch
for (size_t i = 0; i < draft.size(); i++) {
slot.i_batch_dft.push_back(batch.n_tokens);
common_batch_add(batch, draft[i], slot.prompt.tokens.pos_next(), { slot.id }, true);
@ -2114,7 +2132,6 @@ private:
slot.i_batch = batch.n_tokens;
common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true);
slot.prompt.tokens.push_back(slot.sampled);
SLT_DBG(slot, "slot decode token, n_ctx = %d, n_tokens = %d, truncated = %d\n",
@ -2821,7 +2838,7 @@ private:
slot.state = SLOT_STATE_GENERATING;
if (slot.can_speculate()) {
common_speculative_begin(slot.spec, slot.prompt.tokens.get_text_tokens());
common_speculative_begin(slot.spec, slot.prompt.tokens.get_text_tokens());
}
} else if (slot.state != SLOT_STATE_GENERATING) {
continue; // continue loop of slots
@ -2833,13 +2850,135 @@ private:
const int tok_idx = slot.i_batch - i;
// --- Two-phase MTP: verify draft after decoding sampled token ---
// The sampled token was decoded alone (1-token batch).
// Now verify the draft against main model logits.
// If accepted: decode draft in a second pass (correct checkpoint position).
// If rejected: skip draft decode (no seq_rm needed — state is clean).
if (slot.mtp_pending) {
// Sample from main model at the sampled token's position
llama_token verified = common_sampler_sample(slot.smpl.get(), ctx, tok_idx);
common_sampler_accept(slot.smpl.get(), verified, true);
slot.n_draft_total += 1;
const int64_t t_current = ggml_time_us();
if (slot.n_decoded == 0) {
slot.t_start_generation = t_current;
slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
metrics.on_prompt_eval(slot);
}
if (verified == slot.mtp_draft_token) {
// ACCEPTED — decode the draft token in a second pass.
// At this point the recurrent state checkpoint is at the
// correct position (after sampled, before draft).
slot.n_draft_accepted += 1;
slot.n_decoded += 1;
// Output the verified/sampled token
completion_token_output result_sampled;
result_sampled.tok = verified;
result_sampled.text_to_send = common_token_to_piece(ctx, result_sampled.tok, accept_special_token(slot, result_sampled.tok));
result_sampled.prob = 1.0f;
bool should_stop = !process_token(result_sampled, slot);
if (should_stop) {
slot.print_timings();
send_final_response(slot);
metrics.on_prediction(slot);
slot.release();
slot.mtp_pending = false;
slot.mtp_draft_token = -1;
continue;
}
// Build a 1-token batch for the draft token and decode it
llama_batch draft_batch = llama_batch_init(1, 0, 1);
common_batch_clear(draft_batch);
common_batch_add(draft_batch, slot.mtp_draft_token, slot.prompt.tokens.pos_next(), { slot.id }, true);
slot.prompt.tokens.push_back(slot.mtp_draft_token);
fprintf(stderr, "[MTP-2PHASE] draft ACCEPTED (verified=%d), decoding draft at pos %d\n",
(int)verified, (int)(slot.prompt.tokens.pos_next() - 1));
fflush(stderr);
int ret2 = llama_decode(ctx, draft_batch);
llama_batch_free(draft_batch);
if (ret2 != 0) {
fprintf(stderr, "[MTP-2PHASE] ERROR: draft decode failed with ret=%d\n", ret2);
fflush(stderr);
// Fall through — state is still valid at pre-draft position
} else {
// Sample bonus token from draft's logits (the NEXT-NEXT token)
llama_token bonus = common_sampler_sample(slot.smpl.get(), ctx, 0);
common_sampler_accept(slot.smpl.get(), bonus, true);
slot.n_decoded += 1;
slot.sampled = bonus;
// Output the bonus token (draft/verified was already output above)
completion_token_output result_bonus;
result_bonus.tok = bonus;
result_bonus.text_to_send = common_token_to_piece(ctx, result_bonus.tok, accept_special_token(slot, result_bonus.tok));
result_bonus.prob = 1.0f;
if (!process_token(result_bonus, slot)) {
slot.print_timings();
send_final_response(slot);
metrics.on_prediction(slot);
slot.release();
slot.mtp_pending = false;
slot.mtp_draft_token = -1;
continue;
}
// Inform speculative state about the acceptance
common_speculative_accept(slot.spec, 1);
}
} else {
// REJECTED — draft was never decoded, state is clean.
// No seq_rm needed!
fprintf(stderr, "[MTP-2PHASE] draft REJECTED (verified=%d, draft=%d)\n",
(int)verified, (int)slot.mtp_draft_token);
fflush(stderr);
slot.sampled = verified;
slot.n_decoded += 1;
// Output the verified token
completion_token_output result_verified;
result_verified.tok = verified;
result_verified.text_to_send = common_token_to_piece(ctx, result_verified.tok, accept_special_token(slot, result_verified.tok));
result_verified.prob = 1.0f;
if (!process_token(result_verified, slot)) {
slot.print_timings();
send_final_response(slot);
metrics.on_prediction(slot);
slot.release();
slot.mtp_pending = false;
slot.mtp_draft_token = -1;
continue;
}
}
slot.t_token_generation = std::max<int64_t>(1, t_current - slot.t_start_generation) / 1e3;
slot.mtp_pending = false;
slot.mtp_draft_token = -1;
slot.mtp_i_batch = -1;
slot.i_batch = -1;
continue; // done with this slot for this decode step
}
// --- Normal sampling (no pending MTP draft) ---
llama_token id = common_sampler_sample(slot.smpl.get(), ctx, tok_idx);
slot.i_batch = -1;
common_sampler_accept(slot.smpl.get(), id, true);
// here we have synchronized the llama_context (due to the sampling above), so we can do time measurement
const int64_t t_current = ggml_time_us();
slot.n_decoded += 1;
@ -2855,14 +2994,13 @@ private:
completion_token_output result;
result.tok = id;
result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs
result.prob = 1.0f;
if (slot.task->params.sampling.n_probs > 0) {
populate_token_probs(slot, result, slot.task->params.post_sampling_probs, params_base.special, tok_idx);
}
if (!process_token(result, slot)) {
// release slot because of stop condition
slot.print_timings();
send_final_response(slot);
metrics.on_prediction(slot);
@ -2904,7 +3042,15 @@ private:
slot.prompt.tokens.insert({ids.begin(), ids.end() - 1});
slot.sampled = ids.back(); // last accepted token
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.prompt.n_tokens(), -1);
// Remove rejected draft tokens from KV cache.
// For hybrid SSM/DeltaNet, seq_rm may fail. In that case,
// just log and continue — the recurrent state has the draft
// token baked in, but the checkpoint mechanism in
// llama-memory-recurrent.cpp should handle rollback internally
// during the next find_slot call.
if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.prompt.n_tokens(), -1)) {
SLT_WRN(slot, "seq_rm failed at pos %d\n", (int)slot.prompt.n_tokens());
}
for (size_t i = 0; i < ids.size(); ++i) {
completion_token_output result;