model : use ggml_swiglu_split for Mamba

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
This commit is contained in:
Francis Couture-Harpin 2025-07-08 15:45:20 -04:00
parent 2f39cd7bb7
commit f7c7a926f0
1 changed files with 2 additions and 2 deletions

View File

@ -10057,7 +10057,7 @@ struct llm_graph_context_mamba : public llm_graph_context {
// TODO: skip computing output earlier for unused tokens // TODO: skip computing output earlier for unused tokens
y = ggml_add(ctx0, y, ggml_mul(ctx0, cur, layer.ssm_d)); y = ggml_add(ctx0, y, ggml_mul(ctx0, cur, layer.ssm_d));
y = ggml_mul(ctx0, y, ggml_silu(ctx0, ggml_cont(ctx0, z))); y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y);
// {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
cur = build_lora_mm(layer.ssm_out, y); cur = build_lora_mm(layer.ssm_out, y);
@ -10181,7 +10181,7 @@ struct llm_graph_context_mamba : public llm_graph_context {
// TODO: skip computing output earlier for unused tokens // TODO: skip computing output earlier for unused tokens
y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d)); y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
y = ggml_mul(ctx0, y, ggml_silu(ctx0, ggml_cont(ctx0, z))); y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y);
// grouped RMS norm // grouped RMS norm
y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs); y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs);