Modified API GgmlOvDecoder::get_output_type(const std::string & name)

This commit is contained in:
Xuejun Zhai 2025-12-03 23:13:18 -08:00 committed by Mustafa Cavus
parent f516db1db5
commit 6d7a0d6047
8 changed files with 12 additions and 12 deletions

View File

@ -779,8 +779,8 @@ ov::PartialShape GgmlOvDecoder::get_output_shape(int node_idx) const {
return ov::PartialShape(get_shape(ggml_tensor));
}
ov::element::Type GgmlOvDecoder::get_output_type(const std::string & name) const {
return get_ov_type(m_outputs.at(name));
ov::element::Type GgmlOvDecoder::get_output_type(const int node_idx) const {
return get_ov_type(m_node_info_list[node_idx].node);
}
std::vector<std::string> GgmlOvDecoder::get_output_names(int node_idx) const {

View File

@ -107,7 +107,7 @@ public:
virtual ov::PartialShape get_output_shape(int node_idx) const override;
virtual ov::element::Type get_output_type(const std::string & name) const override;
virtual ov::element::Type get_output_type(const int node_idx) const override;
virtual int32_t * get_input_op_params(const std::string & name) const override;

View File

@ -39,7 +39,7 @@ public:
virtual PartialShape get_output_shape(int node_idx) const = 0;
virtual element::Type get_output_type(const std::string& name) const = 0;
virtual element::Type get_output_type(const int node_idx) const = 0;
virtual int32_t* get_input_op_params(const std::string& name) const = 0;

View File

@ -61,8 +61,8 @@ public:
int32_t * get_output_op_params() const { return m_decoder->get_output_op_params(m_node_idx); }
ov::element::Type get_output_type(size_t index) const {
return m_decoder->get_output_type(m_output_names[index]);
ov::element::Type get_output_type() const {
return m_decoder->get_output_type(m_node_idx);
}
Output<Node> get_input(int idx) const override {

View File

@ -11,7 +11,7 @@ namespace ggml {
namespace op {
OutputVector translate_cpy(const NodeContext & context) {
auto res = std::make_shared<ov::op::v0::Convert>(context.get_input(0), context.get_output_type(0));
auto res = std::make_shared<ov::op::v0::Convert>(context.get_input(0), context.get_output_type());
return rename_outputs_with_suffix({res}, context.get_name());
}

View File

@ -42,8 +42,8 @@ OutputVector translate_get_rows(const NodeContext & context) {
res = std::make_shared<ov::op::v8::Gather>(data, indices, axis);
}
if (res.get_element_type() != context.get_output_type(0)) {
res = std::make_shared<ov::op::v0::Convert>(res, context.get_output_type(0));
if (res.get_element_type() != context.get_output_type()) {
res = std::make_shared<ov::op::v0::Convert>(res, context.get_output_type());
}
res = std::make_shared<ov::op::v0::Unsqueeze>(res, ov::op::v0::Constant::create(ov::element::i64, {1}, {0}));
return rename_outputs_with_suffix({res}, context.get_name());

View File

@ -32,7 +32,7 @@ OutputVector translate_set_rows(const NodeContext & context) {
auto indices = context.get_input(1);
auto dst = context.get_input(2);
data = std::make_shared<ov::op::v0::Convert>(data, context.get_output_type(0));
data = std::make_shared<ov::op::v0::Convert>(data, context.get_output_type());
auto dst_shape = context.get_output_shape().to_shape();

View File

@ -63,8 +63,8 @@ OutputVector translate_soft_max(const NodeContext & context) {
mask_node_sliced = std::make_shared<ov::op::v8::Slice>(mask_node, zero, token_len, one, one);
}
if (mask_node_sliced.get_element_type() != context.get_output_type(0)) {
mask_node_sliced = std::make_shared<ov::op::v0::Convert>(mask_node_sliced, context.get_output_type(0));
if (mask_node_sliced.get_element_type() != context.get_output_type()) {
mask_node_sliced = std::make_shared<ov::op::v0::Convert>(mask_node_sliced, context.get_output_type());
}
Output<Node> slope_mask;