models : fix assert in mamba2 (cont) (#20335)

* models : fix assert in mamba2 (cont)

* cont : add n_group mod

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
This commit is contained in:
Georgi Gerganov 2026-03-10 15:00:08 +02:00 committed by GitHub
parent a7b3dee7a5
commit 1274fbee9e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 3 additions and 2 deletions

View File

@ -168,8 +168,9 @@ ggml_tensor * llm_build_mamba_base::build_mamba2_layer(llm_graph_input_rs * inp,
GGML_ASSERT(n_seqs != 0);
GGML_ASSERT(ubatch.equal_seqs());
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
GGML_ASSERT(d_inner % n_head == 0);
GGML_ASSERT(d_inner % (n_group*d_state) == 0);
GGML_ASSERT(d_inner % n_head == 0);
GGML_ASSERT(d_inner % d_state == 0);
GGML_ASSERT(d_inner % n_group == 0);
ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);