Fix stateful shapes

This commit is contained in:
Yu, Zijun 2026-01-23 15:49:36 +08:00
parent d398214e14
commit 26328fe118
5 changed files with 11 additions and 19 deletions

View File

@ -26,7 +26,7 @@ OutputVector translate_glu_geglu(const NodeContext & context) {
src1 = context.get_input(1);
} else {
auto combined = context.get_input(0);
auto split_axis = ov::op::v0::Constant::create(ov::element::i64, {}, {3});
auto split_axis = ov::op::v0::Constant::create(ov::element::i64, {}, {-1});
auto split = std::make_shared<ov::op::v1::Split>(combined, split_axis, 2);
src0 = split->output(0);
src1 = split->output(1);

View File

@ -26,7 +26,7 @@ OutputVector translate_glu_swiglu(const NodeContext & context) {
src1 = context.get_input(1);
} else {
auto combined = context.get_input(0);
auto split_axis = ov::op::v0::Constant::create(ov::element::i64, {}, {3});
auto split_axis = ov::op::v0::Constant::create(ov::element::i64, {}, {-1});
auto split = std::make_shared<ov::op::v1::Split>(combined, split_axis, 2);
src0 = split->output(0);
src1 = split->output(1);

View File

@ -70,22 +70,16 @@ OutputVector translate_rope(const NodeContext & context) {
constexpr int ROPE_TYPE_NORM = 0;
if (mode == ROPE_TYPE_NORM) {
auto neg_one = ov::op::v0::Constant::create(ov::element::i64, {1}, {-1});
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
auto two = ov::op::v0::Constant::create(ov::element::i64, {1}, {2});
auto end = ov::op::v0::Constant::create(ov::element::i64, {1}, {output_shape[3]});
Output<Node> even_slice;
Output<Node> odd_slice;
int32_t unsqueeze_dim = 4;
if (context.is_stateful()) {
unsqueeze_dim = 3;
even_slice = std::make_shared<ov::op::v8::Slice>(data_node, zero, end, two, two);
odd_slice = std::make_shared<ov::op::v8::Slice>(data_node, one, end, two, two);
} else {
auto three = ov::op::v0::Constant::create(ov::element::i64, {1}, {3});
even_slice = std::make_shared<ov::op::v8::Slice>(data_node, zero, end, two, three);
odd_slice = std::make_shared<ov::op::v8::Slice>(data_node, one, end, two, three);
}
int32_t unsqueeze_dim = context.is_stateful() ? 3 : 4;
even_slice = std::make_shared<ov::op::v8::Slice>(data_node, zero, end, two, neg_one);
odd_slice = std::make_shared<ov::op::v8::Slice>(data_node, one, end, two, neg_one);
Output<Node> first_half =
std::make_shared<ov::op::v1::Subtract>(std::make_shared<ov::op::v1::Multiply>(even_slice, cos_theta_node),
@ -105,7 +99,7 @@ OutputVector translate_rope(const NodeContext & context) {
res = std::make_shared<ov::op::v1::Reshape>(stack, data_shape, false);
} else if (mode == ROPE_TYPE_NEOX) {
auto data_split = std::make_shared<ov::op::v1::Split>(
data_node, ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {3}), 2);
data_node, ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {-1}), 2);
Output<Node> slice_data_node_0 = data_split->outputs()[0];
Output<Node> slice_data_node_1 = data_split->outputs()[1];
@ -117,11 +111,7 @@ OutputVector translate_rope(const NodeContext & context) {
std::make_shared<ov::op::v1::Multiply>(slice_data_node_0, sin_theta_node),
std::make_shared<ov::op::v1::Multiply>(slice_data_node_1, cos_theta_node));
int32_t concat_dim = 3;
if (context.is_stateful()) {
concat_dim = 2;
}
res = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{first_half_node, second_half_node}, concat_dim);
res = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{first_half_node, second_half_node}, -1);
}
return rename_outputs_with_suffix({res}, context.get_name());

View File

@ -216,7 +216,7 @@ ov::Output<ov::Node> process_view_input(const NodeContext & context, int input_i
auto begin = ov::op::v0::Constant::create(ov::element::i64, {1}, {split_addr});
auto end = ov::op::v0::Constant::create(ov::element::i64, {1}, {slice_end});
auto stride = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
auto axes = ov::op::v0::Constant::create(ov::element::i64, {1}, {3});
auto axes = ov::op::v0::Constant::create(ov::element::i64, {1}, {context.is_stateful() ? 2 : 3});
auto sliced = std::make_shared<ov::op::v8::Slice>(input, begin, end, stride, axes);
return sliced;
}

View File

@ -497,6 +497,7 @@ ov::Tensor get_ov_input_tensor(std::shared_ptr<GgmlOvDecoder> ggml_decoder, cons
ov::Tensor get_ov_input_tensor_static_decode(std::shared_ptr<GgmlOvDecoder> ggml_decoder,
const std::string & param_name) {
// NPU decoding stage
const auto * ggml_tensor = ggml_decoder->get_input_ggml_tensor(param_name);
const auto * op = ggml_decoder->get_tensor_used_op(ggml_tensor);
@ -540,6 +541,7 @@ ov::Tensor get_ov_input_tensor_static_decode(std::shared_ptr<GgmlOvDecoder> ggml
ov::Tensor get_ov_input_tensor_static_prefill(std::shared_ptr<GgmlOvDecoder> ggml_decoder,
const std::string & param_name,
int chunk_index) {
// NPU prompt processing stage
const auto * ggml_tensor = ggml_decoder->get_input_ggml_tensor(param_name);
const auto * op = ggml_decoder->get_tensor_used_op(ggml_tensor);