diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index b0a6ea323f..67f6712744 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1963,15 +1963,11 @@ ggml_tensor * llm_graph_context::build_rs( ggml_tensor * output_states = get_state_rows(ctx0, states, state_copy_main); ggml_build_forward_expand(gf, output_states); - // copy extra states which won't be changed further (between n_seqs and n_rs) - // Skip if there are no extra states to copy (n_rs == n_seqs) - if (arch != LLM_ARCH_KIMI_LINEAR || n_rs > (u_int32_t) n_seqs) { // arch check for backward compat - ggml_tensor * states_extra = ggml_get_rows(ctx0, states, state_copy_extra); - ggml_build_forward_expand(gf, - ggml_cpy(ctx0, - states_extra, - ggml_view_1d(ctx0, s, state_size*(n_rs - n_seqs), (rs_head + n_seqs)*state_size*ggml_element_size(s)))); - } + ggml_tensor * states_extra = ggml_get_rows(ctx0, states, state_copy_extra); + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, + states_extra, + ggml_view_1d(ctx0, s, state_size*(n_rs - n_seqs), (rs_head + n_seqs)*state_size*ggml_element_size(s)))); return output_states; }