Modified API GgmlOvDecoder::get_output_type(const std::string & name)
This commit is contained in:
parent
f516db1db5
commit
6d7a0d6047
|
|
@ -779,8 +779,8 @@ ov::PartialShape GgmlOvDecoder::get_output_shape(int node_idx) const {
|
|||
return ov::PartialShape(get_shape(ggml_tensor));
|
||||
}
|
||||
|
||||
ov::element::Type GgmlOvDecoder::get_output_type(const std::string & name) const {
|
||||
return get_ov_type(m_outputs.at(name));
|
||||
ov::element::Type GgmlOvDecoder::get_output_type(const int node_idx) const {
|
||||
return get_ov_type(m_node_info_list[node_idx].node);
|
||||
}
|
||||
|
||||
std::vector<std::string> GgmlOvDecoder::get_output_names(int node_idx) const {
|
||||
|
|
|
|||
|
|
@ -107,7 +107,7 @@ public:
|
|||
|
||||
virtual ov::PartialShape get_output_shape(int node_idx) const override;
|
||||
|
||||
virtual ov::element::Type get_output_type(const std::string & name) const override;
|
||||
virtual ov::element::Type get_output_type(const int node_idx) const override;
|
||||
|
||||
virtual int32_t * get_input_op_params(const std::string & name) const override;
|
||||
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ public:
|
|||
|
||||
virtual PartialShape get_output_shape(int node_idx) const = 0;
|
||||
|
||||
virtual element::Type get_output_type(const std::string& name) const = 0;
|
||||
virtual element::Type get_output_type(const int node_idx) const = 0;
|
||||
|
||||
virtual int32_t* get_input_op_params(const std::string& name) const = 0;
|
||||
|
||||
|
|
|
|||
|
|
@ -61,8 +61,8 @@ public:
|
|||
|
||||
int32_t * get_output_op_params() const { return m_decoder->get_output_op_params(m_node_idx); }
|
||||
|
||||
ov::element::Type get_output_type(size_t index) const {
|
||||
return m_decoder->get_output_type(m_output_names[index]);
|
||||
ov::element::Type get_output_type() const {
|
||||
return m_decoder->get_output_type(m_node_idx);
|
||||
}
|
||||
|
||||
Output<Node> get_input(int idx) const override {
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ namespace ggml {
|
|||
namespace op {
|
||||
|
||||
OutputVector translate_cpy(const NodeContext & context) {
|
||||
auto res = std::make_shared<ov::op::v0::Convert>(context.get_input(0), context.get_output_type(0));
|
||||
auto res = std::make_shared<ov::op::v0::Convert>(context.get_input(0), context.get_output_type());
|
||||
return rename_outputs_with_suffix({res}, context.get_name());
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -42,8 +42,8 @@ OutputVector translate_get_rows(const NodeContext & context) {
|
|||
res = std::make_shared<ov::op::v8::Gather>(data, indices, axis);
|
||||
}
|
||||
|
||||
if (res.get_element_type() != context.get_output_type(0)) {
|
||||
res = std::make_shared<ov::op::v0::Convert>(res, context.get_output_type(0));
|
||||
if (res.get_element_type() != context.get_output_type()) {
|
||||
res = std::make_shared<ov::op::v0::Convert>(res, context.get_output_type());
|
||||
}
|
||||
res = std::make_shared<ov::op::v0::Unsqueeze>(res, ov::op::v0::Constant::create(ov::element::i64, {1}, {0}));
|
||||
return rename_outputs_with_suffix({res}, context.get_name());
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ OutputVector translate_set_rows(const NodeContext & context) {
|
|||
auto indices = context.get_input(1);
|
||||
auto dst = context.get_input(2);
|
||||
|
||||
data = std::make_shared<ov::op::v0::Convert>(data, context.get_output_type(0));
|
||||
data = std::make_shared<ov::op::v0::Convert>(data, context.get_output_type());
|
||||
|
||||
auto dst_shape = context.get_output_shape().to_shape();
|
||||
|
||||
|
|
|
|||
|
|
@ -63,8 +63,8 @@ OutputVector translate_soft_max(const NodeContext & context) {
|
|||
mask_node_sliced = std::make_shared<ov::op::v8::Slice>(mask_node, zero, token_len, one, one);
|
||||
}
|
||||
|
||||
if (mask_node_sliced.get_element_type() != context.get_output_type(0)) {
|
||||
mask_node_sliced = std::make_shared<ov::op::v0::Convert>(mask_node_sliced, context.get_output_type(0));
|
||||
if (mask_node_sliced.get_element_type() != context.get_output_type()) {
|
||||
mask_node_sliced = std::make_shared<ov::op::v0::Convert>(mask_node_sliced, context.get_output_type());
|
||||
}
|
||||
|
||||
Output<Node> slope_mask;
|
||||
|
|
|
|||
Loading…
Reference in New Issue