models : fix assert in mamba2 graph (#20270)
This commit is contained in:
parent
107d599952
commit
43e1cbd6c1
|
|
@ -155,7 +155,6 @@ ggml_tensor * llm_build_mamba_base::build_mamba2_layer(llm_graph_input_rs * inp,
|
|||
|
||||
const auto kv_head = mctx_cur->get_head();
|
||||
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
const int64_t d_conv = hparams.ssm_d_conv;
|
||||
const int64_t d_inner = hparams.ssm_d_inner;
|
||||
const int64_t d_state = hparams.ssm_d_state;
|
||||
|
|
@ -170,7 +169,7 @@ ggml_tensor * llm_build_mamba_base::build_mamba2_layer(llm_graph_input_rs * inp,
|
|||
GGML_ASSERT(ubatch.equal_seqs());
|
||||
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
|
||||
GGML_ASSERT(d_inner % n_head == 0);
|
||||
GGML_ASSERT(d_inner % (n_group*n_embd) == 0);
|
||||
GGML_ASSERT(d_inner % (n_group*d_state) == 0);
|
||||
|
||||
ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
|
||||
ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
|
||||
|
|
|
|||
Loading…
Reference in New Issue