glm-moe: allow skipping MTP tensor loading to save VRAM
Adds a new `mtp` boolean to `llama_model_params`. When set to false (default): 1. The loader skips loading MTP-specific tensors (NextN layers) using `TENSOR_SKIP`. 2. The KV cache size calculation excludes the MTP layer (`n_layer_kv_from_start`). This reduces VRAM usage and load time for users running GLM-4.5/4.6 in standard generation mode.
This commit is contained in:
parent
d9576dd037
commit
a3e29da02a
|
|
@ -1351,6 +1351,7 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
|
|||
mparams.check_tensors = params.check_tensors;
|
||||
mparams.use_extra_bufts = !params.no_extra_bufts;
|
||||
mparams.no_host = params.no_host;
|
||||
mparams.mtp = params.mtp;
|
||||
|
||||
if (params.kv_overrides.empty()) {
|
||||
mparams.kv_overrides = NULL;
|
||||
|
|
|
|||
|
|
@ -430,7 +430,7 @@ struct common_params {
|
|||
bool no_op_offload = false; // globally disable offload host tensor operations to device
|
||||
bool no_extra_bufts = false; // disable extra buffer types (used for weight repacking)
|
||||
bool no_host = false; // bypass host buffer allowing extra buffers to be used
|
||||
bool mtp = false; // use mtp is supported
|
||||
bool mtp = false; // enable MTP if supported by the model
|
||||
|
||||
bool single_turn = false; // single turn chat conversation
|
||||
|
||||
|
|
|
|||
|
|
@ -326,6 +326,7 @@ extern "C" {
|
|||
bool use_extra_bufts; // use extra buffer types (used for weight repacking)
|
||||
bool no_host; // bypass host buffer allowing extra buffers to be used
|
||||
bool no_alloc; // only load metadata and simulate memory allocations
|
||||
bool mtp; // use mtp if is supported by the Model
|
||||
};
|
||||
|
||||
// NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations
|
||||
|
|
|
|||
|
|
@ -337,4 +337,4 @@ private:
|
|||
mutable int32_t n_eval = 0; // number of eval calls
|
||||
|
||||
mutable int32_t n_reused = 0; // number of times the previous graph was reused
|
||||
};
|
||||
};
|
||||
|
|
@ -860,4 +860,4 @@ struct llm_graph_context {
|
|||
};
|
||||
|
||||
// TODO: better name
|
||||
int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional);
|
||||
int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional);
|
||||
|
|
@ -1720,8 +1720,14 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
// NextN/MTP parameters
|
||||
ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false);
|
||||
|
||||
hparams.n_layer_kv_from_start = hparams.n_layer;
|
||||
|
||||
if (params.mtp) {
|
||||
// Include MTP layers in KV cache if MTP is enabled
|
||||
hparams.n_layer_kv_from_start = hparams.n_layer;
|
||||
}
|
||||
else {
|
||||
// Otherwise exclude to save memory
|
||||
hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers;
|
||||
}
|
||||
switch (hparams.n_layer) {
|
||||
case 47: type = LLM_TYPE_106B_A12B; break; // GLM-4.5-Air (46 layers + 1 NextN layer)
|
||||
case 93: type = LLM_TYPE_355B_A32B; break; // GLM-4.5 (92 layers + 1 NextN layer)
|
||||
|
|
@ -5050,9 +5056,15 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|||
}
|
||||
|
||||
// Load ALL tensors including NextN layer to satisfy total tensor count
|
||||
// but only PROCESS up to last layer (skipping final NextN layer) in forward pass
|
||||
// but skip loading data for NextN layers if MTP is disabled to save VRAM
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
int flags = 0;
|
||||
// Skip loading MTP layers if the feature is disabled
|
||||
if (!params.mtp) {
|
||||
if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) {
|
||||
flags |= TENSOR_SKIP;
|
||||
}
|
||||
}
|
||||
|
||||
auto & layer = layers[i];
|
||||
|
||||
|
|
@ -7673,6 +7685,7 @@ llama_model_params llama_model_default_params() {
|
|||
/*.use_extra_bufts =*/ true,
|
||||
/*.no_host =*/ false,
|
||||
/*.no_alloc =*/ false,
|
||||
/*.mtp =*/ false,
|
||||
};
|
||||
|
||||
return result;
|
||||
|
|
|
|||
Loading…
Reference in New Issue