Fix stateful shapes
This commit is contained in:
parent
d398214e14
commit
26328fe118
|
|
@ -26,7 +26,7 @@ OutputVector translate_glu_geglu(const NodeContext & context) {
|
|||
src1 = context.get_input(1);
|
||||
} else {
|
||||
auto combined = context.get_input(0);
|
||||
auto split_axis = ov::op::v0::Constant::create(ov::element::i64, {}, {3});
|
||||
auto split_axis = ov::op::v0::Constant::create(ov::element::i64, {}, {-1});
|
||||
auto split = std::make_shared<ov::op::v1::Split>(combined, split_axis, 2);
|
||||
src0 = split->output(0);
|
||||
src1 = split->output(1);
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ OutputVector translate_glu_swiglu(const NodeContext & context) {
|
|||
src1 = context.get_input(1);
|
||||
} else {
|
||||
auto combined = context.get_input(0);
|
||||
auto split_axis = ov::op::v0::Constant::create(ov::element::i64, {}, {3});
|
||||
auto split_axis = ov::op::v0::Constant::create(ov::element::i64, {}, {-1});
|
||||
auto split = std::make_shared<ov::op::v1::Split>(combined, split_axis, 2);
|
||||
src0 = split->output(0);
|
||||
src1 = split->output(1);
|
||||
|
|
|
|||
|
|
@ -70,22 +70,16 @@ OutputVector translate_rope(const NodeContext & context) {
|
|||
constexpr int ROPE_TYPE_NORM = 0;
|
||||
|
||||
if (mode == ROPE_TYPE_NORM) {
|
||||
auto neg_one = ov::op::v0::Constant::create(ov::element::i64, {1}, {-1});
|
||||
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
|
||||
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
|
||||
auto two = ov::op::v0::Constant::create(ov::element::i64, {1}, {2});
|
||||
auto end = ov::op::v0::Constant::create(ov::element::i64, {1}, {output_shape[3]});
|
||||
Output<Node> even_slice;
|
||||
Output<Node> odd_slice;
|
||||
int32_t unsqueeze_dim = 4;
|
||||
if (context.is_stateful()) {
|
||||
unsqueeze_dim = 3;
|
||||
even_slice = std::make_shared<ov::op::v8::Slice>(data_node, zero, end, two, two);
|
||||
odd_slice = std::make_shared<ov::op::v8::Slice>(data_node, one, end, two, two);
|
||||
} else {
|
||||
auto three = ov::op::v0::Constant::create(ov::element::i64, {1}, {3});
|
||||
even_slice = std::make_shared<ov::op::v8::Slice>(data_node, zero, end, two, three);
|
||||
odd_slice = std::make_shared<ov::op::v8::Slice>(data_node, one, end, two, three);
|
||||
}
|
||||
int32_t unsqueeze_dim = context.is_stateful() ? 3 : 4;
|
||||
even_slice = std::make_shared<ov::op::v8::Slice>(data_node, zero, end, two, neg_one);
|
||||
odd_slice = std::make_shared<ov::op::v8::Slice>(data_node, one, end, two, neg_one);
|
||||
|
||||
Output<Node> first_half =
|
||||
std::make_shared<ov::op::v1::Subtract>(std::make_shared<ov::op::v1::Multiply>(even_slice, cos_theta_node),
|
||||
|
|
@ -105,7 +99,7 @@ OutputVector translate_rope(const NodeContext & context) {
|
|||
res = std::make_shared<ov::op::v1::Reshape>(stack, data_shape, false);
|
||||
} else if (mode == ROPE_TYPE_NEOX) {
|
||||
auto data_split = std::make_shared<ov::op::v1::Split>(
|
||||
data_node, ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {3}), 2);
|
||||
data_node, ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {-1}), 2);
|
||||
Output<Node> slice_data_node_0 = data_split->outputs()[0];
|
||||
Output<Node> slice_data_node_1 = data_split->outputs()[1];
|
||||
|
||||
|
|
@ -117,11 +111,7 @@ OutputVector translate_rope(const NodeContext & context) {
|
|||
std::make_shared<ov::op::v1::Multiply>(slice_data_node_0, sin_theta_node),
|
||||
std::make_shared<ov::op::v1::Multiply>(slice_data_node_1, cos_theta_node));
|
||||
|
||||
int32_t concat_dim = 3;
|
||||
if (context.is_stateful()) {
|
||||
concat_dim = 2;
|
||||
}
|
||||
res = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{first_half_node, second_half_node}, concat_dim);
|
||||
res = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{first_half_node, second_half_node}, -1);
|
||||
}
|
||||
|
||||
return rename_outputs_with_suffix({res}, context.get_name());
|
||||
|
|
|
|||
|
|
@ -216,7 +216,7 @@ ov::Output<ov::Node> process_view_input(const NodeContext & context, int input_i
|
|||
auto begin = ov::op::v0::Constant::create(ov::element::i64, {1}, {split_addr});
|
||||
auto end = ov::op::v0::Constant::create(ov::element::i64, {1}, {slice_end});
|
||||
auto stride = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
|
||||
auto axes = ov::op::v0::Constant::create(ov::element::i64, {1}, {3});
|
||||
auto axes = ov::op::v0::Constant::create(ov::element::i64, {1}, {context.is_stateful() ? 2 : 3});
|
||||
auto sliced = std::make_shared<ov::op::v8::Slice>(input, begin, end, stride, axes);
|
||||
return sliced;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -497,6 +497,7 @@ ov::Tensor get_ov_input_tensor(std::shared_ptr<GgmlOvDecoder> ggml_decoder, cons
|
|||
|
||||
ov::Tensor get_ov_input_tensor_static_decode(std::shared_ptr<GgmlOvDecoder> ggml_decoder,
|
||||
const std::string & param_name) {
|
||||
// NPU decoding stage
|
||||
const auto * ggml_tensor = ggml_decoder->get_input_ggml_tensor(param_name);
|
||||
const auto * op = ggml_decoder->get_tensor_used_op(ggml_tensor);
|
||||
|
||||
|
|
@ -540,6 +541,7 @@ ov::Tensor get_ov_input_tensor_static_decode(std::shared_ptr<GgmlOvDecoder> ggml
|
|||
ov::Tensor get_ov_input_tensor_static_prefill(std::shared_ptr<GgmlOvDecoder> ggml_decoder,
|
||||
const std::string & param_name,
|
||||
int chunk_index) {
|
||||
// NPU prompt processing stage
|
||||
const auto * ggml_tensor = ggml_decoder->get_input_ggml_tensor(param_name);
|
||||
const auto * op = ggml_decoder->get_tensor_used_op(ggml_tensor);
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue