add mtp graph for glm dense
This commit is contained in:
parent
6ded9d7b77
commit
c256da1f9f
|
|
@ -168,6 +168,7 @@ enum common_speculative_type {
|
|||
COMMON_SPECULATIVE_TYPE_NONE, // no speculative decoding
|
||||
COMMON_SPECULATIVE_TYPE_DRAFT, // draft model
|
||||
COMMON_SPECULATIVE_TYPE_EAGLE3, // eagle draft model
|
||||
COMMON_SPECULATIVE_TYPE_NEXTN, // MTP model with NextN layer (deepseek, GLM)
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, // self-speculative decoding with n-gram keys only
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, // self-speculative decoding with n-gram keys and 4 m-gram values
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ const std::vector<enum common_speculative_type> common_speculative_types = {
|
|||
COMMON_SPECULATIVE_TYPE_NONE,
|
||||
COMMON_SPECULATIVE_TYPE_DRAFT,
|
||||
COMMON_SPECULATIVE_TYPE_EAGLE3,
|
||||
COMMON_SPECULATIVE_TYPE_NEXTN,
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE,
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K,
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V,
|
||||
|
|
@ -32,6 +33,7 @@ const std::map<std::string, enum common_speculative_type> common_speculative_typ
|
|||
{"none", COMMON_SPECULATIVE_TYPE_NONE},
|
||||
{"draft", COMMON_SPECULATIVE_TYPE_DRAFT},
|
||||
{"eagle3", COMMON_SPECULATIVE_TYPE_EAGLE3},
|
||||
{"nextn", COMMON_SPECULATIVE_TYPE_NEXTN},
|
||||
{"ngram_simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE},
|
||||
{"ngram_map_k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K},
|
||||
{"ngram_map_k4v", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V},
|
||||
|
|
|
|||
|
|
@ -78,7 +78,7 @@ struct llama_context {
|
|||
float * get_embeddings();
|
||||
float * get_embeddings_ith(int32_t i);
|
||||
float * get_embeddings_seq(llama_seq_id seq_id);
|
||||
|
||||
|
||||
int32_t cpy_mtp_state(llama_context & ctx_mtp);
|
||||
|
||||
llama_token * get_sampled_tokens() const;
|
||||
|
|
|
|||
|
|
@ -8577,14 +8577,18 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
|||
} break;
|
||||
case LLM_ARCH_GLM4:
|
||||
{
|
||||
llm = std::make_unique<llm_build_glm4>(*this, params);
|
||||
if (params.gtype == LLM_GRAPH_TYPE_DECODER_MTP) {
|
||||
llm = std::make_unique<llm_build_glm4<LLM_GRAPH_TYPE_DECODER_MTP>>(*this, params);
|
||||
} else {
|
||||
llm = std::make_unique<llm_build_glm4<LLM_GRAPH_TYPE_DECODER>>(*this, params);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_GLM4_MOE:
|
||||
{
|
||||
if (params.gtype == LLM_GRAPH_TYPE_DECODER_MTP) {
|
||||
llm = std::make_unique<llm_build_glm4_moe<true>>(*this, params);
|
||||
llm = std::make_unique<llm_build_glm4_moe<LLM_GRAPH_TYPE_DECODER_MTP>>(*this, params);
|
||||
} else {
|
||||
llm = std::make_unique<llm_build_glm4_moe<false>>(*this, params);
|
||||
llm = std::make_unique<llm_build_glm4_moe<LLM_GRAPH_TYPE_DECODER>>(*this, params);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_BITNET:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
#include "models.h"
|
||||
|
||||
template <>
|
||||
llm_build_glm4_moe<false>::llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
llm_build_glm4_moe<LLM_GRAPH_TYPE_DECODER>::llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const bool use_mrope = hparams.use_mrope();
|
||||
|
||||
|
|
@ -52,7 +52,7 @@ llm_build_glm4_moe<false>::llm_build_glm4_moe(const llama_model & model, const l
|
|||
|
||||
// MTP model
|
||||
template <>
|
||||
llm_build_glm4_moe<true>::llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
llm_build_glm4_moe<LLM_GRAPH_TYPE_DECODER_MTP>::llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
|
|
@ -109,8 +109,8 @@ llm_build_glm4_moe<true>::llm_build_glm4_moe(const llama_model & model, const ll
|
|||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
|
||||
template <bool is_mtp>
|
||||
ggml_tensor * llm_build_glm4_moe<is_mtp>::build_layer(const llama_model & model,
|
||||
template <llm_graph_type graph_type>
|
||||
ggml_tensor * llm_build_glm4_moe<graph_type>::build_layer(const llama_model & model,
|
||||
llm_graph_input_attn_kv * inp_attn,
|
||||
ggml_tensor * inpL,
|
||||
ggml_tensor * inp_pos,
|
||||
|
|
|
|||
|
|
@ -1,10 +1,9 @@
|
|||
#include "models.h"
|
||||
|
||||
|
||||
|
||||
llm_build_glm4::llm_build_glm4(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
template <>
|
||||
llm_build_glm4<LLM_GRAPH_TYPE_DECODER>::llm_build_glm4(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
|
||||
const bool use_mrope = hparams.use_mrope();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
|
||||
|
|
@ -16,7 +15,6 @@ llm_build_glm4::llm_build_glm4(const llama_model & model, const llm_graph_params
|
|||
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
|
||||
bool use_mrope = hparams.use_mrope();
|
||||
if (ubatch.embd && !use_mrope) {
|
||||
// unfortunately, we need to forcefully stop here, to avoid users complaining about wrong results
|
||||
GGML_ABORT("This GGUF does not support multimodal. Please reconvert it.");
|
||||
|
|
@ -33,116 +31,12 @@ llm_build_glm4::llm_build_glm4(const llama_model & model, const llm_graph_params
|
|||
// Final layer tensors are loaded but not processed in forward pass
|
||||
const int n_transformer_layers = n_layer - hparams.nextn_predict_layers;
|
||||
for (int il = 0; il < n_transformer_layers; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
|
||||
// Pre-attention norm
|
||||
cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// self-attention
|
||||
{
|
||||
ggml_tensor * Qcur = nullptr;
|
||||
ggml_tensor * Kcur = nullptr;
|
||||
ggml_tensor * Vcur = nullptr;
|
||||
|
||||
if (model.layers[il].wqkv == nullptr) {
|
||||
Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||
if (model.layers[il].bq) {
|
||||
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
||||
}
|
||||
Kcur = build_lora_mm(model.layers[il].wk, cur);
|
||||
if (model.layers[il].bk) {
|
||||
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
||||
}
|
||||
Vcur = build_lora_mm(model.layers[il].wv, cur);
|
||||
if (model.layers[il].bv) {
|
||||
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
||||
}
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
} else {
|
||||
cur = build_lora_mm(model.layers[il].wqkv, cur);
|
||||
cb(cur, "wqkv", il);
|
||||
if (model.layers[il].bqkv) {
|
||||
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
|
||||
cb(cur, "bqkv", il);
|
||||
}
|
||||
Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head * sizeof(float), cur->nb[1],
|
||||
0 * sizeof(float) * (n_embd));
|
||||
Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float),
|
||||
cur->nb[1], 1 * sizeof(float) * (n_embd));
|
||||
Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float),
|
||||
cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa));
|
||||
}
|
||||
|
||||
if (use_mrope) {
|
||||
Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr,
|
||||
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
|
||||
Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr,
|
||||
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
} else {
|
||||
// Normal RoPE
|
||||
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot,
|
||||
rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
|
||||
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot,
|
||||
rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
}
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, NULL,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
if (il == n_transformer_layers - 1 && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
// Post-attention norm (new!)
|
||||
cur = build_norm(cur, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(cur, "post_attn_norm", il);
|
||||
|
||||
// Add the input (residual connection after post-attention norm)
|
||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
// FF
|
||||
{
|
||||
// Pre-MLP norm
|
||||
cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
// MLP
|
||||
cur = build_ffn(cur,
|
||||
model.layers[il].ffn_up, NULL, NULL,
|
||||
NULL, NULL, NULL,
|
||||
model.layers[il].ffn_down, NULL, NULL,
|
||||
NULL, LLM_FFN_SWIGLU, LLM_FFN_SEQ, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
|
||||
// Post-MLP norm
|
||||
cur = build_norm(cur, model.layers[il].ffn_post_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(cur, "post_mlp_norm", il);
|
||||
}
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
|
||||
cur = build_cvec(cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
bool is_output_layer = (il == n_transformer_layers - 1);
|
||||
inpL = build_layer(model, inp_attn, inpL, inp_pos, inp_out_ids, sections, is_output_layer, il);
|
||||
}
|
||||
// Final norm
|
||||
cur = build_norm(inpL, model.output_norm, NULL, LLM_NORM_RMS, -1);
|
||||
cur = inpL;
|
||||
cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
|
||||
|
||||
cb(cur, "result_norm", -1);
|
||||
res->t_embd = cur;
|
||||
|
|
@ -155,3 +49,183 @@ llm_build_glm4::llm_build_glm4(const llama_model & model, const llm_graph_params
|
|||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
|
||||
// MTP model
|
||||
template <>
|
||||
llm_build_glm4<LLM_GRAPH_TYPE_DECODER_MTP>::llm_build_glm4(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
|
||||
int sections[4];
|
||||
std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
// for now, we only support one single NextN layer for simplicity
|
||||
GGML_ASSERT(hparams.nextn_predict_layers == 1);
|
||||
const int il = n_layer - hparams.nextn_predict_layers;
|
||||
auto & mtp_layer = model.layers[il];
|
||||
|
||||
ggml_tensor * inp_token_embd = build_inp_embd(mtp_layer.nextn.embed_tokens // can be nullptr on GLM-4.6
|
||||
? mtp_layer.nextn.embed_tokens : model.tok_embd);
|
||||
ggml_tensor * inp_state_embd = build_inp_cross_mtp();
|
||||
|
||||
// check number of input tokens
|
||||
GGML_ASSERT(inp_state_embd->ne[1] == inp_token_embd->ne[1]);
|
||||
|
||||
inp_token_embd = build_norm(inp_token_embd, mtp_layer.nextn.enorm, NULL, LLM_NORM_RMS, il);
|
||||
inp_state_embd = build_norm(inp_state_embd, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il);
|
||||
|
||||
inpL = ggml_concat(ctx0, inp_token_embd, inp_state_embd, 0);
|
||||
cb(inpL, "inp_mtp", il);
|
||||
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
{
|
||||
// input for next layer
|
||||
bool is_output_layer = (il == n_layer - 1);
|
||||
inpL = build_layer(model, inp_attn, inpL, inp_pos, inp_out_ids, sections, is_output_layer, il);
|
||||
}
|
||||
cur = inpL;
|
||||
cur = build_norm(cur, mtp_layer.nextn.shared_head_norm // can be nullptr on GLM-4.6
|
||||
? mtp_layer.nextn.shared_head_norm : model.output_norm, NULL, LLM_NORM_RMS, -1);
|
||||
|
||||
cb(cur, "result_norm", -1);
|
||||
res->t_embd = cur;
|
||||
|
||||
// lm_head
|
||||
cur = build_lora_mm(mtp_layer.nextn.shared_head_head // can be nullptr on GLM-4.6
|
||||
? mtp_layer.nextn.shared_head_head : model.output, cur);
|
||||
|
||||
cb(cur, "result_output", -1);
|
||||
res->t_logits = cur;
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
|
||||
template <llm_graph_type graph_type>
|
||||
ggml_tensor * llm_build_glm4<graph_type>::build_layer(const llama_model & model,
|
||||
llm_graph_input_attn_kv * inp_attn,
|
||||
ggml_tensor * inpL,
|
||||
ggml_tensor * inp_pos,
|
||||
ggml_tensor * inp_out_ids,
|
||||
int sections[4],
|
||||
bool is_output_layer,
|
||||
int il) {
|
||||
bool use_mrope = hparams.use_mrope();
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
|
||||
|
||||
ggml_tensor * inpSA = inpL;
|
||||
|
||||
// Pre-attention norm
|
||||
ggml_tensor * cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// self-attention
|
||||
{
|
||||
ggml_tensor * Qcur = nullptr;
|
||||
ggml_tensor * Kcur = nullptr;
|
||||
ggml_tensor * Vcur = nullptr;
|
||||
|
||||
if (model.layers[il].wqkv == nullptr) {
|
||||
Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||
if (model.layers[il].bq) {
|
||||
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
||||
}
|
||||
Kcur = build_lora_mm(model.layers[il].wk, cur);
|
||||
if (model.layers[il].bk) {
|
||||
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
||||
}
|
||||
Vcur = build_lora_mm(model.layers[il].wv, cur);
|
||||
if (model.layers[il].bv) {
|
||||
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
||||
}
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
} else {
|
||||
cur = build_lora_mm(model.layers[il].wqkv, cur);
|
||||
cb(cur, "wqkv", il);
|
||||
if (model.layers[il].bqkv) {
|
||||
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
|
||||
cb(cur, "bqkv", il);
|
||||
}
|
||||
Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head * sizeof(float), cur->nb[1],
|
||||
0 * sizeof(float) * (n_embd));
|
||||
Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float),
|
||||
cur->nb[1], 1 * sizeof(float) * (n_embd));
|
||||
Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float),
|
||||
cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa));
|
||||
}
|
||||
|
||||
if (use_mrope) {
|
||||
Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr,
|
||||
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
|
||||
Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr,
|
||||
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
} else {
|
||||
// Normal RoPE
|
||||
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot,
|
||||
rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
|
||||
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot,
|
||||
rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
}
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, NULL,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
if (is_output_layer && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
// Post-attention norm (new!)
|
||||
cur = build_norm(cur, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(cur, "post_attn_norm", il);
|
||||
|
||||
// Add the input (residual connection after post-attention norm)
|
||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
// FF
|
||||
{
|
||||
// Pre-MLP norm
|
||||
cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
// MLP
|
||||
cur = build_ffn(cur,
|
||||
model.layers[il].ffn_up, NULL, NULL,
|
||||
NULL, NULL, NULL,
|
||||
model.layers[il].ffn_down, NULL, NULL,
|
||||
NULL, LLM_FFN_SWIGLU, LLM_FFN_SEQ, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
|
||||
// Post-MLP norm
|
||||
cur = build_norm(cur, model.layers[il].ffn_post_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(cur, "post_mlp_norm", il);
|
||||
}
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
|
||||
cur = build_cvec(cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
return cur;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -250,11 +250,20 @@ struct llm_build_gemma : public llm_graph_context {
|
|||
llm_build_gemma(const llama_model & model, const llm_graph_params & params);
|
||||
};
|
||||
|
||||
template <llm_graph_type graph_type>
|
||||
struct llm_build_glm4 : public llm_graph_context {
|
||||
llm_build_glm4(const llama_model & model, const llm_graph_params & params);
|
||||
ggml_tensor * build_layer(const llama_model & model,
|
||||
llm_graph_input_attn_kv * inp_attn,
|
||||
ggml_tensor * inpL,
|
||||
ggml_tensor * inp_pos,
|
||||
ggml_tensor * inp_out_ids,
|
||||
int sections[4],
|
||||
bool is_output_layer,
|
||||
int il);
|
||||
};
|
||||
|
||||
template <bool is_mtp>
|
||||
template <llm_graph_type graph_type>
|
||||
struct llm_build_glm4_moe : public llm_graph_context {
|
||||
llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params);
|
||||
ggml_tensor * build_layer(const llama_model & model,
|
||||
|
|
|
|||
Loading…
Reference in New Issue