From a3e29da02add9759a24f1daf39a275427545f434 Mon Sep 17 00:00:00 2001 From: samuel Date: Fri, 19 Dec 2025 20:41:35 -0300 Subject: [PATCH] glm-moe: allow skipping MTP tensor loading to save VRAM Adds a new `mtp` boolean to `llama_model_params`. When set to false (default): 1. The loader skips loading MTP-specific tensors (NextN layers) using `TENSOR_SKIP`. 2. The KV cache size calculation excludes the MTP layer (`n_layer_kv_from_start`). This reduces VRAM usage and load time for users running GLM-4.5/4.6 in standard generation mode. --- common/common.cpp | 1 + common/common.h | 2 +- include/llama.h | 1 + src/llama-context.h | 2 +- src/llama-graph.h | 2 +- src/llama-model.cpp | 19 ++++++++++++++++--- 6 files changed, 21 insertions(+), 6 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index d4e8c7405e..7c1297574a 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1351,6 +1351,7 @@ struct llama_model_params common_model_params_to_llama(common_params & params) { mparams.check_tensors = params.check_tensors; mparams.use_extra_bufts = !params.no_extra_bufts; mparams.no_host = params.no_host; + mparams.mtp = params.mtp; if (params.kv_overrides.empty()) { mparams.kv_overrides = NULL; diff --git a/common/common.h b/common/common.h index 6c2f1dc686..40d7689872 100644 --- a/common/common.h +++ b/common/common.h @@ -430,7 +430,7 @@ struct common_params { bool no_op_offload = false; // globally disable offload host tensor operations to device bool no_extra_bufts = false; // disable extra buffer types (used for weight repacking) bool no_host = false; // bypass host buffer allowing extra buffers to be used - bool mtp = false; // use mtp is supported + bool mtp = false; // enable MTP if supported by the model bool single_turn = false; // single turn chat conversation diff --git a/include/llama.h b/include/llama.h index ce44c2d345..0428a9085b 100644 --- a/include/llama.h +++ b/include/llama.h @@ -326,6 +326,7 @@ extern "C" { bool use_extra_bufts; // use extra buffer types (used for weight repacking) bool no_host; // bypass host buffer allowing extra buffers to be used bool no_alloc; // only load metadata and simulate memory allocations + bool mtp; // use mtp if is supported by the Model }; // NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations diff --git a/src/llama-context.h b/src/llama-context.h index 3bf3483b08..392796f7a3 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -337,4 +337,4 @@ private: mutable int32_t n_eval = 0; // number of eval calls mutable int32_t n_reused = 0; // number of times the previous graph was reused -}; +}; \ No newline at end of file diff --git a/src/llama-graph.h b/src/llama-graph.h index 0a0de82c48..0c9d7d1508 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -860,4 +860,4 @@ struct llm_graph_context { }; // TODO: better name -int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional); +int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional); \ No newline at end of file diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 36378440e6..cef7a0bd96 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1720,8 +1720,14 @@ void llama_model::load_hparams(llama_model_loader & ml) { // NextN/MTP parameters ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); - hparams.n_layer_kv_from_start = hparams.n_layer; - + if (params.mtp) { + // Include MTP layers in KV cache if MTP is enabled + hparams.n_layer_kv_from_start = hparams.n_layer; + } + else { + // Otherwise exclude to save memory + hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; + } switch (hparams.n_layer) { case 47: type = LLM_TYPE_106B_A12B; break; // GLM-4.5-Air (46 layers + 1 NextN layer) case 93: type = LLM_TYPE_355B_A32B; break; // GLM-4.5 (92 layers + 1 NextN layer) @@ -5050,9 +5056,15 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } // Load ALL tensors including NextN layer to satisfy total tensor count - // but only PROCESS up to last layer (skipping final NextN layer) in forward pass + // but skip loading data for NextN layers if MTP is disabled to save VRAM for (int i = 0; i < n_layer; ++i) { int flags = 0; + // Skip loading MTP layers if the feature is disabled + if (!params.mtp) { + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + flags |= TENSOR_SKIP; + } + } auto & layer = layers[i]; @@ -7673,6 +7685,7 @@ llama_model_params llama_model_default_params() { /*.use_extra_bufts =*/ true, /*.no_host =*/ false, /*.no_alloc =*/ false, + /*.mtp =*/ false, }; return result;