Fix llama-cli (need to run with --no-warmup)
This commit is contained in:
parent
05d7abae8c
commit
a9371ea646
|
|
@ -42,15 +42,15 @@ OutputVector translate_flash_attn_ext(const NodeContext& context) {
|
|||
mask_sliced = context.get_input(mask_name);
|
||||
} else {
|
||||
auto token_len = get_dimensions(q, {2});
|
||||
auto kv_len = get_dimensions(k.get_node_shared_ptr(), {2});
|
||||
|
||||
auto zero_2d = ov::op::v0::Constant::create(ov::element::i64, {2}, {0,0});
|
||||
auto one_2d = ov::op::v0::Constant::create(ov::element::i64, {2}, {1,1});
|
||||
auto zero_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
|
||||
auto two_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {2});
|
||||
auto axes = ov::op::v0::Constant::create(ov::element::i64, {2}, {1,2});
|
||||
auto inp_pos = context.get_input("inp_pos");
|
||||
auto shape_of_inp_pos = std::make_shared<ov::op::v3::ShapeOf>(inp_pos);
|
||||
auto gather_inp_pos = std::make_shared<ov::op::v8::Gather>(shape_of_inp_pos, two_1d, zero_1d);
|
||||
auto stop = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{token_len, gather_inp_pos}, 0);
|
||||
|
||||
auto stop = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{token_len, kv_len}, 0);
|
||||
mask_sliced =
|
||||
std::make_shared<ov::op::v8::Slice>(mask, zero_2d, stop, one_2d, axes);
|
||||
mask_sliced = std::make_shared<ov::op::v0::Unsqueeze>(mask_sliced, zero_1d);
|
||||
|
|
|
|||
|
|
@ -132,7 +132,7 @@ void add_rope_sin_cos(TensorMap& tensor_map, GgmlDecoder& ggml_model_decoder) {
|
|||
// Create common patterns
|
||||
void preprocess(TensorMap& tensor_map, GgmlDecoder& ggml_model_decoder) {
|
||||
add_token_len(tensor_map);
|
||||
add_sliced_mask(tensor_map, ggml_model_decoder);
|
||||
// add_sliced_mask(tensor_map, ggml_model_decoder);
|
||||
add_rope_sin_cos(tensor_map, ggml_model_decoder);
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue