fallback naive run with accuracy issue

This commit is contained in:
Yu, Zijun 2025-11-27 15:52:20 +08:00 committed by Mustafa Cavus
parent 59e7e7c47d
commit 65348b5d20
18 changed files with 134 additions and 98 deletions

View File

@ -89,32 +89,60 @@ GgmlOvDecoder::GgmlOvDecoder(ggml_cgraph * cgraph, std::map<std::string, std::sh
m_node_info_list[node_n].node_op_case = compute_op_case(m_node_info_list[node_n].node);
m_node_info_list[node_n].node_op_type = compute_op_type(m_node_info_list[node_n].node);
}
// Iterate through node_info_list to create model inputs and outputs.
// For inputs: if an input of a node is not seen as an output of any previous node, it is a model input.
// For outputs: every node output is a model output unless its data_addr is overridden by a later node.
std::map<void *, ggml_tensor *> data_addr_map;
std::unordered_set<std::string> output_name_set;
for (const auto & node_info : m_node_info_list) {
for (const auto & it : node_info.node_inputs) {
const auto & src_name = it.first;
const auto & src_node = it.second;
if (output_name_set.find(src_name) == output_name_set.end() &&
m_model_weights.find(src_name) == m_model_weights.end() &&
m_model_inputs.find(src_name) == m_model_inputs.end()) {
auto param_node =
std::make_shared<ov::op::v0::Parameter>(get_ov_type(src_node), ov::Shape(get_shape(src_node)));
param_node->set_friendly_name(src_name);
param_node->output(0).get_tensor().set_names({src_name});
m_model_inputs[src_name] = param_node;
}
}
output_name_set.emplace(node_info.node_output_name);
data_addr_map[node_info.data_addr] = node_info.node_output;
}
for (const auto & it : data_addr_map) {
// No need to add view tensors as model outputs
if (it.second->op != GGML_OP_VIEW) {
m_model_outputs[std::string(it.second->name)] = it.second;
}
}
}
// Called in GgmlOvDecoder constructor. Two cases: 1. constructing a decoder for the whole graph;
// 2. constructing a decoder for a node;
// 3. constructing a decoder for the whole graph naively (op test case)
void GgmlOvDecoder::set_input_output(ggml_tensor * node, bool naive) {
std::string node_name;
NodeInfo current_node_info;
auto node_name = std::string(node->name);
auto node_output_name = node_name;
auto * node_output = node;
if (node->op == GGML_OP_SET_ROWS) {
// SET_ROWS updates the tensor in place. For later ov op that uses the
// the view_src of SET_ROWS, we need to make sure they get the updated tensor
// by putting the view_src name in the tensor_map in
// <openvino>/src/frontends/ggml/src/translate_session.cpp
node_name = std::string(node->view_src->name);
} else {
node_name = std::string(node->name);
node_output_name = std::string(node->view_src->name);
node_output = node->view_src;
}
m_output_names.push_back(node_name);
m_outputs[node_name] = node;
m_output_names.push_back(node_output_name);
m_outputs[node_output_name] = node_output;
current_node_info.node = node;
current_node_info.node_name = node_name;
current_node_info.node_outputs[node_name] = node;
current_node_info.node_outputs_names.push_back(node_name);
current_node_info.node_output = node_output;
current_node_info.node_output_name = node_output_name;
current_node_info.node_op_case = 0;
current_node_info.data_addr = node->data;
for (int i = 0; i < GGML_MAX_SRC; i++) {
auto * src = node->src[i];
@ -127,17 +155,8 @@ void GgmlOvDecoder::set_input_output(ggml_tensor * node, bool naive) {
current_node_info.node_inputs[src_name] = src;
current_node_info.node_inputs_names.push_back(src_name);
// Add model inputs and weights constants, if called for the whole graph
if (naive) {
if (m_model_weights.find(src_name) == m_model_weights.end()) {
auto param_node =
std::make_shared<ov::op::v0::Parameter>(get_ov_type(src), get_graph_input_shape(node, src));
param_node->set_friendly_name(src_name);
param_node->output(0).get_tensor().set_names({src_name});
m_model_inputs[src_name] = param_node;
}
} else if (!src->view_src) {
// Add model inputs
if (!naive && !src->view_src) {
ggml_backend_buffer * buffer = src->buffer;
if (buffer->usage == GGML_BACKEND_BUFFER_USAGE_ANY || src->flags & GGML_TENSOR_FLAG_INPUT) {
@ -157,18 +176,15 @@ void GgmlOvDecoder::set_input_output(ggml_tensor * node, bool naive) {
}
}
// Add model outputs, if called for the whole graph
if (naive) {
m_model_output_names.push_back(node_name);
} else {
// Add model outputs
if (!naive) {
// Model outputs are tensors with GGML_TENSOR_FLAG_OUTPUT flag and kv_caches
static std::set<std::string> debug_output_names = {};
// Workaround: the final tensor "result_output" does not have GGML_TENSOR_FLAG_OUTPUT flag set in cgraph
if (node->op == GGML_OP_SET_ROWS || node->flags & GGML_TENSOR_FLAG_OUTPUT ||
node_name.find("output") != std::string::npos || debug_output_names.count(node_name)) {
if (auto it = std::find(m_model_output_names.begin(), m_model_output_names.end(), node_name);
it == m_model_output_names.end()) {
m_model_output_names.push_back(node_name);
node_output_name.find("output") != std::string::npos || debug_output_names.count(node_output_name)) {
if (m_model_outputs.find(node_output_name) == m_model_outputs.end()) {
m_model_outputs[node_output_name] = node_output;
}
}
}
@ -176,7 +192,7 @@ void GgmlOvDecoder::set_input_output(ggml_tensor * node, bool naive) {
m_node_info_list.push_back(current_node_info);
}
int GgmlOvDecoder::compute_op_case(const ggml_tensor * node) {
int GgmlOvDecoder::compute_op_case(const ggml_tensor * node) const {
int op_case = 0;
switch (node->op) {
case GGML_OP_RESHAPE: {
@ -370,9 +386,6 @@ ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor * op, co
int len = m_is_static ? (m_is_prefill ? m_prefill_chunk_size : 1) : -1;
input_shape = ov::PartialShape{1, 1, 1, len};
} else if (input->op == GGML_OP_VIEW) {
// This case is added to make test-backend-ops work
input_shape = ov::PartialShape{get_shape(input->view_src)};
} else {
input_shape = ov::PartialShape{get_shape(input)};
}
@ -762,17 +775,11 @@ std::vector<size_t> GgmlOvDecoder::get_output_stride(const std::string & name) c
ov::PartialShape GgmlOvDecoder::get_output_shape(const std::string & name) const {
auto * ggml_tensor = m_outputs.at(name);
if (ggml_tensor->op == GGML_OP_SET_ROWS) {
ggml_tensor = ggml_tensor->view_src;
}
return ov::PartialShape(get_shape(ggml_tensor));
}
ov::PartialShape GgmlOvDecoder::get_output_shape(int node_idx, const std::string & name) const {
auto * ggml_tensor = m_node_info_list[node_idx].node_outputs.at(name);
if (ggml_tensor->op == GGML_OP_SET_ROWS) {
ggml_tensor = ggml_tensor->view_src;
}
ov::PartialShape GgmlOvDecoder::get_output_shape(int node_idx) const {
auto * ggml_tensor = m_node_info_list[node_idx].node_output;
return ov::PartialShape(get_shape(ggml_tensor));
}
@ -785,7 +792,7 @@ std::vector<std::string> GgmlOvDecoder::get_output_names() const {
}
std::vector<std::string> GgmlOvDecoder::get_output_names(int node_idx) const {
return m_node_info_list[node_idx].node_outputs_names;
return {m_node_info_list[node_idx].node_output_name};
}
const std::string & GgmlOvDecoder::get_op_name() const {
@ -809,8 +816,8 @@ int32_t * GgmlOvDecoder::get_output_op_params(const std::string & name) const {
return m_outputs.at(name)->op_params;
}
int32_t * GgmlOvDecoder::get_output_op_params(int node_idx, const std::string & name) const {
return m_node_info_list[node_idx].node_outputs.at(name)->op_params;
int32_t * GgmlOvDecoder::get_output_op_params(int node_idx) const {
return m_node_info_list[node_idx].node->op_params;
}
void GgmlOvDecoder::visit_subgraph(std::function<void(std::shared_ptr<GgmlDecoder>, int node_idx)> node_visitor) const {

View File

@ -51,13 +51,14 @@ class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder {
public:
struct NodeInfo {
ggml_tensor * node;
std::string node_name;
std::string node_op_type;
std::map<std::string, ggml_tensor *> node_inputs;
std::vector<std::string> node_inputs_names;
std::map<std::string, ggml_tensor *> node_outputs;
std::vector<std::string> node_outputs_names;
ggml_tensor * node_output;
std::string node_output_name;
int node_op_case = 0;
std::string node_op_type;
std::string node_name;
void * data_addr;
};
// Graph decoder
GgmlOvDecoder(ggml_cgraph * cgraph,
@ -106,7 +107,7 @@ public:
virtual ov::PartialShape get_output_shape(const std::string & name) const override;
virtual ov::PartialShape get_output_shape(int node_idx, const std::string & name) const override;
virtual ov::PartialShape get_output_shape(int node_idx) const override;
virtual std::vector<size_t> get_output_stride(const std::string & name) const override;
@ -118,7 +119,7 @@ public:
virtual int32_t * get_output_op_params(const std::string & name) const override;
virtual int32_t * get_output_op_params(int node_idx, const std::string & name) const override;
virtual int32_t * get_output_op_params(int node_idx) const override;
virtual std::vector<std::string> get_output_names() const override;
@ -156,7 +157,16 @@ public:
return m_model_weights;
}
virtual const std::vector<std::string> & get_model_output_names() const override { return m_model_output_names; }
virtual std::vector<std::string> get_model_output_names() const override {
std::vector<std::string> output_names;
output_names.reserve(m_model_outputs.size());
for (const auto & [name, tensor] : m_model_outputs) {
output_names.push_back(name);
}
return output_names;
}
const std::map<std::string, ggml_tensor *> & get_model_outputs() const { return m_model_outputs; }
virtual int get_ctx_size() const { return m_model_params.ctx; }
@ -214,14 +224,15 @@ public:
bool m_is_prefill = false;
int m_prefill_chunk_size = 0;
private:
void set_input_output(ggml_tensor * node, bool naive = false);
void add_extra_inputs();
static std::vector<size_t> get_shape(const ggml_tensor * tensor);
static std::vector<size_t> get_stride(const ggml_tensor * tensor);
static ov::element::Type get_ov_type(const ggml_tensor * tensor);
int compute_op_case(const ggml_tensor * node);
std::string compute_op_type(const ggml_tensor * node);
static std::string compute_op_type(const ggml_tensor * node);
private:
void set_input_output(ggml_tensor * node, bool naive = false);
void add_extra_inputs();
int compute_op_case(const ggml_tensor * node) const;
void validate_cgraph() const;
@ -236,7 +247,7 @@ private:
std::map<std::string, std::shared_ptr<ov::Node>> m_model_extra_inputs;
std::map<std::string, std::shared_ptr<ov::Tensor>> m_model_extra_input_values;
std::map<std::string, std::shared_ptr<ov::Node>> m_model_weights;
std::vector<std::string> m_model_output_names;
std::map<std::string, ggml_tensor *> m_model_outputs;
std::vector<NodeInfo> m_node_info_list;
ModelParams m_model_params;

View File

@ -39,7 +39,7 @@ public:
virtual PartialShape get_output_shape(const std::string& name) const = 0;
virtual PartialShape get_output_shape(int node_idx, const std::string& name) const = 0;
virtual PartialShape get_output_shape(int node_idx) const = 0;
virtual std::vector<size_t> get_output_stride(const std::string& name) const = 0;
@ -51,7 +51,7 @@ public:
virtual int32_t* get_output_op_params(const std::string& name) const = 0;
virtual int32_t* get_output_op_params(int node_idx, const std::string& name) const = 0;
virtual int32_t * get_output_op_params(int node_idx) const = 0;
virtual std::vector<std::string> get_output_names() const = 0;
@ -72,7 +72,7 @@ public:
virtual const std::map<std::string, std::shared_ptr<ov::Node>>& get_model_inputs() const = 0;
virtual const std::map<std::string, std::shared_ptr<ov::Node>>& get_model_extra_inputs() const = 0;
virtual const std::map<std::string, std::shared_ptr<ov::Node>>& get_model_weights() const = 0;
virtual const std::vector<std::string>& get_model_output_names() const = 0;
virtual std::vector<std::string> get_model_output_names() const = 0;
virtual int32_t* get_rope_params() const = 0;
// virtual std::map<std::string, std::string> get_kv_param_res_names() const = 0;

View File

@ -53,17 +53,13 @@ public:
std::string get_output_name() const { return m_output_names[0]; }
PartialShape get_output_shape(size_t index) const {
return m_decoder->get_output_shape(m_node_idx, m_output_names[index]);
}
PartialShape get_output_shape() const { return m_decoder->get_output_shape(m_node_idx); }
int32_t* get_input_op_params(size_t index) const {
return m_decoder->get_input_op_params(m_node_idx, m_input_names[index]);
}
int32_t* get_output_op_params(size_t index) const {
return m_decoder->get_output_op_params(m_node_idx, m_output_names[index]);
}
int32_t * get_output_op_params() const { return m_decoder->get_output_op_params(m_node_idx); }
ov::element::Type get_output_type(size_t index) const {
return m_decoder->get_output_type(m_output_names[index]);

View File

@ -22,7 +22,7 @@ OutputVector translate_cont(const NodeContext & context) {
FRONT_END_CHECK_IMPLEMENTED(op_case == 1 || op_case == 2 || op_case == 3, "Unsupported CONT case");
auto src_shape = context.get_input_shape(0).to_shape();
auto dst_shape = context.get_output_shape(0).to_shape();
auto dst_shape = context.get_output_shape().to_shape();
ov::Output<Node> res;
if (op_case == 1) {

View File

@ -26,7 +26,7 @@ OutputVector translate_flash_attn_ext(const NodeContext & context) {
auto v = context.get_input(2);
auto mask = context.get_input(3);
float * params = reinterpret_cast<float *>(context.get_output_op_params(0));
float * params = reinterpret_cast<float *>(context.get_output_op_params());
float scale = params[0];
// float max_bias = params[1];
// float logit_softcap = params[2];

View File

@ -32,7 +32,7 @@ OutputVector translate_glu_geglu(const NodeContext & context) {
src1 = split->output(1);
}
int32_t * params = context.get_output_op_params(0);
int32_t * params = context.get_output_op_params();
const int32_t swapped = params[1];
if (swapped) {
std::swap(src0, src1);

View File

@ -32,7 +32,7 @@ OutputVector translate_glu_swiglu(const NodeContext & context) {
src1 = split->output(1);
}
int32_t * params = context.get_output_op_params(0);
int32_t * params = context.get_output_op_params();
const int32_t swapped = params[1];
if (swapped) {
std::swap(src0, src1);

View File

@ -32,10 +32,12 @@ OutputVector translate_permute(const NodeContext & context) {
if (op_case == 1) {
res = std::make_shared<ov::op::v1::Transpose>(src, perm);
} else if (op_case == 4) {
auto output_shape = context.get_output_shape(0).to_shape();
auto output_shape = context.get_output_shape().to_shape();
auto n_heads = ov::op::v0::Constant::create(ov::element::i64, {1}, {output_shape[1]});
auto head_size = ov::op::v0::Constant::create(ov::element::i64, {1}, {output_shape[3]});
auto n_seq_active = context.get_input("n_seq_active");
auto n_seq_active = context.has_input("n_seq_active") ?
context.get_input("n_seq_active") :
ov::op::v0::Constant::create(ov::element::i64, {1}, {output_shape[0]});
auto neg_one = ov::op::v0::Constant::create(ov::element::i64, {1}, {-1});
auto new_shape =
@ -49,26 +51,39 @@ OutputVector translate_permute(const NodeContext & context) {
res = std::make_shared<ov::op::v1::Transpose>(reshaped, perm);
} else {
auto cache_shape = src.get_partial_shape();
auto output_shape = context.get_output_shape(0).to_shape();
auto output_shape = context.get_output_shape().to_shape();
int64_t head_size = output_shape[3];
int64_t n_heads = output_shape[1];
int64_t ctx_per_seq = cache_shape[2].is_static() ? cache_shape[2].get_length() : -1;
int64_t n_seq = cache_shape[1].get_length();
Output<Node> attention_size;
if (op_case == 2) {
if (!context.has_input("attention_size")) {
attention_size = ov::op::v0::Constant::create(ov::element::i64, {1}, {output_shape[2]});
} else if (op_case == 2) {
attention_size = context.get_input("attention_size");
} else {
attention_size = context.get_input("attention_size_swa");
}
Output<Node> seq_active_start;
Output<Node> seq_active_end;
if (context.has_input("seq_active_start")) {
seq_active_start = context.get_input("seq_active_start");
seq_active_end = context.get_input("seq_active_end");
} else {
int64_t n_seq_active = output_shape[0];
size_t offset = *((size_t *) context.get_input_op_params(0));
int64_t seq_active_start_val = offset / context.get_input_stride(0)[0];
int64_t seq_active_end_val = seq_active_start_val + n_seq_active;
seq_active_start = ov::op::v0::Constant::create(ov::element::i64, {1}, {seq_active_start_val});
seq_active_end = ov::op::v0::Constant::create(ov::element::i64, {1}, {seq_active_end_val});
}
// 1. reshape to [n_seq, ctx_per_seq, n_heads, head_size]
// 2. slice out the active sequences
// 3. slice out the attention part in each sequence
// 4. permute
auto seq_active_start = context.get_input("seq_active_start");
auto seq_active_end = context.get_input("seq_active_end");
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});

View File

@ -20,7 +20,7 @@ namespace op {
OutputVector translate_reshape(const NodeContext & context) {
num_inputs_check(context, 1, 1);
if (context.get_input_shape(0) == context.get_output_shape(0)) {
if (context.get_input_shape(0) == context.get_output_shape()) {
return {context.get_input(0)};
}
@ -29,7 +29,7 @@ OutputVector translate_reshape(const NodeContext & context) {
op_case == 1 || op_case == 2 || op_case == 3 || op_case == 4 || op_case == 5 || op_case == 6,
"Unsupported RESHAPE case");
auto output_shape = context.get_output_shape(0).to_shape();
auto output_shape = context.get_output_shape().to_shape();
std::shared_ptr<ov::Node> new_shape_node;
if (op_case == 1) {
new_shape_node = ov::op::v0::Constant::create(
@ -50,18 +50,18 @@ OutputVector translate_reshape(const NodeContext & context) {
return {context.get_input(0).get_node_shared_ptr()->input_value(0)};
} else if (op_case == 5) {
std::vector<int64_t> shape_vec = {1, 1, -1, (int64_t) context.get_output_shape(0).to_shape()[3]};
std::vector<int64_t> shape_vec = {1, 1, -1, (int64_t) context.get_output_shape().to_shape()[3]};
new_shape_node = ov::op::v0::Constant::create(ov::element::i64, {4}, shape_vec);
// // Alternative
// auto token_len = context.get_input("token_len");
// auto emb_size =
// ov::op::v0::Constant::create(ov::element::i64, {1}, {(int64_t) context.get_output_shape(0).to_shape()[3]});
// ov::op::v0::Constant::create(ov::element::i64, {1}, {(int64_t) context.get_output_shape().to_shape()[3]});
// auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
// new_shape_node = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{one, one, token_len, emb_size}, 0);
} else if (op_case == 6) {
new_shape_node = ov::op::v0::Constant::create(ov::element::i64, {4}, context.get_output_shape(0).to_shape());
new_shape_node = ov::op::v0::Constant::create(ov::element::i64, {4}, context.get_output_shape().to_shape());
}
auto res = std::make_shared<ov::op::v1::Reshape>(context.get_input(0), new_shape_node, false);
return rename_outputs_with_suffix({res}, context.get_name());

View File

@ -27,7 +27,7 @@ OutputVector translate_rms_norm(const NodeContext & context) {
square, ov::op::v0::Constant::create(ov::element::i64, ov::Shape{1}, {-1}), true);
float eps;
memcpy(&eps, context.get_output_op_params(0), sizeof(float));
memcpy(&eps, context.get_output_op_params(), sizeof(float));
auto rms = std::make_shared<ov::op::v0::Sqrt>(
std::make_shared<ov::op::v1::Add>(mean, ov::op::v0::Constant::create(ov::element::f32, ov::Shape{1}, {eps})));

View File

@ -31,8 +31,8 @@ OutputVector translate_rope(const NodeContext & context) {
ov::Output<Node> res;
auto data_node = context.get_input(0).get_node_shared_ptr();
auto output_shape = context.get_output_shape(0).to_shape();
int32_t * op_params = context.get_output_op_params(0);
auto output_shape = context.get_output_shape().to_shape();
int32_t * op_params = context.get_output_op_params();
Output<Node> cos_theta_node;
Output<Node> sin_theta_node;

View File

@ -15,7 +15,7 @@ OutputVector translate_scale(const NodeContext & context) {
num_inputs_check(context, 1, 1);
float scale;
memcpy(&scale, context.get_output_op_params(0), sizeof(float));
memcpy(&scale, context.get_output_op_params(), sizeof(float));
auto scale_node = std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{}, std::vector<float>{scale});
auto res = std::make_shared<ov::op::v1::Multiply>(context.get_input(0), scale_node);

View File

@ -34,7 +34,7 @@ OutputVector translate_set_rows(const NodeContext & context) {
data = std::make_shared<ov::op::v0::Convert>(data, context.get_output_type(0));
auto dst_shape = context.get_output_shape(0).to_shape();
auto dst_shape = context.get_output_shape().to_shape();
auto ind_squeezed =
std::make_shared<ov::op::v0::Squeeze>(indices, ov::op::v0::Constant::create(ov::element::i64, {3}, {0, 1, 2}));

View File

@ -31,7 +31,7 @@ OutputVector translate_soft_max(const NodeContext & context) {
float scale = 1.0f;
float max_bias = 0.0f;
auto * op_params = context.get_output_op_params(0);
auto * op_params = context.get_output_op_params();
memcpy(&scale, (float *) op_params + 0, sizeof(float));
memcpy(&max_bias, (float *) op_params + 1, sizeof(float));
auto src0_shape = context.get_input_shape(0).get_shape();

View File

@ -10,7 +10,7 @@ OutputVector translate_view(const NodeContext & context) {
num_inputs_check(context, 1, 1);
if (context.get_op_case() == 2) {
auto dst_shape = context.get_output_shape(0).to_shape();
auto dst_shape = context.get_output_shape().to_shape();
return rename_outputs_with_suffix({process_view_input(context, 0, dst_shape[2] * dst_shape[3])},
context.get_name());
}

View File

@ -203,7 +203,16 @@ std::shared_ptr<Model> TranslateSession::translate_graph(const frontend::InputMo
results.push_back(result);
}
resulting_model = std::make_shared<Model>(results, params);
ov::ParameterVector used_params;
for (const auto & param : params) {
if (!param->output(0).get_target_inputs().empty()) {
used_params.push_back(param);
}
}
// if (auto diff = params.size() - used_params.size()) {
// GGML_LOG_INFO("%zu parameters are not used in the model.", diff);
// }
resulting_model = std::make_shared<Model>(results, used_params);
apply_transformations(resulting_model);
return resulting_model;

View File

@ -362,7 +362,7 @@ std::map<ggml_type, ExtraQuantType> get_types_to_requant(const std::string & dev
}
bool is_naive(ggml_cgraph * cgraph) {
constexpr int naive_graph_size_threshold = 20;
constexpr int naive_graph_size_threshold = 100;
return cgraph->n_nodes < naive_graph_size_threshold;
}
@ -412,7 +412,7 @@ ov::Tensor convert_ggml_input_to_ov(std::shared_ptr<GgmlOvDecoder> ggml_decoder,
ov::Shape input_shape;
if (ggml_tensor->op == GGML_OP_VIEW) {
// This case is added to make test-backend-ops work
input_shape = ggml_decoder->get_graph_input_shape(ggml_tensor, ggml_tensor->view_src).to_shape();
input_shape = ggml_decoder->get_shape(ggml_tensor->view_src);
} else {
input_shape = ggml_decoder->get_input_shape(name).to_shape();
}
@ -545,15 +545,13 @@ ov::Tensor get_ov_input_tensor_static_prefill(std::shared_ptr<GgmlOvDecoder> ggm
}
ov::Tensor get_ov_output_tensor(std::shared_ptr<GgmlOvDecoder> ggml_decoder, const std::string & result_name) {
auto * ggml_tensor = ggml_decoder->get_output_ggml_tensor(result_name);
auto output_type = ggml_decoder->get_output_type(result_name);
ov::Shape output_shape;
output_shape = ggml_decoder->get_output_shape(result_name).to_shape();
auto * ggml_tensor = ggml_decoder->get_model_outputs().at(result_name);
auto output_type = ggml_decoder->get_ov_type(ggml_tensor);
auto output_shape = ggml_decoder->get_shape(ggml_tensor);
if (ggml_decoder->is_static() && result_name == "result_output") {
output_shape[1] = 1;
}
ov::Tensor output_tensor(output_type, output_shape, ggml_tensor->data);
return output_tensor;
}