PERF: use Slice+Concat in writing cache_v

This commit is contained in:
Yu, Zijun 2025-05-16 10:14:05 +08:00 committed by Mustafa Cavus
parent 8ac5c225aa
commit d7cc802292
1 changed files with 32 additions and 30 deletions

View File

@ -1,13 +1,17 @@
#include <climits>
#include <cstdint>
#include <memory>
#include <openvino/core/node.hpp>
#include <openvino/core/node_output.hpp>
#include <openvino/core/node_vector.hpp>
#include <openvino/op/add.hpp>
#include <openvino/op/concat.hpp>
#include <openvino/op/constant.hpp>
#include <openvino/op/convert_like.hpp>
#include <openvino/op/range.hpp>
#include <openvino/op/reshape.hpp>
#include <openvino/op/scatter_nd_update.hpp>
#include <openvino/op/slice.hpp>
#include <openvino/op/transpose.hpp>
#include <openvino/op/unsqueeze.hpp>
#include <vector>
@ -64,42 +68,40 @@ OutputVector translate_cpy(const NodeContext& context) {
} else {
// Write V to cache_v
int64_t total_head_size = src0_shape[1];
auto total_head_size_node = ov::op::v0::Constant::create(ov::element::i64, {1}, {total_head_size});
auto reshaped_src0 = std::make_shared<ov::op::v1::Reshape>(
src0,
ov::op::v0::Constant::create(ov::element::i64, {2}, std::vector<int64_t>{total_head_size, -1}),
false);
auto transposed_src0 =
std::make_shared<ov::op::v1::Transpose>(reshaped_src0,
ov::op::v0::Constant::create(ov::element::i64, {2}, {1, 0}));
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
auto token_len = get_dimensions(src0.get_node_shared_ptr(), {2});
past_token_len = std::make_shared<ov::op::v0::Unsqueeze>(past_token_len, zero);
auto total_token_len = std::make_shared<ov::op::v1::Add>(past_token_len, token_len);
auto reshaped_src1 = std::make_shared<ov::op::v1::Reshape>(
src1,
ov::op::v0::Constant::create(ov::element::i64, {2}, std::vector<int64_t>{total_head_size, -1}),
false);
auto transposed_src1 =
std::make_shared<ov::op::v1::Transpose>(reshaped_src1,
ov::op::v0::Constant::create(ov::element::i64, {2}, {1, 0}));
auto token_len = get_dimensions(src0.get_node_shared_ptr(), {2});
token_len = std::make_shared<ov::op::v1::Reshape>(token_len,
ov::op::v0::Constant::create(ov::element::i64, {0}, {}),
false);
auto total_token_len = std::make_shared<ov::op::v1::Add>(past_token_len, token_len);
std::shared_ptr<ov::Node> indices =
std::make_shared<ov::op::v4::Range>(past_token_len, total_token_len, one, ov::element::i64);
indices = std::make_shared<ov::op::v0::Unsqueeze>(
indices,
ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{1}));
auto res = std::make_shared<ov::op::v3::ScatterNDUpdate>(transposed_src1, indices, transposed_src0);
auto transposed_res =
std::make_shared<ov::op::v1::Transpose>(res, ov::op::v0::Constant::create(ov::element::i64, {2}, {1, 0}));
auto reshaped_res = std::make_shared<ov::op::v1::Reshape>(
transposed_res,
ov::op::v0::Constant::create(ov::element::i64, {3}, std::vector<int64_t>{1, total_head_size, -1}),
false);
return {reshaped_res};
auto src1_left = std::make_shared<ov::op::v8::Slice>(
reshaped_src1,
ov::op::v0::Constant::create(ov::element::i64, {3}, {0, 0, 0}),
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{one, total_head_size_node, past_token_len}, 0),
ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 1, 1}));
auto src1_right = std::make_shared<ov::op::v8::Slice>(
reshaped_src1,
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{zero, zero, total_token_len}, 0),
ov::op::v0::Constant::create(ov::element::i64, {3}, std::vector<int64_t>{1, total_head_size, INT_MAX}),
ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 1, 1}));
auto reshaped_src0 = std::make_shared<ov::op::v1::Reshape>(
src0,
ov::op::v0::Constant::create(ov::element::i64, {3}, std::vector<int64_t>{1, total_head_size, -1}),
false);
auto res = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{src1_left, reshaped_src0, src1_right}, 2);
return {res};
}
}