feat: add MTP (Multi-Token Prediction) support for dense Qwen 3.5
Add native MTP support for the dense Qwen 3.5 architecture (0.8B, 2B, 4B, 9B, 27B). What works: - MTP graph builder for dense qwen35 (build_mtp_head in qwen35.cpp) - MTP tensor loading and registration for QWEN35 arch - GGUF converter handles MTP tensors (mtp.fc, mtp.layers, mtp.norm, etc.) - Public API: llama_get_mtp_logits(), llama_model_n_mtp_layers() - Server auto-detects MTP from GGUF metadata - Speculative state machine for MTP draft token generation - PR #20075 applied: recurrent state checkpoint/restore for hybrid models - M-RoPE position check relaxed for speculative re-evaluation - Windows os.kill fix for gateway process detection What needs work: - Speculative verify loop conflicts with tool-calling requests (400 error) - The recommended fix: bypass the speculative framework entirely and implement MTP acceptance directly in the server generation loop (no seq_rm/rollback needed since MTP drafts are produced in-graph) - MTP attention skipped (projection + FFN path only) due to inp_out_ids token count mismatch Tested on: RTX 5060 8GB, Windows 11, CUDA 13.2 Model: Qwen3.5-9B with MTP tensors (Q4_K_M quantization) Base: llama.cpp b8388
This commit is contained in:
parent
d34ff7eb5b
commit
6075918309
|
|
@ -3463,8 +3463,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
|
||||
add_opt(common_arg(
|
||||
{"--spec-type"}, "[none|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]",
|
||||
string_format("type of speculative decoding to use when no draft model is provided (default: %s)\n",
|
||||
{"--spec-type"}, "[none|mtp|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]",
|
||||
string_format("type of speculative decoding to use when no draft model is provided (default: %s)\n"
|
||||
" mtp: use model's built-in Multi-Token Prediction head (requires MTP-capable model)\n",
|
||||
common_speculative_type_to_str(params.speculative.type).c_str()),
|
||||
[](common_params & params, const std::string & value) {
|
||||
if (value == "none") {
|
||||
|
|
@ -3479,6 +3480,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V;
|
||||
} else if (value == "ngram-mod") {
|
||||
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MOD;
|
||||
} else if (value == "mtp") {
|
||||
params.speculative.type = COMMON_SPECULATIVE_TYPE_MTP;
|
||||
} else {
|
||||
throw std::invalid_argument("unknown speculative decoding type without draft model");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -170,6 +170,7 @@ enum common_speculative_type {
|
|||
COMMON_SPECULATIVE_TYPE_NONE, // no speculative decoding
|
||||
COMMON_SPECULATIVE_TYPE_DRAFT, // draft model
|
||||
COMMON_SPECULATIVE_TYPE_EAGLE3, // eagle draft model
|
||||
COMMON_SPECULATIVE_TYPE_MTP, // multi-token prediction (uses model's built-in MTP head)
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, // self-speculative decoding with n-gram keys only
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, // self-speculative decoding with n-gram keys and 4 m-gram values
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ const std::vector<enum common_speculative_type> common_speculative_types = {
|
|||
COMMON_SPECULATIVE_TYPE_NONE,
|
||||
COMMON_SPECULATIVE_TYPE_DRAFT,
|
||||
COMMON_SPECULATIVE_TYPE_EAGLE3,
|
||||
COMMON_SPECULATIVE_TYPE_MTP,
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE,
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K,
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V,
|
||||
|
|
@ -32,6 +33,7 @@ const std::map<std::string, enum common_speculative_type> common_speculative_typ
|
|||
{"none", COMMON_SPECULATIVE_TYPE_NONE},
|
||||
{"draft", COMMON_SPECULATIVE_TYPE_DRAFT},
|
||||
{"eagle3", COMMON_SPECULATIVE_TYPE_EAGLE3},
|
||||
{"mtp", COMMON_SPECULATIVE_TYPE_MTP},
|
||||
{"ngram_simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE},
|
||||
{"ngram_map_k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K},
|
||||
{"ngram_map_k4v", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V},
|
||||
|
|
@ -462,6 +464,60 @@ struct common_speculative_state_eagle3 : public common_speculative_state {
|
|||
}
|
||||
};
|
||||
|
||||
// Multi-Token Prediction (MTP) speculative decoding state
|
||||
struct common_speculative_state_mtp : public common_speculative_state {
|
||||
llama_context * ctx_tgt;
|
||||
|
||||
common_speculative_state_mtp(
|
||||
enum common_speculative_type type,
|
||||
llama_context * ctx_tgt)
|
||||
: common_speculative_state(type)
|
||||
, ctx_tgt(ctx_tgt)
|
||||
{
|
||||
}
|
||||
|
||||
~common_speculative_state_mtp() override = default;
|
||||
|
||||
void begin(const llama_tokens & prompt) override {
|
||||
GGML_UNUSED(prompt);
|
||||
}
|
||||
|
||||
void draft(
|
||||
const common_params_speculative & params,
|
||||
const llama_tokens & prompt_tgt,
|
||||
llama_token id_last,
|
||||
llama_tokens & result) override {
|
||||
GGML_UNUSED(prompt_tgt);
|
||||
|
||||
const float * mtp_logits = llama_get_mtp_logits(ctx_tgt);
|
||||
if (mtp_logits == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(llama_get_model(ctx_tgt)));
|
||||
|
||||
llama_token best_token = 0;
|
||||
float best_logit = mtp_logits[0];
|
||||
for (int i = 1; i < n_vocab; ++i) {
|
||||
if (mtp_logits[i] > best_logit) {
|
||||
best_logit = mtp_logits[i];
|
||||
best_token = i;
|
||||
}
|
||||
}
|
||||
|
||||
if (best_token >= 0 && best_token < n_vocab) {
|
||||
result.push_back(best_token);
|
||||
}
|
||||
|
||||
GGML_UNUSED(id_last);
|
||||
GGML_UNUSED(params);
|
||||
}
|
||||
|
||||
void accept(uint16_t n_accepted) override {
|
||||
GGML_UNUSED(n_accepted);
|
||||
}
|
||||
};
|
||||
|
||||
// state of self-speculation (simple implementation, not ngram-map)
|
||||
struct common_speculative_state_ngram_simple : public common_speculative_state {
|
||||
common_ngram_simple_config config;
|
||||
|
|
@ -781,6 +837,7 @@ std::string common_speculative_type_to_str(enum common_speculative_type type) {
|
|||
case COMMON_SPECULATIVE_TYPE_NONE: return "none";
|
||||
case COMMON_SPECULATIVE_TYPE_DRAFT: return "draft";
|
||||
case COMMON_SPECULATIVE_TYPE_EAGLE3: return "eagle3";
|
||||
case COMMON_SPECULATIVE_TYPE_MTP: return "mtp";
|
||||
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram_simple";
|
||||
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return "ngram_map_k";
|
||||
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return "ngram_map_k4v";
|
||||
|
|
@ -853,6 +910,7 @@ common_speculative * common_speculative_init(
|
|||
{
|
||||
bool has_draft = !params.mparams_dft.path.empty();
|
||||
bool has_draft_eagle3 = false; // TODO PR-18039: if params.speculative.eagle3
|
||||
bool has_mtp = (params.type == COMMON_SPECULATIVE_TYPE_MTP);
|
||||
|
||||
bool has_ngram_cache = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_CACHE);
|
||||
bool has_ngram_simple = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE);
|
||||
|
|
@ -892,6 +950,9 @@ common_speculative * common_speculative_init(
|
|||
if (has_ngram_cache) {
|
||||
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, params));
|
||||
}
|
||||
if (has_mtp) {
|
||||
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_MTP, params));
|
||||
}
|
||||
if (has_draft) {
|
||||
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT, params));
|
||||
}
|
||||
|
|
@ -919,6 +980,10 @@ common_speculative * common_speculative_init(
|
|||
impls.push_back(std::make_unique<common_speculative_state_eagle3>(config.type));
|
||||
break;
|
||||
}
|
||||
case COMMON_SPECULATIVE_TYPE_MTP: {
|
||||
impls.push_back(std::make_unique<common_speculative_state_mtp>(config.type, ctx_tgt));
|
||||
break;
|
||||
}
|
||||
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: {
|
||||
common_ngram_map ngram_map = get_common_ngram_map(config);
|
||||
|
||||
|
|
|
|||
|
|
@ -5033,6 +5033,55 @@ class _LinearAttentionVReorderBase(Qwen3NextModel):
|
|||
class Qwen3_5TextModel(_LinearAttentionVReorderBase):
|
||||
model_arch = gguf.MODEL_ARCH.QWEN35
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# If model has MTP layers, include them in block_count
|
||||
mtp_layers = self.hparams.get("mtp_num_hidden_layers", 0)
|
||||
if mtp_layers > 0:
|
||||
self.block_count = self.hparams["num_hidden_layers"] + mtp_layers
|
||||
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
mtp_layers = self.hparams.get("mtp_num_hidden_layers", 0)
|
||||
if mtp_layers > 0:
|
||||
self.gguf_writer.add_nextn_predict_layers(mtp_layers)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
if name.startswith("mtp."):
|
||||
num_hidden = self.hparams["num_hidden_layers"]
|
||||
|
||||
if "layers." in name:
|
||||
# Remap MTP transformer block tensors to append after main layers
|
||||
# mtp.layers.{k}.* -> model.layers.{k + num_hidden_layers}.*
|
||||
new_bid = (bid or 0) + num_hidden
|
||||
name = name.replace(f"mtp.layers.{bid}", f"model.layers.{new_bid}")
|
||||
yield from super().modify_tensors(data_torch, name, new_bid)
|
||||
else:
|
||||
# Shared MTP weights -> nextn tensor slots
|
||||
from pathlib import Path
|
||||
remapper = {
|
||||
"mtp.fc": "model.layers.{bid}.eh_proj",
|
||||
"mtp.pre_fc_norm_embedding": "model.layers.{bid}.enorm",
|
||||
"mtp.pre_fc_norm_hidden": "model.layers.{bid}.hnorm",
|
||||
"mtp.norm": "model.layers.{bid}.shared_head.norm",
|
||||
}
|
||||
_n = Path(name)
|
||||
matched = False
|
||||
for prefix, template in remapper.items():
|
||||
if name.startswith(prefix):
|
||||
suffix = name[len(prefix):] # e.g. ".weight"
|
||||
for b in range(num_hidden, self.block_count):
|
||||
new_name = template.format(bid=b) + suffix
|
||||
yield from super().modify_tensors(data_torch, new_name, b)
|
||||
matched = True
|
||||
break
|
||||
if not matched:
|
||||
# Skip unknown MTP tensors (e.g. embed_tokens/lm_head if shared)
|
||||
pass
|
||||
return
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("Qwen3_5MoeForConditionalGeneration", "Qwen3_5MoeForCausalLM")
|
||||
class Qwen3_5MoeTextModel(_LinearAttentionVReorderBase):
|
||||
|
|
|
|||
|
|
@ -1898,7 +1898,14 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||
MODEL_TENSOR.SSM_NORM,
|
||||
MODEL_TENSOR.SSM_BETA,
|
||||
MODEL_TENSOR.SSM_ALPHA,
|
||||
MODEL_TENSOR.SSM_OUT
|
||||
MODEL_TENSOR.SSM_OUT,
|
||||
# NextN/MTP tensors
|
||||
MODEL_TENSOR.NEXTN_EH_PROJ,
|
||||
MODEL_TENSOR.NEXTN_EMBED_TOKENS,
|
||||
MODEL_TENSOR.NEXTN_ENORM,
|
||||
MODEL_TENSOR.NEXTN_HNORM,
|
||||
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD,
|
||||
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM,
|
||||
],
|
||||
MODEL_ARCH.QWEN35MOE: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
|
|
|
|||
|
|
@ -557,6 +557,9 @@ extern "C" {
|
|||
LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_swa (const struct llama_model * model);
|
||||
|
||||
// Returns the number of Multi-Token Prediction layers (0 if MTP is not available)
|
||||
LLAMA_API int32_t llama_model_n_mtp_layers(const struct llama_model * model);
|
||||
|
||||
// Get the model's RoPE frequency scaling factor
|
||||
LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);
|
||||
|
||||
|
|
@ -990,6 +993,10 @@ extern "C" {
|
|||
// returns NULL for invalid ids.
|
||||
LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i);
|
||||
|
||||
// Get MTP (Multi-Token Prediction) draft logits for the last output position.
|
||||
// Returns a pointer to n_vocab floats, or NULL if MTP is not available.
|
||||
LLAMA_API float * llama_get_mtp_logits(struct llama_context * ctx);
|
||||
|
||||
// Get all output token embeddings.
|
||||
// when pooling_type == LLAMA_POOLING_TYPE_NONE or when using a generative model,
|
||||
// the embeddings for which llama_batch.logits[i] != 0 are stored contiguously
|
||||
|
|
|
|||
|
|
@ -1051,6 +1051,13 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
|
|||
LLM_TENSOR_SSM_ALPHA,
|
||||
LLM_TENSOR_SSM_NORM,
|
||||
LLM_TENSOR_SSM_OUT,
|
||||
// NextN/MTP tensors
|
||||
LLM_TENSOR_NEXTN_EH_PROJ,
|
||||
LLM_TENSOR_NEXTN_EMBED_TOKENS,
|
||||
LLM_TENSOR_NEXTN_ENORM,
|
||||
LLM_TENSOR_NEXTN_HNORM,
|
||||
LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD,
|
||||
LLM_TENSOR_NEXTN_SHARED_HEAD_NORM,
|
||||
};
|
||||
case LLM_ARCH_QWEN35MOE:
|
||||
return {
|
||||
|
|
@ -2753,14 +2760,13 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
|||
{LLM_TENSOR_INDEXER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_INDEXER_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_INDEXER_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
// NextN/MTP tensors are currently ignored (reserved for future MTP support)
|
||||
// These tensors only exist in the last layer(s) and are treated as output tensors
|
||||
{LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
|
||||
{LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
|
||||
{LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
|
||||
// NextN/MTP tensors — per-layer (appended after main layers)
|
||||
{LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}},
|
||||
{LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
// Nemotron 3 Super
|
||||
{LLM_TENSOR_FFN_LATENT_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_FFN_LATENT_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
|
|
|
|||
|
|
@ -262,12 +262,13 @@ bool llama_batch_allocr::init(
|
|||
const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;
|
||||
|
||||
if (batch.token) {
|
||||
if (p0 >= 0 && p0 >= seq_pos_min(s)) {
|
||||
// Allow X == Y for speculative decoding where seq_rm + re-eval at same position is valid
|
||||
if (p0 >= 0 && p0 > seq_pos_min(s)) {
|
||||
LLAMA_LOG_ERROR(
|
||||
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
|
||||
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
|
||||
" - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
|
||||
" for M-RoPE, it is required that the position satisfies: X < Y\n",
|
||||
" for M-RoPE, it is required that the position satisfies: X <= Y\n",
|
||||
__func__, s, s, p0, s, seq_pos_min(s));
|
||||
|
||||
return false;
|
||||
|
|
|
|||
|
|
@ -777,6 +777,13 @@ float * llama_context::get_logits() {
|
|||
return logits.data;
|
||||
}
|
||||
|
||||
float * llama_context::get_mtp_logits() {
|
||||
if (!mtp_logits_valid || mtp_logits_buf.empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
return mtp_logits_buf.data();
|
||||
}
|
||||
|
||||
int64_t llama_context::output_resolve_row(int32_t i) const {
|
||||
int64_t j = -1;
|
||||
|
||||
|
|
@ -1797,6 +1804,23 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
}
|
||||
}
|
||||
|
||||
// Extract MTP logits if available
|
||||
if (res->t_logits_mtp != nullptr && n_outputs > 0) {
|
||||
ggml_backend_t backend_mtp = ggml_backend_sched_get_tensor_backend(sched.get(), res->t_logits_mtp);
|
||||
if (backend_mtp != nullptr) {
|
||||
const int64_t mtp_n_vocab = res->t_logits_mtp->ne[0];
|
||||
const int64_t mtp_n_tokens = res->t_logits_mtp->ne[1];
|
||||
|
||||
mtp_logits_buf.resize(mtp_n_vocab);
|
||||
const size_t offset = (mtp_n_tokens - 1) * mtp_n_vocab * sizeof(float);
|
||||
ggml_backend_tensor_get_async(backend_mtp, res->t_logits_mtp,
|
||||
mtp_logits_buf.data(), offset, mtp_n_vocab * sizeof(float));
|
||||
mtp_logits_valid = true;
|
||||
}
|
||||
} else {
|
||||
mtp_logits_valid = false;
|
||||
}
|
||||
|
||||
// Copy backend sampling output if this ubatch produced any sampling tensors.
|
||||
if (has_samplers && (!res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty())) {
|
||||
const auto seq_to_output_row = build_seq_to_output_row(ubatch, n_outputs_prev);
|
||||
|
|
@ -3079,6 +3103,12 @@ float * llama_get_logits_ith(llama_context * ctx, int32_t i) {
|
|||
return res;
|
||||
}
|
||||
|
||||
float * llama_get_mtp_logits(llama_context * ctx) {
|
||||
ctx->synchronize();
|
||||
|
||||
return ctx->get_mtp_logits();
|
||||
}
|
||||
|
||||
float * llama_get_embeddings(llama_context * ctx) {
|
||||
ctx->synchronize();
|
||||
|
||||
|
|
|
|||
|
|
@ -74,6 +74,7 @@ struct llama_context {
|
|||
|
||||
float * get_logits();
|
||||
float * get_logits_ith(int32_t i);
|
||||
float * get_mtp_logits();
|
||||
|
||||
float * get_embeddings();
|
||||
float * get_embeddings_ith(int32_t i);
|
||||
|
|
@ -268,6 +269,10 @@ private:
|
|||
// decode output (2-dimensional array: [n_outputs][n_vocab])
|
||||
buffer_view<float> logits = {nullptr, 0};
|
||||
|
||||
// MTP draft logits (1-dimensional array: [n_vocab])
|
||||
std::vector<float> mtp_logits_buf;
|
||||
bool mtp_logits_valid = false;
|
||||
|
||||
// embeddings output (2-dimensional array: [n_outputs][n_embd])
|
||||
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
|
||||
buffer_view<float> embd = {nullptr, 0};
|
||||
|
|
|
|||
|
|
@ -662,6 +662,10 @@ public:
|
|||
ggml_tensor * t_embd = nullptr;
|
||||
ggml_tensor * t_embd_pooled = nullptr;
|
||||
|
||||
// MTP (Multi-Token Prediction) output nodes
|
||||
ggml_tensor * t_logits_mtp = nullptr; // [n_vocab, n_tokens] draft logits from MTP head
|
||||
ggml_tensor * t_embd_mtp = nullptr; // [n_embd, n_tokens] hidden state from MTP head
|
||||
|
||||
std::map<llama_seq_id, ggml_tensor*> t_sampled_logits;
|
||||
std::map<llama_seq_id, ggml_tensor*> t_candidates;
|
||||
std::map<llama_seq_id, ggml_tensor*> t_sampled;
|
||||
|
|
|
|||
|
|
@ -163,12 +163,41 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
|
|||
const auto & cell = cells[tail_id];
|
||||
// partial intersection is invalid if it includes the final pos
|
||||
if (0 < p0 && p0 <= cell.pos && p1 > cell.pos) {
|
||||
//printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false, p0 = %d, cell.pos = %d, p1 = %d\n", p0, cell.pos, p1);
|
||||
return false;
|
||||
// for speculative decoding, we search for a checkpoint in the history
|
||||
int32_t best_cell = -1;
|
||||
for (uint32_t i = 0; i < size; ++i) {
|
||||
if (cells[i].has_seq_id(seq_id) && cells[i].pos == p0 - 1) {
|
||||
best_cell = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (best_cell >= 0) {
|
||||
tail_id = best_cell;
|
||||
} else {
|
||||
// No checkpoint found at p0-1: SSM tensor state cannot be rolled back
|
||||
// without re-evaluating the sequence. Signal failure to the caller.
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// invalidate tails which will be cleared
|
||||
if (p0 <= cell.pos && cell.pos < p1) {
|
||||
tail_id = -1;
|
||||
if (p0 == 0) {
|
||||
tail_id = -1;
|
||||
} else {
|
||||
// Search for the best remaining cell after removal
|
||||
int32_t new_tail = -1;
|
||||
llama_pos max_pos = -1;
|
||||
for (uint32_t i = 0; i < size; ++i) {
|
||||
if (cells[i].has_seq_id(seq_id) && cells[i].pos < p0) {
|
||||
if (cells[i].pos > max_pos) {
|
||||
max_pos = cells[i].pos;
|
||||
new_tail = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
tail_id = new_tail;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
|
@ -184,6 +213,11 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
|
|||
if (seq_id < 0) {
|
||||
cells[i].seq_id.clear();
|
||||
} else if (cells[i].has_seq_id(seq_id)) {
|
||||
if (p0 > 0 && p1 == std::numeric_limits<llama_pos>::max()) {
|
||||
// partial removal: just move the position back
|
||||
cells[i].pos = p0 - 1;
|
||||
continue;
|
||||
}
|
||||
cells[i].seq_id.erase(seq_id);
|
||||
} else {
|
||||
continue;
|
||||
|
|
@ -224,25 +258,42 @@ void llama_memory_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id
|
|||
}
|
||||
|
||||
if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) {
|
||||
auto & tail_src = cells[seq_id_src];
|
||||
auto & tail_dst = cells[seq_id_dst];
|
||||
if (tail_dst.tail >= 0) {
|
||||
auto & tail_src_meta = cells[seq_id_src];
|
||||
auto & tail_dst_meta = cells[seq_id_dst];
|
||||
|
||||
if (tail_dst_meta.tail >= 0) {
|
||||
// clear destination seq_id if it wasn't empty
|
||||
auto & cell_dst = cells[tail_dst.tail];
|
||||
|
||||
cell_dst.seq_id.erase(seq_id_dst);
|
||||
tail_dst.tail = -1;
|
||||
if (cell_dst.seq_id.empty()) {
|
||||
cell_dst.pos = -1;
|
||||
cell_dst.src = -1;
|
||||
used -= 1;
|
||||
}
|
||||
seq_rm(seq_id_dst, -1, -1);
|
||||
}
|
||||
if (tail_src.tail >= 0) {
|
||||
auto & cell_src = cells[tail_src.tail];
|
||||
|
||||
cell_src.seq_id.insert(seq_id_dst);
|
||||
tail_dst.tail = tail_src.tail;
|
||||
if (tail_src_meta.tail >= 0) {
|
||||
auto & cell_src = cells[tail_src_meta.tail];
|
||||
|
||||
// For recurrent models, we must copy the state to a new cell
|
||||
// Otherwise, both sequences would share the same mutable state
|
||||
uint32_t next_empty_cell = size;
|
||||
for (uint32_t i = head; i < head + size; ++i) {
|
||||
uint32_t idx = i % size;
|
||||
if (cells[idx].is_empty()) {
|
||||
next_empty_cell = idx;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (next_empty_cell != size) {
|
||||
auto & empty_cell = cells[next_empty_cell];
|
||||
|
||||
// Copy tensors data
|
||||
copy_cell(tail_src_meta.tail, next_empty_cell);
|
||||
|
||||
empty_cell.pos = cell_src.pos;
|
||||
empty_cell.src = next_empty_cell; // results in a copy in the graph if needed
|
||||
empty_cell.seq_id.insert(seq_id_dst);
|
||||
tail_dst_meta.tail = next_empty_cell;
|
||||
used += 1;
|
||||
} else {
|
||||
LLAMA_LOG_ERROR("%s: failed to find available cell for copy\n", __func__);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -367,6 +418,47 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
|
|||
return result;
|
||||
}
|
||||
|
||||
void llama_memory_recurrent::copy_cell(int32_t i_src, int32_t i_dst) {
|
||||
if (i_src == i_dst || i_src < 0 || i_dst < 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
ggml_init_params params = {
|
||||
/*.mem_size =*/ size_t(2*ggml_tensor_overhead()),
|
||||
/*.mem_buffer =*/ NULL,
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
|
||||
for (uint32_t il = 0; il < hparams.n_layer; ++il) {
|
||||
if (r_l[il]) {
|
||||
ggml_context * ctx = ggml_init(params);
|
||||
size_t r_row_size = ggml_row_size(r_l[il]->type, hparams.n_embd_r());
|
||||
ggml_tensor * src_v = ggml_view_1d(ctx, r_l[il], r_row_size, i_src * r_row_size);
|
||||
ggml_tensor * dst_v = ggml_view_1d(ctx, r_l[il], r_row_size, i_dst * r_row_size);
|
||||
ggml_backend_tensor_copy(src_v, dst_v);
|
||||
ggml_free(ctx);
|
||||
}
|
||||
if (s_l[il]) {
|
||||
ggml_context * ctx = ggml_init(params);
|
||||
size_t s_row_size = ggml_row_size(s_l[il]->type, hparams.n_embd_s());
|
||||
ggml_tensor * src_v = ggml_view_1d(ctx, s_l[il], s_row_size, i_src * s_row_size);
|
||||
ggml_tensor * dst_v = ggml_view_1d(ctx, s_l[il], s_row_size, i_dst * s_row_size);
|
||||
ggml_backend_tensor_copy(src_v, dst_v);
|
||||
ggml_free(ctx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int llama_memory_recurrent::get_cell_count(llama_seq_id seq_id) const {
|
||||
int count = 0;
|
||||
for (uint32_t i = 0; i < size; ++i) {
|
||||
if (cells[i].has_seq_id(seq_id)) {
|
||||
count++;
|
||||
}
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
std::map<ggml_backend_buffer_type_t, size_t> llama_memory_recurrent::memory_breakdown() const {
|
||||
std::map<ggml_backend_buffer_type_t, size_t> ret;
|
||||
for (const auto & [_, buf] : ctxs_bufs) {
|
||||
|
|
@ -551,10 +643,35 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
|
|||
if (seq_meta.tail >= 0) {
|
||||
auto & orig_cell = cells[seq_meta.tail];
|
||||
empty_cell.pos = orig_cell.pos;
|
||||
empty_cell.src = orig_cell.src;
|
||||
orig_cell.seq_id.erase(seq_id);
|
||||
empty_cell.src = seq_meta.tail; // the data should be copied from the previous tail
|
||||
|
||||
// Copy state data
|
||||
copy_cell(seq_meta.tail, next_empty_cell);
|
||||
|
||||
// Keep history of previous states for rollback (up to 8 cells per sequence)
|
||||
if (get_cell_count(seq_id) < 8 && used < size * 0.9) {
|
||||
// Do not erase seq_id from orig_cell to keep it as a checkpoint
|
||||
} else {
|
||||
// Erase oldest history point for this sequence
|
||||
int32_t oldest_cell = -1;
|
||||
llama_pos min_pos = std::numeric_limits<llama_pos>::max();
|
||||
for (uint32_t i = 0; i < size; ++i) {
|
||||
if (cells[i].has_seq_id(seq_id) && cells[i].pos < min_pos) {
|
||||
min_pos = cells[i].pos;
|
||||
oldest_cell = i;
|
||||
}
|
||||
}
|
||||
|
||||
if (oldest_cell >= 0) {
|
||||
cells[oldest_cell].seq_id.erase(seq_id);
|
||||
if (cells[oldest_cell].is_empty()) {
|
||||
cells[oldest_cell].pos = -1;
|
||||
cells[oldest_cell].src = -1;
|
||||
used--;
|
||||
}
|
||||
}
|
||||
}
|
||||
empty_cell.seq_id.insert(seq_id); // will be overwritten
|
||||
GGML_ASSERT(!orig_cell.is_empty()); // has at least one remaining seq_id
|
||||
}
|
||||
seq_meta.tail = next_empty_cell;
|
||||
// find next empty cell
|
||||
|
|
@ -566,6 +683,51 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
|
|||
if (cell.is_empty()) { break; }
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Sequence owns its cell. Save a checkpoint of the current state before it is
|
||||
// overwritten by new tokens. This is required for speculative decoding rollback
|
||||
// in recurrent/SSM models where tensor state cannot be partially rewound.
|
||||
const int32_t cur_tail = seq_meta.tail;
|
||||
if (cells[next_empty_cell].is_empty()) {
|
||||
bool can_checkpoint = (get_cell_count(seq_id) < 8 && used < size * 0.9);
|
||||
if (!can_checkpoint) {
|
||||
// Try to evict the oldest checkpoint to make room
|
||||
int32_t oldest = -1;
|
||||
llama_pos min_pos = std::numeric_limits<llama_pos>::max();
|
||||
for (uint32_t j = 0; j < size; ++j) {
|
||||
if ((int32_t)j != cur_tail && cells[j].has_seq_id(seq_id) && cells[j].pos < min_pos) {
|
||||
min_pos = cells[j].pos;
|
||||
oldest = j;
|
||||
}
|
||||
}
|
||||
if (oldest >= 0) {
|
||||
cells[oldest].seq_id.erase(seq_id);
|
||||
if (cells[oldest].is_empty()) {
|
||||
cells[oldest].pos = -1;
|
||||
cells[oldest].src = -1;
|
||||
used--;
|
||||
}
|
||||
can_checkpoint = true;
|
||||
}
|
||||
}
|
||||
if (can_checkpoint) {
|
||||
auto & cp_cell = cells[next_empty_cell];
|
||||
copy_cell(cur_tail, next_empty_cell);
|
||||
cp_cell.pos = cells[cur_tail].pos;
|
||||
cp_cell.src = next_empty_cell; // independent copy, no further movement needed
|
||||
cp_cell.seq_id.insert(seq_id);
|
||||
used++;
|
||||
// advance next_empty_cell for subsequent sequences in this batch
|
||||
if (s + 1 < n_seqs) {
|
||||
for (uint32_t j = 0; j < size; ++j) {
|
||||
next_empty_cell += 1;
|
||||
if (next_empty_cell >= size) { next_empty_cell -= size; }
|
||||
if (cells[next_empty_cell].is_empty()) { break; }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// seq_meta.tail remains unchanged - sequence still owns its current cell
|
||||
}
|
||||
if (min > seq_meta.tail) { min = seq_meta.tail; }
|
||||
if (max < seq_meta.tail) { max = seq_meta.tail; }
|
||||
|
|
|
|||
|
|
@ -65,6 +65,10 @@ public:
|
|||
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
|
||||
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
|
||||
|
||||
// cell management
|
||||
void copy_cell(int32_t i_src, int32_t i_dst);
|
||||
int get_cell_count(llama_seq_id seq_id) const;
|
||||
|
||||
uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
|
||||
uint32_t size = 0; // total number of cells, shared across all sequences
|
||||
uint32_t used = 0; // used cells (i.e. at least one seq_id)
|
||||
|
|
|
|||
|
|
@ -2403,16 +2403,29 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
|
||||
ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group);
|
||||
|
||||
// Mark recurrent layers (linear attention layers)
|
||||
// NextN/MTP parameters
|
||||
ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false);
|
||||
|
||||
// The total n_layer includes MTP layers appended after main layers.
|
||||
// Determine the number of main transformer layers for type detection.
|
||||
const uint32_t n_main_layers = hparams.n_layer - hparams.nextn_predict_layers;
|
||||
|
||||
// Mark recurrent layers (linear attention layers) — main layers only
|
||||
// MTP layers use full attention, so they are NOT recurrent
|
||||
{
|
||||
uint32_t full_attn_interval = 4;
|
||||
ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false);
|
||||
for (uint32_t i = 0; i < hparams.n_layer; ++i) {
|
||||
hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0);
|
||||
if (i < n_main_layers) {
|
||||
hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0);
|
||||
} else {
|
||||
// MTP layers use full attention (not recurrent)
|
||||
hparams.recurrent_layer_arr[i] = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
switch (n_main_layers) {
|
||||
case 24: type = hparams.n_embd == 1024 ? LLM_TYPE_0_8B : LLM_TYPE_2B; break;
|
||||
case 32: type = hparams.n_embd == 2560 ? LLM_TYPE_4B : LLM_TYPE_9B; break;
|
||||
case 64: type = LLM_TYPE_27B; break;
|
||||
|
|
@ -7272,39 +7285,67 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|||
const int64_t value_dim = head_v_dim * n_v_heads;
|
||||
const int64_t conv_dim = key_dim * 2 + value_dim;
|
||||
|
||||
const uint32_t n_main_layers = n_layer - hparams.nextn_predict_layers;
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = layers[i];
|
||||
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
|
||||
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0);
|
||||
const bool is_mtp_layer = (static_cast<uint32_t>(i) >= n_main_layers);
|
||||
|
||||
if (!hparams.is_recurrent(i)) {
|
||||
// Attention layers
|
||||
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head * 2 }, 0);
|
||||
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0);
|
||||
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0);
|
||||
if (is_mtp_layer) {
|
||||
// MTP layer: nextn-specific tensors + standard attention + standard FFN
|
||||
layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, 0);
|
||||
layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, 0);
|
||||
layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, 0);
|
||||
layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED);
|
||||
layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
|
||||
layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
|
||||
|
||||
// Q/K normalization for attention layers
|
||||
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0);
|
||||
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0);
|
||||
// MTP layer uses same gated attention as main model (joint QG projection)
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
|
||||
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0);
|
||||
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head * 2 }, 0);
|
||||
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0);
|
||||
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0);
|
||||
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0);
|
||||
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0);
|
||||
|
||||
// MTP layer uses standard dense FFN
|
||||
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||
} else {
|
||||
// Linear attention (gated delta net) specific tensors
|
||||
// Create tensors with calculated dimensions
|
||||
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED);
|
||||
layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED);
|
||||
layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0);
|
||||
layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0);
|
||||
layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0);
|
||||
layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", i), { n_embd, n_v_heads }, 0);
|
||||
layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", i), { n_embd, n_v_heads }, 0);
|
||||
layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0);
|
||||
layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0);
|
||||
}
|
||||
// Main transformer layers
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
|
||||
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0);
|
||||
|
||||
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||
if (!hparams.is_recurrent(i)) {
|
||||
// Full attention layers (joint QG projection + gated attention)
|
||||
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head * 2 }, 0);
|
||||
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0);
|
||||
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0);
|
||||
|
||||
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0);
|
||||
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0);
|
||||
} else {
|
||||
// Linear attention (gated delta net) specific tensors
|
||||
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED);
|
||||
layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED);
|
||||
layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0);
|
||||
layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0);
|
||||
layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0);
|
||||
layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", i), { n_embd, n_v_heads }, 0);
|
||||
layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", i), { n_embd, n_v_heads }, 0);
|
||||
layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0);
|
||||
layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0);
|
||||
}
|
||||
|
||||
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||
}
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_MIMO2:
|
||||
|
|
@ -8755,6 +8796,10 @@ int32_t llama_model_n_swa(const llama_model * model) {
|
|||
return model->hparams.n_swa;
|
||||
}
|
||||
|
||||
int32_t llama_model_n_mtp_layers(const llama_model * model) {
|
||||
return model->hparams.nextn_predict_layers;
|
||||
}
|
||||
|
||||
uint32_t llama_model_n_cls_out(const struct llama_model * model) {
|
||||
return model->hparams.n_cls_out;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -597,6 +597,9 @@ private:
|
|||
ggml_tensor * input,
|
||||
int il);
|
||||
|
||||
// Build the MTP (Multi-Token Prediction) head with standard transformer block
|
||||
void build_mtp_head(llm_graph_input_mem_hybrid * inp, ggml_tensor * inp_pos, int * sections);
|
||||
|
||||
const llama_model & model;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -23,7 +23,10 @@ llm_build_qwen35::llm_build_qwen35(const llama_model & model, const llm_graph_pa
|
|||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
// Only process main transformer layers (skip MTP layers appended at the end)
|
||||
const int n_transformer_layers = n_layer - hparams.nextn_predict_layers;
|
||||
|
||||
for (int il = 0; il < n_transformer_layers; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
|
||||
cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il);
|
||||
|
|
@ -40,7 +43,7 @@ llm_build_qwen35::llm_build_qwen35(const llama_model & model, const llm_graph_pa
|
|||
cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
if (il == n_transformer_layers - 1 && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
|
|
@ -82,6 +85,11 @@ llm_build_qwen35::llm_build_qwen35(const llama_model & model, const llm_graph_pa
|
|||
res->t_logits = cur;
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
// Build MTP head if nextn_predict_layers > 0
|
||||
if (hparams.nextn_predict_layers > 0) {
|
||||
build_mtp_head(inp, inp_pos, sections);
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen35::build_qkvz(
|
||||
|
|
@ -379,3 +387,124 @@ ggml_tensor * llm_build_qwen35::build_layer_ffn(ggml_tensor * cur, const int il)
|
|||
|
||||
return cur;
|
||||
}
|
||||
|
||||
void llm_build_qwen35::build_mtp_head(
|
||||
llm_graph_input_mem_hybrid * inp,
|
||||
ggml_tensor * inp_pos,
|
||||
int * sections) {
|
||||
// MTP (Multi-Token Prediction) head for dense Qwen 3.5
|
||||
//
|
||||
// The MTP module takes the hidden state from the last main transformer layer
|
||||
// and uses the model's built-in MTP head to produce draft logits.
|
||||
//
|
||||
// MTP forward pass:
|
||||
// 1. sampled_token = argmax(main_logits)
|
||||
// 2. emb = embed_tokens(sampled_token)
|
||||
// 3. h_norm = RMSNorm(hidden_state, hnorm)
|
||||
// 4. e_norm = RMSNorm(emb, enorm)
|
||||
// 5. combined = eh_proj(concat(e_norm, h_norm))
|
||||
// 6. Standard self-attention (Q/K/V with Q/K norms + RoPE)
|
||||
// 7. Standard FFN (gate_proj + up_proj → SiLU → down_proj)
|
||||
// 8. logits = lm_head(RMSNorm(output, mtp_norm))
|
||||
|
||||
const int n_transformer_layers = n_layer - hparams.nextn_predict_layers;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
// Main model logits and hidden state
|
||||
ggml_tensor * main_logits = res->t_logits; // [n_vocab, n_tokens]
|
||||
ggml_tensor * hidden_state = res->t_embd; // [n_embd, n_tokens] after final norm
|
||||
GGML_ASSERT(main_logits != nullptr);
|
||||
|
||||
// In-graph greedy token selection
|
||||
ggml_tensor * greedy_tokens = ggml_argmax(ctx0, main_logits); // [n_tokens]
|
||||
cb(greedy_tokens, "mtp_greedy_tokens", -1);
|
||||
|
||||
ggml_tensor * mtp_hidden = hidden_state;
|
||||
|
||||
for (uint32_t k = 0; k < hparams.nextn_predict_layers; ++k) {
|
||||
const int il = n_transformer_layers + k;
|
||||
const auto & layer = model.layers[il];
|
||||
|
||||
if (layer.nextn.eh_proj == nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Step 1: Get token embedding (shared with main model)
|
||||
ggml_tensor * tok_embd = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd;
|
||||
ggml_tensor * emb = ggml_get_rows(ctx0, tok_embd, greedy_tokens);
|
||||
cb(emb, "mtp_token_embd", il);
|
||||
|
||||
// Step 2: Normalize hidden state and embedding
|
||||
ggml_tensor * h_norm = build_norm(mtp_hidden, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il);
|
||||
cb(h_norm, "mtp_hnorm", il);
|
||||
|
||||
ggml_tensor * e_norm = build_norm(emb, layer.nextn.enorm, nullptr, LLM_NORM_RMS, il);
|
||||
cb(e_norm, "mtp_enorm", il);
|
||||
|
||||
// Step 3: Concatenate and project
|
||||
ggml_tensor * concat = ggml_concat(ctx0, e_norm, h_norm, 0); // [2*n_embd, n_tokens]
|
||||
cb(concat, "mtp_concat", il);
|
||||
|
||||
ggml_tensor * cur = build_lora_mm(layer.nextn.eh_proj, concat);
|
||||
cb(cur, "mtp_projected", il);
|
||||
|
||||
// Step 4: Skip attention for MTP head (MTP operates on already-contextualized hidden states
|
||||
// from the main model's final layer, so the projection + FFN path is sufficient for draft
|
||||
// token generation. Full MTP attention requires resolving inp_out_ids token count mismatch.)
|
||||
// TODO: implement MTP-specific attention that handles filtered token counts
|
||||
{
|
||||
// The projection through eh_proj already combines embedding + hidden context
|
||||
// Just pass through with a norm (attention weights are loaded but unused for now)
|
||||
(void)inp_pos;
|
||||
(void)sections;
|
||||
}
|
||||
|
||||
// Step 5: Post-attention norm + FFN
|
||||
{
|
||||
ggml_tensor * ffn_residual = cur;
|
||||
|
||||
ggml_tensor * attn_post_norm = build_norm(cur, layer.attn_post_norm, nullptr, LLM_NORM_RMS, il);
|
||||
cb(attn_post_norm, "mtp_attn_post_norm", il);
|
||||
|
||||
// Standard dense FFN (same as main model FFN)
|
||||
cur = build_ffn(attn_post_norm,
|
||||
layer.ffn_up, NULL, layer.ffn_up_s,
|
||||
layer.ffn_gate, NULL, layer.ffn_gate_s,
|
||||
layer.ffn_down, NULL, layer.ffn_down_s,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||
cb(cur, "mtp_ffn_out", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_residual);
|
||||
cb(cur, "mtp_post_ffn", il);
|
||||
}
|
||||
|
||||
mtp_hidden = cur;
|
||||
|
||||
// Step 6: Final norm + LM head for draft logits
|
||||
ggml_tensor * mtp_normed;
|
||||
if (layer.nextn.shared_head_norm != nullptr) {
|
||||
mtp_normed = build_norm(mtp_hidden, layer.nextn.shared_head_norm, nullptr, LLM_NORM_RMS, il);
|
||||
} else {
|
||||
// Use main model's output norm
|
||||
mtp_normed = build_norm(mtp_hidden, model.output_norm, nullptr, LLM_NORM_RMS, il);
|
||||
}
|
||||
cb(mtp_normed, "mtp_head_norm", il);
|
||||
|
||||
ggml_tensor * lm_head = layer.nextn.shared_head_head ? layer.nextn.shared_head_head : model.output;
|
||||
ggml_tensor * mtp_logits = build_lora_mm(lm_head, mtp_normed);
|
||||
cb(mtp_logits, "mtp_logits", il);
|
||||
|
||||
// Store MTP outputs in graph result
|
||||
res->t_embd_mtp = mtp_hidden;
|
||||
res->t_logits_mtp = mtp_logits;
|
||||
|
||||
// For recursive MTP (multiple layers), feed greedy tokens forward
|
||||
if (k + 1 < hparams.nextn_predict_layers) {
|
||||
greedy_tokens = ggml_argmax(ctx0, mtp_logits);
|
||||
cb(greedy_tokens, "mtp_greedy_next", il);
|
||||
}
|
||||
|
||||
ggml_build_forward_expand(gf, mtp_logits);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -752,11 +752,22 @@ private:
|
|||
|
||||
slots.clear();
|
||||
|
||||
const bool can_spec = common_speculative_is_compat(ctx);
|
||||
bool can_spec = common_speculative_is_compat(ctx);
|
||||
if (!can_spec) {
|
||||
SRV_WRN("%s", "speculative decoding not supported by this context\n");
|
||||
}
|
||||
|
||||
// Auto-detect MTP capability — log presence but don't enable speculative
|
||||
// decoding framework. The hybrid SSM + M-RoPE architecture is incompatible
|
||||
// with the speculative verify loop when tool-calling is active.
|
||||
// MTP tensors are still computed in the forward pass graph (build_mtp_head).
|
||||
if (params_base.speculative.type == COMMON_SPECULATIVE_TYPE_NONE) {
|
||||
const int32_t n_mtp = llama_model_n_mtp_layers(llama_get_model(ctx));
|
||||
if (n_mtp > 0) {
|
||||
SRV_INF("model has %d MTP layer(s) (graph-only, speculative verify disabled for hybrid models)\n", n_mtp);
|
||||
}
|
||||
}
|
||||
|
||||
// initialize slots
|
||||
for (int i = 0; i < params_base.n_parallel; i++) {
|
||||
server_slot slot;
|
||||
|
|
|
|||
Loading…
Reference in New Issue