PERF: use Slice+Concat in writing cache_v
This commit is contained in:
parent
8ac5c225aa
commit
d7cc802292
|
|
@ -1,13 +1,17 @@
|
|||
#include <climits>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <openvino/core/node.hpp>
|
||||
#include <openvino/core/node_output.hpp>
|
||||
#include <openvino/core/node_vector.hpp>
|
||||
#include <openvino/op/add.hpp>
|
||||
#include <openvino/op/concat.hpp>
|
||||
#include <openvino/op/constant.hpp>
|
||||
#include <openvino/op/convert_like.hpp>
|
||||
#include <openvino/op/range.hpp>
|
||||
#include <openvino/op/reshape.hpp>
|
||||
#include <openvino/op/scatter_nd_update.hpp>
|
||||
#include <openvino/op/slice.hpp>
|
||||
#include <openvino/op/transpose.hpp>
|
||||
#include <openvino/op/unsqueeze.hpp>
|
||||
#include <vector>
|
||||
|
|
@ -64,42 +68,40 @@ OutputVector translate_cpy(const NodeContext& context) {
|
|||
} else {
|
||||
// Write V to cache_v
|
||||
int64_t total_head_size = src0_shape[1];
|
||||
auto total_head_size_node = ov::op::v0::Constant::create(ov::element::i64, {1}, {total_head_size});
|
||||
|
||||
auto reshaped_src0 = std::make_shared<ov::op::v1::Reshape>(
|
||||
src0,
|
||||
ov::op::v0::Constant::create(ov::element::i64, {2}, std::vector<int64_t>{total_head_size, -1}),
|
||||
false);
|
||||
auto transposed_src0 =
|
||||
std::make_shared<ov::op::v1::Transpose>(reshaped_src0,
|
||||
ov::op::v0::Constant::create(ov::element::i64, {2}, {1, 0}));
|
||||
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
|
||||
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
|
||||
|
||||
auto token_len = get_dimensions(src0.get_node_shared_ptr(), {2});
|
||||
past_token_len = std::make_shared<ov::op::v0::Unsqueeze>(past_token_len, zero);
|
||||
auto total_token_len = std::make_shared<ov::op::v1::Add>(past_token_len, token_len);
|
||||
|
||||
auto reshaped_src1 = std::make_shared<ov::op::v1::Reshape>(
|
||||
src1,
|
||||
ov::op::v0::Constant::create(ov::element::i64, {2}, std::vector<int64_t>{total_head_size, -1}),
|
||||
false);
|
||||
auto transposed_src1 =
|
||||
std::make_shared<ov::op::v1::Transpose>(reshaped_src1,
|
||||
ov::op::v0::Constant::create(ov::element::i64, {2}, {1, 0}));
|
||||
|
||||
auto token_len = get_dimensions(src0.get_node_shared_ptr(), {2});
|
||||
token_len = std::make_shared<ov::op::v1::Reshape>(token_len,
|
||||
ov::op::v0::Constant::create(ov::element::i64, {0}, {}),
|
||||
false);
|
||||
auto total_token_len = std::make_shared<ov::op::v1::Add>(past_token_len, token_len);
|
||||
std::shared_ptr<ov::Node> indices =
|
||||
std::make_shared<ov::op::v4::Range>(past_token_len, total_token_len, one, ov::element::i64);
|
||||
indices = std::make_shared<ov::op::v0::Unsqueeze>(
|
||||
indices,
|
||||
ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{1}));
|
||||
|
||||
auto res = std::make_shared<ov::op::v3::ScatterNDUpdate>(transposed_src1, indices, transposed_src0);
|
||||
auto transposed_res =
|
||||
std::make_shared<ov::op::v1::Transpose>(res, ov::op::v0::Constant::create(ov::element::i64, {2}, {1, 0}));
|
||||
auto reshaped_res = std::make_shared<ov::op::v1::Reshape>(
|
||||
transposed_res,
|
||||
ov::op::v0::Constant::create(ov::element::i64, {3}, std::vector<int64_t>{1, total_head_size, -1}),
|
||||
false);
|
||||
return {reshaped_res};
|
||||
|
||||
auto src1_left = std::make_shared<ov::op::v8::Slice>(
|
||||
reshaped_src1,
|
||||
ov::op::v0::Constant::create(ov::element::i64, {3}, {0, 0, 0}),
|
||||
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{one, total_head_size_node, past_token_len}, 0),
|
||||
ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 1, 1}));
|
||||
|
||||
auto src1_right = std::make_shared<ov::op::v8::Slice>(
|
||||
reshaped_src1,
|
||||
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{zero, zero, total_token_len}, 0),
|
||||
ov::op::v0::Constant::create(ov::element::i64, {3}, std::vector<int64_t>{1, total_head_size, INT_MAX}),
|
||||
ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 1, 1}));
|
||||
|
||||
auto reshaped_src0 = std::make_shared<ov::op::v1::Reshape>(
|
||||
src0,
|
||||
ov::op::v0::Constant::create(ov::element::i64, {3}, std::vector<int64_t>{1, total_head_size, -1}),
|
||||
false);
|
||||
|
||||
auto res = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{src1_left, reshaped_src0, src1_right}, 2);
|
||||
|
||||
return {res};
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue