Fix stateful shapes
This commit is contained in:
parent
d398214e14
commit
26328fe118
|
|
@ -26,7 +26,7 @@ OutputVector translate_glu_geglu(const NodeContext & context) {
|
||||||
src1 = context.get_input(1);
|
src1 = context.get_input(1);
|
||||||
} else {
|
} else {
|
||||||
auto combined = context.get_input(0);
|
auto combined = context.get_input(0);
|
||||||
auto split_axis = ov::op::v0::Constant::create(ov::element::i64, {}, {3});
|
auto split_axis = ov::op::v0::Constant::create(ov::element::i64, {}, {-1});
|
||||||
auto split = std::make_shared<ov::op::v1::Split>(combined, split_axis, 2);
|
auto split = std::make_shared<ov::op::v1::Split>(combined, split_axis, 2);
|
||||||
src0 = split->output(0);
|
src0 = split->output(0);
|
||||||
src1 = split->output(1);
|
src1 = split->output(1);
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ OutputVector translate_glu_swiglu(const NodeContext & context) {
|
||||||
src1 = context.get_input(1);
|
src1 = context.get_input(1);
|
||||||
} else {
|
} else {
|
||||||
auto combined = context.get_input(0);
|
auto combined = context.get_input(0);
|
||||||
auto split_axis = ov::op::v0::Constant::create(ov::element::i64, {}, {3});
|
auto split_axis = ov::op::v0::Constant::create(ov::element::i64, {}, {-1});
|
||||||
auto split = std::make_shared<ov::op::v1::Split>(combined, split_axis, 2);
|
auto split = std::make_shared<ov::op::v1::Split>(combined, split_axis, 2);
|
||||||
src0 = split->output(0);
|
src0 = split->output(0);
|
||||||
src1 = split->output(1);
|
src1 = split->output(1);
|
||||||
|
|
|
||||||
|
|
@ -70,22 +70,16 @@ OutputVector translate_rope(const NodeContext & context) {
|
||||||
constexpr int ROPE_TYPE_NORM = 0;
|
constexpr int ROPE_TYPE_NORM = 0;
|
||||||
|
|
||||||
if (mode == ROPE_TYPE_NORM) {
|
if (mode == ROPE_TYPE_NORM) {
|
||||||
|
auto neg_one = ov::op::v0::Constant::create(ov::element::i64, {1}, {-1});
|
||||||
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
|
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
|
||||||
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
|
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
|
||||||
auto two = ov::op::v0::Constant::create(ov::element::i64, {1}, {2});
|
auto two = ov::op::v0::Constant::create(ov::element::i64, {1}, {2});
|
||||||
auto end = ov::op::v0::Constant::create(ov::element::i64, {1}, {output_shape[3]});
|
auto end = ov::op::v0::Constant::create(ov::element::i64, {1}, {output_shape[3]});
|
||||||
Output<Node> even_slice;
|
Output<Node> even_slice;
|
||||||
Output<Node> odd_slice;
|
Output<Node> odd_slice;
|
||||||
int32_t unsqueeze_dim = 4;
|
int32_t unsqueeze_dim = context.is_stateful() ? 3 : 4;
|
||||||
if (context.is_stateful()) {
|
even_slice = std::make_shared<ov::op::v8::Slice>(data_node, zero, end, two, neg_one);
|
||||||
unsqueeze_dim = 3;
|
odd_slice = std::make_shared<ov::op::v8::Slice>(data_node, one, end, two, neg_one);
|
||||||
even_slice = std::make_shared<ov::op::v8::Slice>(data_node, zero, end, two, two);
|
|
||||||
odd_slice = std::make_shared<ov::op::v8::Slice>(data_node, one, end, two, two);
|
|
||||||
} else {
|
|
||||||
auto three = ov::op::v0::Constant::create(ov::element::i64, {1}, {3});
|
|
||||||
even_slice = std::make_shared<ov::op::v8::Slice>(data_node, zero, end, two, three);
|
|
||||||
odd_slice = std::make_shared<ov::op::v8::Slice>(data_node, one, end, two, three);
|
|
||||||
}
|
|
||||||
|
|
||||||
Output<Node> first_half =
|
Output<Node> first_half =
|
||||||
std::make_shared<ov::op::v1::Subtract>(std::make_shared<ov::op::v1::Multiply>(even_slice, cos_theta_node),
|
std::make_shared<ov::op::v1::Subtract>(std::make_shared<ov::op::v1::Multiply>(even_slice, cos_theta_node),
|
||||||
|
|
@ -105,7 +99,7 @@ OutputVector translate_rope(const NodeContext & context) {
|
||||||
res = std::make_shared<ov::op::v1::Reshape>(stack, data_shape, false);
|
res = std::make_shared<ov::op::v1::Reshape>(stack, data_shape, false);
|
||||||
} else if (mode == ROPE_TYPE_NEOX) {
|
} else if (mode == ROPE_TYPE_NEOX) {
|
||||||
auto data_split = std::make_shared<ov::op::v1::Split>(
|
auto data_split = std::make_shared<ov::op::v1::Split>(
|
||||||
data_node, ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {3}), 2);
|
data_node, ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {-1}), 2);
|
||||||
Output<Node> slice_data_node_0 = data_split->outputs()[0];
|
Output<Node> slice_data_node_0 = data_split->outputs()[0];
|
||||||
Output<Node> slice_data_node_1 = data_split->outputs()[1];
|
Output<Node> slice_data_node_1 = data_split->outputs()[1];
|
||||||
|
|
||||||
|
|
@ -117,11 +111,7 @@ OutputVector translate_rope(const NodeContext & context) {
|
||||||
std::make_shared<ov::op::v1::Multiply>(slice_data_node_0, sin_theta_node),
|
std::make_shared<ov::op::v1::Multiply>(slice_data_node_0, sin_theta_node),
|
||||||
std::make_shared<ov::op::v1::Multiply>(slice_data_node_1, cos_theta_node));
|
std::make_shared<ov::op::v1::Multiply>(slice_data_node_1, cos_theta_node));
|
||||||
|
|
||||||
int32_t concat_dim = 3;
|
res = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{first_half_node, second_half_node}, -1);
|
||||||
if (context.is_stateful()) {
|
|
||||||
concat_dim = 2;
|
|
||||||
}
|
|
||||||
res = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{first_half_node, second_half_node}, concat_dim);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return rename_outputs_with_suffix({res}, context.get_name());
|
return rename_outputs_with_suffix({res}, context.get_name());
|
||||||
|
|
|
||||||
|
|
@ -216,7 +216,7 @@ ov::Output<ov::Node> process_view_input(const NodeContext & context, int input_i
|
||||||
auto begin = ov::op::v0::Constant::create(ov::element::i64, {1}, {split_addr});
|
auto begin = ov::op::v0::Constant::create(ov::element::i64, {1}, {split_addr});
|
||||||
auto end = ov::op::v0::Constant::create(ov::element::i64, {1}, {slice_end});
|
auto end = ov::op::v0::Constant::create(ov::element::i64, {1}, {slice_end});
|
||||||
auto stride = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
|
auto stride = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
|
||||||
auto axes = ov::op::v0::Constant::create(ov::element::i64, {1}, {3});
|
auto axes = ov::op::v0::Constant::create(ov::element::i64, {1}, {context.is_stateful() ? 2 : 3});
|
||||||
auto sliced = std::make_shared<ov::op::v8::Slice>(input, begin, end, stride, axes);
|
auto sliced = std::make_shared<ov::op::v8::Slice>(input, begin, end, stride, axes);
|
||||||
return sliced;
|
return sliced;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -497,6 +497,7 @@ ov::Tensor get_ov_input_tensor(std::shared_ptr<GgmlOvDecoder> ggml_decoder, cons
|
||||||
|
|
||||||
ov::Tensor get_ov_input_tensor_static_decode(std::shared_ptr<GgmlOvDecoder> ggml_decoder,
|
ov::Tensor get_ov_input_tensor_static_decode(std::shared_ptr<GgmlOvDecoder> ggml_decoder,
|
||||||
const std::string & param_name) {
|
const std::string & param_name) {
|
||||||
|
// NPU decoding stage
|
||||||
const auto * ggml_tensor = ggml_decoder->get_input_ggml_tensor(param_name);
|
const auto * ggml_tensor = ggml_decoder->get_input_ggml_tensor(param_name);
|
||||||
const auto * op = ggml_decoder->get_tensor_used_op(ggml_tensor);
|
const auto * op = ggml_decoder->get_tensor_used_op(ggml_tensor);
|
||||||
|
|
||||||
|
|
@ -540,6 +541,7 @@ ov::Tensor get_ov_input_tensor_static_decode(std::shared_ptr<GgmlOvDecoder> ggml
|
||||||
ov::Tensor get_ov_input_tensor_static_prefill(std::shared_ptr<GgmlOvDecoder> ggml_decoder,
|
ov::Tensor get_ov_input_tensor_static_prefill(std::shared_ptr<GgmlOvDecoder> ggml_decoder,
|
||||||
const std::string & param_name,
|
const std::string & param_name,
|
||||||
int chunk_index) {
|
int chunk_index) {
|
||||||
|
// NPU prompt processing stage
|
||||||
const auto * ggml_tensor = ggml_decoder->get_input_ggml_tensor(param_name);
|
const auto * ggml_tensor = ggml_decoder->get_input_ggml_tensor(param_name);
|
||||||
const auto * op = ggml_decoder->get_tensor_used_op(ggml_tensor);
|
const auto * op = ggml_decoder->get_tensor_used_op(ggml_tensor);
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue