diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 1d0d7197e1..8edf7d749b 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -32,7 +32,7 @@ bool llm_graph_input_embd::can_reuse(const llm_graph_params & params) { bool res = true; res &= (!tokens && !params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens); - res &= (!embd && !params.ubatch.embd) || (embd && embd->ne[0] == params.ubatch.n_tokens); + res &= (!embd && !params.ubatch.embd) || (embd && embd->ne[1] == params.ubatch.n_tokens); return res; } @@ -62,7 +62,7 @@ void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) { bool llm_graph_input_pos::can_reuse(const llm_graph_params & params) { bool res = true; - res &= pos->ne[0] == params.ubatch.n_tokens; + res &= pos->ne[0] == params.ubatch.n_tokens*n_pos_per_embd; return res; }