diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 6e9dd53223..34643226e5 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -9194,11 +9194,11 @@ struct llm_build_mamba : public llm_graph_context { // (ab)using the KV cache to store the states ggml_tensor * conv = build_recurrent_state( gf, conv_states_all, state_copy, - hparams.n_embd_k_s(), n_seqs); + hparams.n_embd_k_s(il), n_seqs); conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs); ggml_tensor * ssm = build_recurrent_state( gf, ssm_states_all, state_copy, - hparams.n_embd_v_s(), n_seqs); + hparams.n_embd_v_s(il), n_seqs); ssm = ggml_reshape_3d(ctx0, ssm, d_state, d_inner, n_seqs); // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}