llama.cpp/src/models/glm4.cpp

234 lines
9.1 KiB
C++

#include "models.h"
template <>
llm_build_glm4<LLM_GRAPH_TYPE_DECODER>::llm_build_glm4(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const bool use_mrope = hparams.use_mrope();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
int sections[4];
std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
ggml_tensor * cur;
ggml_tensor * inpL;
inpL = build_inp_embd(model.tok_embd);
if (ubatch.embd && !use_mrope) {
// unfortunately, we need to forcefully stop here, to avoid users complaining about wrong results
GGML_ABORT("This GGUF does not support multimodal. Please reconvert it.");
}
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * inp_out_ids = build_inp_out_ids();
// Only process up to last layer (skip final NextN layer)
// Final layer tensors are loaded but not processed in forward pass
const int n_transformer_layers = n_layer - hparams.nextn_predict_layers;
for (int il = 0; il < n_transformer_layers; ++il) {
// input for next layer
bool is_output_layer = (il == n_transformer_layers - 1);
inpL = build_layer(model, inp_attn, inpL, inp_pos, inp_out_ids, sections, is_output_layer, il);
}
cur = inpL;
cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
cb(cur, "result_norm", -1);
res->t_embd = cur;
// Output projection
cur = build_lora_mm(model.output, cur);
cb(cur, "result_output", -1);
res->t_logits = cur;
ggml_build_forward_expand(gf, cur);
}
// MTP model
template <>
llm_build_glm4<LLM_GRAPH_TYPE_DECODER_MTP>::llm_build_glm4(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
int sections[4];
std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
ggml_tensor * cur;
ggml_tensor * inpL;
// for now, we only support one single NextN layer for simplicity
GGML_ASSERT(hparams.nextn_predict_layers == 1);
const int il = n_layer - hparams.nextn_predict_layers;
auto & mtp_layer = model.layers[il];
ggml_tensor * inp_token_embd = build_inp_embd(mtp_layer.nextn.embed_tokens // can be nullptr on GLM-4.6
? mtp_layer.nextn.embed_tokens : model.tok_embd);
ggml_tensor * inp_state_embd = build_inp_cross_mtp();
// check number of input tokens
GGML_ASSERT(inp_state_embd->ne[1] == inp_token_embd->ne[1]);
inp_token_embd = build_norm(inp_token_embd, mtp_layer.nextn.enorm, NULL, LLM_NORM_RMS, il);
inp_state_embd = build_norm(inp_state_embd, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il);
inpL = ggml_concat(ctx0, inp_token_embd, inp_state_embd, 0);
cb(inpL, "inp_mtp", il);
inpL = build_lora_mm(mtp_layer.nextn.eh_proj, inpL);
cb(inpL, "inp_mtp_projected", il);
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * inp_out_ids = build_inp_out_ids();
{
bool is_output_layer = true; // TODO: we only have one single nextn layer for now, may need to change in the future
inpL = build_layer(model, inp_attn, inpL, inp_pos, inp_out_ids, sections, is_output_layer, il);
}
cur = inpL;
cur = build_norm(cur, mtp_layer.nextn.shared_head_norm // can be nullptr on GLM-4.6
? mtp_layer.nextn.shared_head_norm : model.output_norm, NULL, LLM_NORM_RMS, -1);
cb(cur, "result_norm", -1);
res->t_embd = cur;
// lm_head
cur = build_lora_mm(mtp_layer.nextn.shared_head_head // can be nullptr on GLM-4.6
? mtp_layer.nextn.shared_head_head : model.output, cur);
cb(cur, "result_output", -1);
res->t_logits = cur;
ggml_build_forward_expand(gf, cur);
}
template <llm_graph_type graph_type>
ggml_tensor * llm_build_glm4<graph_type>::build_layer(const llama_model & model,
llm_graph_input_attn_kv * inp_attn,
ggml_tensor * inpL,
ggml_tensor * inp_pos,
ggml_tensor * inp_out_ids,
int sections[4],
bool is_output_layer,
int il) {
bool use_mrope = hparams.use_mrope();
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
ggml_tensor * inpSA = inpL;
// Pre-attention norm
ggml_tensor * cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
cb(cur, "attn_norm", il);
// self-attention
{
ggml_tensor * Qcur = nullptr;
ggml_tensor * Kcur = nullptr;
ggml_tensor * Vcur = nullptr;
if (model.layers[il].wqkv == nullptr) {
Qcur = build_lora_mm(model.layers[il].wq, cur);
if (model.layers[il].bq) {
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
}
Kcur = build_lora_mm(model.layers[il].wk, cur);
if (model.layers[il].bk) {
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
}
Vcur = build_lora_mm(model.layers[il].wv, cur);
if (model.layers[il].bv) {
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
}
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
} else {
cur = build_lora_mm(model.layers[il].wqkv, cur);
cb(cur, "wqkv", il);
if (model.layers[il].bqkv) {
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
cb(cur, "bqkv", il);
}
Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head * sizeof(float), cur->nb[1],
0 * sizeof(float) * (n_embd));
Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float),
cur->nb[1], 1 * sizeof(float) * (n_embd));
Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float),
cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa));
}
if (use_mrope) {
Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr,
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr,
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
} else {
// Normal RoPE
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot,
rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot,
rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
}
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
cur = build_attn(inp_attn,
model.layers[il].wo, NULL,
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il);
}
if (is_output_layer && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
// Post-attention norm (new!)
cur = build_norm(cur, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il);
cb(cur, "post_attn_norm", il);
// Add the input (residual connection after post-attention norm)
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
// FF
{
// Pre-MLP norm
cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il);
cb(cur, "ffn_norm", il);
// MLP
cur = build_ffn(cur,
model.layers[il].ffn_up, NULL, NULL,
NULL, NULL, NULL,
model.layers[il].ffn_down, NULL, NULL,
NULL, LLM_FFN_SWIGLU, LLM_FFN_SEQ, il);
cb(cur, "ffn_out", il);
// Post-MLP norm
cur = build_norm(cur, model.layers[il].ffn_post_norm, NULL, LLM_NORM_RMS, il);
cb(cur, "post_mlp_norm", il);
}
cur = ggml_add(ctx0, cur, ffn_inp);
cur = build_cvec(cur, il);
cb(cur, "l_out", il);
return cur;
}