graph : fix graph reuse logic when `n_pos_per_embd > 1` (#18566)

This commit is contained in:
Georgi Gerganov 2026-01-03 23:59:06 +02:00 committed by GitHub
parent e57f52334b
commit c69c7ebc90
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 2 additions and 2 deletions

View File

@ -32,7 +32,7 @@ bool llm_graph_input_embd::can_reuse(const llm_graph_params & params) {
bool res = true;
res &= (!tokens && !params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens);
res &= (!embd && !params.ubatch.embd) || (embd && embd->ne[0] == params.ubatch.n_tokens);
res &= (!embd && !params.ubatch.embd) || (embd && embd->ne[1] == params.ubatch.n_tokens);
return res;
}
@ -62,7 +62,7 @@ void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
bool llm_graph_input_pos::can_reuse(const llm_graph_params & params) {
bool res = true;
res &= pos->ne[0] == params.ubatch.n_tokens;
res &= pos->ne[0] == params.ubatch.n_tokens*n_pos_per_embd;
return res;
}