graph : fix graph reuse logic when `n_pos_per_embd > 1` (#18566)
This commit is contained in:
parent
e57f52334b
commit
c69c7ebc90
|
|
@ -32,7 +32,7 @@ bool llm_graph_input_embd::can_reuse(const llm_graph_params & params) {
|
|||
bool res = true;
|
||||
|
||||
res &= (!tokens && !params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens);
|
||||
res &= (!embd && !params.ubatch.embd) || (embd && embd->ne[0] == params.ubatch.n_tokens);
|
||||
res &= (!embd && !params.ubatch.embd) || (embd && embd->ne[1] == params.ubatch.n_tokens);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
|
@ -62,7 +62,7 @@ void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
|
|||
bool llm_graph_input_pos::can_reuse(const llm_graph_params & params) {
|
||||
bool res = true;
|
||||
|
||||
res &= pos->ne[0] == params.ubatch.n_tokens;
|
||||
res &= pos->ne[0] == params.ubatch.n_tokens*n_pos_per_embd;
|
||||
|
||||
return res;
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue