kvcachefusion support
This commit is contained in:
parent
973a80fd02
commit
c112bc4e73
|
|
@ -316,9 +316,13 @@ ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor* src) co
|
|||
input_shape = ov::PartialShape{1, -1, -1};
|
||||
}
|
||||
} else if (name.find("cache_") == 0) {
|
||||
int layer = extract_layer_from_name(name);
|
||||
bool is_swa = is_swa_layer(layer);
|
||||
input_shape = ov::PartialShape{is_swa ? m_context_size_swa : m_context_size, m_num_heads_kv, m_head_size};
|
||||
if (m_is_static) {
|
||||
int layer = extract_layer_from_name(name);
|
||||
bool is_swa = is_swa_layer(layer);
|
||||
input_shape = ov::PartialShape{is_swa ? m_context_size_swa : m_context_size, m_num_heads_kv, m_head_size};
|
||||
} else {
|
||||
input_shape = ov::PartialShape{1, -1, m_num_heads_kv, m_head_size};
|
||||
}
|
||||
} else if (const auto* op = get_tensor_used_op(src); op && op->op == GGML_OP_SET_ROWS) {
|
||||
input_shape = ov::PartialShape{1, 1, m_is_static ? 1 : -1};
|
||||
} else if (src->op == GGML_OP_VIEW) {
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
#include <openvino/op/broadcast.hpp>
|
||||
#include <openvino/op/concat.hpp>
|
||||
#include <openvino/op/convert.hpp>
|
||||
#include <openvino/op/gather.hpp>
|
||||
#include <openvino/op/reshape.hpp>
|
||||
#include <openvino/op/scaled_dot_product_attention.hpp>
|
||||
#include <openvino/op/transpose.hpp>
|
||||
|
|
@ -32,7 +33,7 @@ OutputVector translate_flash_attn_ext(const NodeContext& context) {
|
|||
auto q = std::make_shared<ov::op::v0::Convert>(q_f32, ov::element::f16);
|
||||
auto scale_node = std::make_shared<ov::op::v0::Constant>(ov::element::f16, ov::Shape{}, std::vector<float>{scale});
|
||||
|
||||
ov::Output<ov::Node> mask_sliced;
|
||||
ov::Output<ov::Node> mask_sliced, res;
|
||||
std::string mask_name = "KQ_mask_sliced";
|
||||
if (context.get_input_names()[3].find("swa") != std::string::npos) {
|
||||
mask_name = "KQ_mask_swa_sliced";
|
||||
|
|
@ -40,33 +41,55 @@ OutputVector translate_flash_attn_ext(const NodeContext& context) {
|
|||
if (context.has_input(mask_name)) {
|
||||
mask_sliced = context.get_input(mask_name);
|
||||
} else {
|
||||
auto token_len = get_dimensions(q, {1});
|
||||
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
|
||||
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
|
||||
mask_sliced = std::make_shared<ov::op::v8::Slice>(mask, zero, token_len, one, one);
|
||||
auto token_len = get_dimensions(q, {2});
|
||||
auto zero_2d = ov::op::v0::Constant::create(ov::element::i64, {2}, {0,0});
|
||||
auto one_2d = ov::op::v0::Constant::create(ov::element::i64, {2}, {1,1});
|
||||
auto zero_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
|
||||
auto two_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {2});
|
||||
auto axes = ov::op::v0::Constant::create(ov::element::i64, {2}, {1,2});
|
||||
auto leaf_8 = context.get_input("leaf_8");
|
||||
auto shape_of_leaf_8 = std::make_shared<ov::op::v3::ShapeOf>(leaf_8);
|
||||
auto gather_leaf_8 = std::make_shared<ov::op::v8::Gather>(shape_of_leaf_8, two_1d, zero_1d);
|
||||
auto stop = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{token_len, gather_leaf_8}, 0);
|
||||
mask_sliced =
|
||||
std::make_shared<ov::op::v8::Slice>(mask, zero_2d, stop, one_2d, axes);
|
||||
mask_sliced = std::make_shared<ov::op::v0::Unsqueeze>(mask_sliced, zero_1d);
|
||||
}
|
||||
|
||||
if (mask_sliced.get_element_type() != ov::element::f16) {
|
||||
mask_sliced = std::make_shared<ov::op::v0::Convert>(mask_sliced, ov::element::f16);
|
||||
}
|
||||
|
||||
auto tile_kv = [](int64_t q_batch, int64_t kv_batch, ov::Output<Node> kv) {
|
||||
auto tile_kv = [](int64_t q_batch, int64_t kv_batch, ov::Output<Node> kv, bool is_static) {
|
||||
int64_t factor = q_batch / kv_batch;
|
||||
if (factor > 1) {
|
||||
auto q_batch_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{q_batch});
|
||||
auto kv_batch_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{kv_batch});
|
||||
auto factor_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{factor});
|
||||
|
||||
auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {1});
|
||||
auto kv_unsqueezed = std::make_shared<ov::op::v0::Unsqueeze>(kv, unsqueeze_axes);
|
||||
ov::Output<ov::Node> kv_broadcast_shape, kv_unsqueezed, new_kv_shape;
|
||||
if (is_static) {
|
||||
auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {1});
|
||||
kv_unsqueezed = std::make_shared<ov::op::v0::Unsqueeze>(kv, unsqueeze_axes);
|
||||
|
||||
auto kv_last_two_dims = get_dimensions(kv.get_node_shared_ptr(), {1, 2});
|
||||
kv_broadcast_shape =
|
||||
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{kv_batch_node, factor_node, kv_last_two_dims}, 0);
|
||||
new_kv_shape =
|
||||
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{q_batch_node, kv_last_two_dims}, 0);
|
||||
} else {
|
||||
auto one_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
|
||||
auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {2});
|
||||
kv_unsqueezed = std::make_shared<ov::op::v0::Unsqueeze>(kv, unsqueeze_axes);
|
||||
|
||||
auto kv_last_two_dims = get_dimensions(kv.get_node_shared_ptr(), {2, 3});
|
||||
kv_broadcast_shape =
|
||||
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{one_1d, kv_batch_node, factor_node, kv_last_two_dims}, 0);
|
||||
new_kv_shape =
|
||||
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{one_1d, q_batch_node, kv_last_two_dims}, 0);
|
||||
}
|
||||
|
||||
auto kv_last_two_dims = get_dimensions(kv.get_node_shared_ptr(), {1, 2});
|
||||
auto kv_broadcast_shape =
|
||||
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{kv_batch_node, factor_node, kv_last_two_dims}, 0);
|
||||
kv = std::make_shared<ov::op::v3::Broadcast>(kv_unsqueezed, kv_broadcast_shape);
|
||||
|
||||
auto new_kv_shape =
|
||||
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{q_batch_node, kv_last_two_dims}, 0);
|
||||
kv = std::make_shared<ov::op::v1::Reshape>(kv, new_kv_shape, false);
|
||||
}
|
||||
return kv;
|
||||
|
|
@ -74,13 +97,18 @@ OutputVector translate_flash_attn_ext(const NodeContext& context) {
|
|||
|
||||
auto q_shape = context.get_input_shape(0).to_shape();
|
||||
auto k_shape = context.get_input_shape(1).to_shape();
|
||||
k = tile_kv(q_shape[0], k_shape[0], k);
|
||||
v = tile_kv(q_shape[0], k_shape[0], v);
|
||||
k = tile_kv(q_shape[0], k_shape[0], k, context.is_static());
|
||||
v = tile_kv(q_shape[0], k_shape[0], v, context.is_static());
|
||||
|
||||
auto sdpa = std::make_shared<ov::op::v13::ScaledDotProductAttention>(q, k, v, mask_sliced, scale_node, false);
|
||||
auto sdpa_f32 = std::make_shared<ov::op::v0::Convert>(sdpa, ov::element::f32);
|
||||
auto res = std::make_shared<ov::op::v1::Transpose>(sdpa_f32,
|
||||
ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2}));
|
||||
if (context.is_static()) {
|
||||
res = std::make_shared<ov::op::v1::Transpose>(sdpa_f32,
|
||||
ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2}));
|
||||
} else {
|
||||
res = std::make_shared<ov::op::v1::Transpose>(sdpa_f32,
|
||||
ov::op::v0::Constant::create(ov::element::i64, {4}, {0, 2, 1, 3}));
|
||||
}
|
||||
return rename_outputs_with_suffix({res}, context.get_name());
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -59,13 +59,23 @@ OutputVector translate_mulmat(const NodeContext& context) {
|
|||
|
||||
auto Z_last_two_dims = get_dimensions(Z.get_node_shared_ptr(), {1, 2});
|
||||
|
||||
auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {1});
|
||||
auto Z_unsqueezed = std::make_shared<ov::op::v0::Unsqueeze>(Z, unsqueeze_axes);
|
||||
|
||||
Output<Node> batch_small = A_batch_larger ? B_batch_node : A_batch_node;
|
||||
Output<Node> batch_large = A_batch_larger ? A_batch_node : B_batch_node;
|
||||
auto broadcast_shape =
|
||||
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{batch_small, factor_node, Z_last_two_dims}, 0);
|
||||
|
||||
ov::Output<Node> broadcast_shape;
|
||||
ov::Output<Node> Z_unsqueezed;
|
||||
if (context.is_static()) {
|
||||
auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {1});
|
||||
Z_unsqueezed = std::make_shared<ov::op::v0::Unsqueeze>(Z, unsqueeze_axes);
|
||||
broadcast_shape =
|
||||
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{batch_small, factor_node, Z_last_two_dims}, 0);
|
||||
} else {
|
||||
auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {2});
|
||||
Z_unsqueezed = std::make_shared<ov::op::v0::Unsqueeze>(Z, unsqueeze_axes);
|
||||
auto one_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
|
||||
broadcast_shape =
|
||||
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{one_1d, batch_small, factor_node, Z_last_two_dims}, 0);
|
||||
}
|
||||
auto Z_broadcasted = std::make_shared<ov::op::v3::Broadcast>(Z_unsqueezed, broadcast_shape);
|
||||
|
||||
auto new_Z_shape = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{batch_large, Z_last_two_dims}, 0);
|
||||
|
|
|
|||
|
|
@ -25,8 +25,13 @@ OutputVector translate_permute(const NodeContext& context) {
|
|||
ov::Output<Node> res;
|
||||
|
||||
if (op_case == 1) {
|
||||
res = std::make_shared<ov::op::v1::Transpose>(context.get_input(0),
|
||||
ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2}));
|
||||
if (context.is_static()) {
|
||||
res = std::make_shared<ov::op::v1::Transpose>(context.get_input(0),
|
||||
ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2}));
|
||||
} else {
|
||||
res = std::make_shared<ov::op::v1::Transpose>(context.get_input(0),
|
||||
ov::op::v0::Constant::create(ov::element::i64, {4}, {0, 2, 1, 3}));
|
||||
}
|
||||
} else {
|
||||
auto src = context.get_input(0);
|
||||
Output<Node> attention_size;
|
||||
|
|
@ -38,20 +43,23 @@ OutputVector translate_permute(const NodeContext& context) {
|
|||
attention_size = context.get_input("attention_size_swa");
|
||||
}
|
||||
|
||||
auto src_shape_ = context.get_input_shape(0).to_shape();
|
||||
std::vector<int64_t> src_shape(src_shape_.begin(), src_shape_.end());
|
||||
|
||||
auto src_reshaped = std::make_shared<ov::op::v1::Reshape>(
|
||||
src,
|
||||
ov::op::v0::Constant::create(ov::element::i64, {3}, std::vector<int64_t>{-1, src_shape[1], src_shape[2]}),
|
||||
false);
|
||||
|
||||
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
|
||||
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
|
||||
auto src_slice = std::make_shared<ov::op::v8::Slice>(src_reshaped, zero, attention_size, one, zero);
|
||||
|
||||
res = std::make_shared<ov::op::v1::Transpose>(src_slice,
|
||||
ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2}));
|
||||
if (context.is_static()) {
|
||||
auto src_shape_ = context.get_input_shape(0).to_shape();
|
||||
std::vector<int64_t> src_shape(src_shape_.begin(), src_shape_.end());
|
||||
auto src_reshaped = std::make_shared<ov::op::v1::Reshape>(
|
||||
src,
|
||||
ov::op::v0::Constant::create(ov::element::i64, {3}, std::vector<int64_t>{-1, src_shape[1], src_shape[2]}),
|
||||
false);
|
||||
auto src_slice = std::make_shared<ov::op::v8::Slice>(src_reshaped, zero, attention_size, one, zero);
|
||||
res = std::make_shared<ov::op::v1::Transpose>(src_slice,
|
||||
ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2}));
|
||||
} else {
|
||||
res = std::make_shared<ov::op::v1::Transpose>(src,
|
||||
ov::op::v0::Constant::create(ov::element::i64, {4}, {0, 2, 1, 3}));
|
||||
}
|
||||
}
|
||||
return rename_outputs_with_suffix({res}, context.get_name());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -84,6 +84,9 @@ OutputVector translate_rope(const NodeContext& context) {
|
|||
ov::op::v0::Constant::create(ov::element::i64, {1}, {3}));
|
||||
auto stack = std::make_shared<ov::op::v0::Concat>(OutputVector{first_half, second_half}, 3);
|
||||
res = std::make_shared<ov::op::v1::Reshape>(stack, std::make_shared<ov::op::v0::ShapeOf>(data_node), false);
|
||||
if (!(context.is_static())) {
|
||||
res = std::make_shared<ov::op::v0::Unsqueeze>(res, ov::op::v0::Constant::create(ov::element::i64, {1}, {0}));
|
||||
}
|
||||
} else if (mode == ROPE_TYPE_NEOX) {
|
||||
auto data_split = std::make_shared<ov::op::v1::Split>(
|
||||
data_node, ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {2}), 2);
|
||||
|
|
|
|||
|
|
@ -3,10 +3,12 @@
|
|||
#include <openvino/core/node.hpp>
|
||||
#include <openvino/core/node_output.hpp>
|
||||
#include <openvino/frontend/exception.hpp>
|
||||
#include <openvino/op/concat.hpp>
|
||||
#include <openvino/op/constant.hpp>
|
||||
#include <openvino/op/convert.hpp>
|
||||
#include <openvino/op/gather.hpp>
|
||||
#include <openvino/op/reshape.hpp>
|
||||
#include <openvino/op/unsqueeze.hpp>
|
||||
#include <openvino/op/scatter_update.hpp>
|
||||
#include <openvino/op/shape_of.hpp>
|
||||
#include <openvino/op/slice.hpp>
|
||||
|
|
@ -39,17 +41,29 @@ OutputVector translate_set_rows(const NodeContext& context) {
|
|||
auto dst = context.get_input(context.get_output_name());
|
||||
|
||||
auto zero = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{1}, {0});
|
||||
auto dst_reshaped = std::make_shared<ov::op::v1::Reshape>(
|
||||
dst,
|
||||
ov::op::v0::Constant::create(ov::element::i64, {2}, {(int64_t) dst_shape[1], (int64_t) dst_shape[2]}),
|
||||
false);
|
||||
auto indices_reshaped =
|
||||
std::make_shared<ov::op::v0::Squeeze>(indices, ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 1}));
|
||||
auto data_reshaped = std::make_shared<ov::op::v1::Reshape>(
|
||||
data, ov::op::v0::Constant::create(ov::element::i64, {2}, {(int64_t) -1, (int64_t) dst_shape[2]}), false);
|
||||
Output<Node> res;
|
||||
if (context.is_static()) {
|
||||
auto dst_reshaped = std::make_shared<ov::op::v1::Reshape>(
|
||||
dst,
|
||||
ov::op::v0::Constant::create(ov::element::i64, {2}, {(int64_t) dst_shape[1], (int64_t) dst_shape[2]}),
|
||||
false);
|
||||
auto indices_reshaped =
|
||||
std::make_shared<ov::op::v0::Squeeze>(indices, ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 1}));
|
||||
auto data_reshaped = std::make_shared<ov::op::v1::Reshape>(
|
||||
data, ov::op::v0::Constant::create(ov::element::i64, {2}, {(int64_t) -1, (int64_t) dst_shape[2]}), false);
|
||||
|
||||
auto updated = std::make_shared<ov::op::v3::ScatterUpdate>(dst_reshaped, indices_reshaped, data_reshaped, zero);
|
||||
auto res = std::make_shared<ov::op::v1::Reshape>(updated, std::make_shared<ov::op::v0::ShapeOf>(dst), false);
|
||||
auto updated = std::make_shared<ov::op::v3::ScatterUpdate>(dst_reshaped, indices_reshaped, data_reshaped, zero);
|
||||
res = std::make_shared<ov::op::v1::Reshape>(updated, std::make_shared<ov::op::v0::ShapeOf>(dst), false);
|
||||
} else {
|
||||
// TODO: Better solution would be to reshape the data into 4D at first place (for stateful model)
|
||||
if (data.get_partial_shape().rank() + 1 == dst.get_partial_shape().rank()) {
|
||||
data = std::make_shared<ov::op::v0::Unsqueeze>(data, zero);
|
||||
}
|
||||
int concat_axis = 1;
|
||||
if (context.is_static())
|
||||
concat_axis = 0;
|
||||
res = std::make_shared<ov::op::v0::Concat>(OutputVector{dst, data}, concat_axis);
|
||||
}
|
||||
return rename_outputs_with_suffix({res}, context.get_name());
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -7,8 +7,10 @@
|
|||
#include <openvino/op/concat.hpp>
|
||||
#include <openvino/op/constant.hpp>
|
||||
#include <openvino/op/convert.hpp>
|
||||
#include <openvino/op/gather.hpp>
|
||||
#include <openvino/op/matmul.hpp>
|
||||
#include <openvino/op/multiply.hpp>
|
||||
#include <openvino/op/unsqueeze.hpp>
|
||||
#include <openvino/op/slice.hpp>
|
||||
#include <openvino/op/softmax.hpp>
|
||||
#include <vector>
|
||||
|
|
@ -57,9 +59,20 @@ OutputVector translate_soft_max(const NodeContext& context) {
|
|||
} else {
|
||||
auto token_len = get_dimensions(input_node, {1});
|
||||
auto mask_node = context.get_input(1);
|
||||
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
|
||||
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
|
||||
mask_node_sliced = std::make_shared<ov::op::v8::Slice>(mask_node, zero, token_len, one, one);
|
||||
auto zero_2d = ov::op::v0::Constant::create(ov::element::i64, {2}, {0,0});
|
||||
auto one_2d = ov::op::v0::Constant::create(ov::element::i64, {2}, {1,1});
|
||||
auto zero_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
|
||||
auto two_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {2});
|
||||
auto axes = ov::op::v0::Constant::create(ov::element::i64, {2}, {1,2});
|
||||
auto leaf_8 = context.get_input("leaf_8");
|
||||
auto shape_of_leaf_8 = std::make_shared<ov::op::v3::ShapeOf>(leaf_8);
|
||||
auto gather_leaf_8 = std::make_shared<ov::op::v8::Gather>(shape_of_leaf_8, two_1d, zero_1d);
|
||||
auto stop = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{token_len, gather_leaf_8}, 0);
|
||||
mask_node_sliced =
|
||||
std::make_shared<ov::op::v8::Slice>(mask_node, zero_2d, stop, one_2d, axes);
|
||||
if (!(context.is_static())) {
|
||||
mask_node_sliced = std::make_shared<ov::op::v0::Unsqueeze>(mask_node_sliced, zero_1d);
|
||||
}
|
||||
}
|
||||
|
||||
if (mask_node_sliced.get_element_type() != context.get_output_type(0)) {
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@
|
|||
#include <openvino/op/convert.hpp>
|
||||
#include <openvino/op/cos.hpp>
|
||||
#include <openvino/op/divide.hpp>
|
||||
#include <openvino/op/gather.hpp>
|
||||
#include <openvino/op/multiply.hpp>
|
||||
#include <openvino/op/parameter.hpp>
|
||||
#include <openvino/op/range.hpp>
|
||||
|
|
@ -87,9 +88,18 @@ void add_sliced_mask(TensorMap& tensor_map, GgmlDecoder& ggml_model_decoder) {
|
|||
if (is_static) {
|
||||
mask_sliced = mask;
|
||||
} else {
|
||||
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
|
||||
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
|
||||
mask_sliced = std::make_shared<ov::op::v8::Slice>(mask, zero, token_len, one, one);
|
||||
auto zero_2d = ov::op::v0::Constant::create(ov::element::i64, {2}, {0,0});
|
||||
auto one_2d = ov::op::v0::Constant::create(ov::element::i64, {2}, {1,1});
|
||||
auto zero_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
|
||||
auto two_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {2});
|
||||
auto axes = ov::op::v0::Constant::create(ov::element::i64, {2}, {1,2});
|
||||
auto leaf_8 = tensor_map.at("leaf_8").get_node_shared_ptr();
|
||||
auto shape_of_leaf_8 = std::make_shared<ov::op::v3::ShapeOf>(leaf_8);
|
||||
auto gather_leaf_8 = std::make_shared<ov::op::v8::Gather>(shape_of_leaf_8, two_1d, zero_1d);
|
||||
auto stop = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{token_len, gather_leaf_8}, 0);
|
||||
mask_sliced =
|
||||
std::make_shared<ov::op::v8::Slice>(mask, zero_2d, stop, one_2d, axes);
|
||||
mask_sliced = std::make_shared<ov::op::v0::Unsqueeze>(mask_sliced, zero_1d);
|
||||
mask_sliced = std::make_shared<ov::op::v0::Convert>(mask_sliced, ov::element::f16);
|
||||
mask_sliced->set_friendly_name(sliced_name);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue