diff --git a/ggml/src/ggml-openvino/ggml-decoder.cpp b/ggml/src/ggml-openvino/ggml-decoder.cpp index 7c72c1fb34..f429b796b5 100644 --- a/ggml/src/ggml-openvino/ggml-decoder.cpp +++ b/ggml/src/ggml-openvino/ggml-decoder.cpp @@ -180,19 +180,13 @@ void GgmlOvDecoder::set_input_output(ggml_tensor * node, bool naive) { } m_inputs[src_name] = src; assert(stateful_kv_shape.rank().is_static()); - if (stateful_kv_shape.rank().get_length() != 0) { - auto param_node = - std::make_shared(get_ov_type(src), stateful_kv_shape); - param_node->set_friendly_name(src_name); - param_node->output(0).get_tensor().set_names({src_name}); - m_model_inputs[src_name] = param_node; - } else { - auto param_node = - std::make_shared(get_ov_type(src), get_graph_input_shape(node, src)); - param_node->set_friendly_name(src_name); - param_node->output(0).get_tensor().set_names({src_name}); - m_model_inputs[src_name] = param_node; - } + ov::PartialShape param_shape = (stateful_kv_shape.rank().get_length() != 0) + ? stateful_kv_shape + : get_graph_input_shape(node, src); + auto param_node = std::make_shared(get_ov_type(src), param_shape); + param_node->set_friendly_name(src_name); + param_node->output(0).get_tensor().set_names({src_name}); + m_model_inputs[src_name] = param_node; } } }