diff --git a/ggml/src/ggml-openvino/ggml-decoder.cpp b/ggml/src/ggml-openvino/ggml-decoder.cpp index 51fb433410..7c72c1fb34 100644 --- a/ggml/src/ggml-openvino/ggml-decoder.cpp +++ b/ggml/src/ggml-openvino/ggml-decoder.cpp @@ -44,9 +44,11 @@ GgmlOvDecoder::GgmlOvDecoder(ggml_cgraph * cgraph, ComputeParams & compute_params, std::map> & model_weights, bool is_static, + bool is_stateful, bool is_prefill, int prefill_chunk_size) : m_is_static(is_static), + m_is_stateful(is_stateful), m_is_prefill(is_prefill), m_prefill_chunk_size(prefill_chunk_size), m_cgraph(cgraph), @@ -157,19 +159,40 @@ void GgmlOvDecoder::set_input_output(ggml_tensor * node, bool naive) { ggml_backend_buffer * buffer = src->buffer; if (buffer->usage == GGML_BACKEND_BUFFER_USAGE_ANY || src->flags & GGML_TENSOR_FLAG_INPUT) { + ov::PartialShape stateful_kv_shape; // GGML_BACKEND_BUFFER_USAGE_ANY are kv caches if (buffer->usage == GGML_BACKEND_BUFFER_USAGE_ANY) { assert(src_name.find("cache_k") == 0 || src_name.find("cache_v") == 0); + if (auto it = std::find(m_model_params.kv_names.begin(), m_model_params.kv_names.end(), src_name); it == m_model_params.kv_names.end()) { + m_model_params.kv_names.push_back(src_name); + if (is_stateful()) { + // TODO: The shape modification for stateful model below is not validated for all supported models yet. More generic solution might be needed + // to enable additional cases. Ideally, this could be removed from decoder and done as part of a transformation later. + auto stateless_kv_shape = get_graph_input_shape(node, src); + assert(stateless_kv_shape.size() == 4 && stateless_kv_shape[0] == 1 && stateless_kv_shape[1] == 1 + && stateless_kv_shape[2].is_dynamic() && stateless_kv_shape[3] == (m_model_params.n_heads_kv*m_model_params.head_size)); + stateful_kv_shape = {stateless_kv_shape[0], ov::Dimension::dynamic(), m_model_params.n_heads_kv, m_model_params.head_size}; + } + } } if (m_model_inputs.find(src_name) != m_model_inputs.end()) { continue; } m_inputs[src_name] = src; - auto param_node = - std::make_shared(get_ov_type(src), get_graph_input_shape(node, src)); - param_node->set_friendly_name(src_name); - param_node->output(0).get_tensor().set_names({src_name}); - m_model_inputs[src_name] = param_node; + assert(stateful_kv_shape.rank().is_static()); + if (stateful_kv_shape.rank().get_length() != 0) { + auto param_node = + std::make_shared(get_ov_type(src), stateful_kv_shape); + param_node->set_friendly_name(src_name); + param_node->output(0).get_tensor().set_names({src_name}); + m_model_inputs[src_name] = param_node; + } else { + auto param_node = + std::make_shared(get_ov_type(src), get_graph_input_shape(node, src)); + param_node->set_friendly_name(src_name); + param_node->output(0).get_tensor().set_names({src_name}); + m_model_inputs[src_name] = param_node; + } } } } @@ -378,6 +401,8 @@ ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor * op, co } else if (name.find("KQ_mask") == 0) { if (m_is_static) { input_shape = ov::PartialShape{1, 1, m_is_prefill ? m_prefill_chunk_size : 1, m_model_params.ctx}; + } else if (m_is_stateful) { + input_shape = ov::PartialShape{1, 1, -1, -1}; } else { input_shape = ov::PartialShape{-1, 1, -1, -1}; } @@ -465,15 +490,15 @@ const ggml_tensor * GgmlOvDecoder::get_tensor_from_name(const std::string & name return nullptr; } -// std::map GgmlOvDecoder::get_kv_param_res_names() const { -// std::map kv_param_res_names; -// for (const auto & name : m_model_params.kv_names) { -// if (name.find("cache_k") == 0 || name.find("cache_v") == 0) { -// kv_param_res_names[name] = name; -// } -// } -// return kv_param_res_names; -// } +std::map GgmlOvDecoder::get_kv_param_res_names() const { + std::map kv_param_res_names; + for (const auto & name : m_model_params.kv_names) { + if (name.find("cache_k") == 0 || name.find("cache_v") == 0) { + kv_param_res_names[name] = name; + } + } + return kv_param_res_names; +} std::map> GgmlOvDecoder::create_weight_nodes(ggml_cgraph * cgraph) { std::map> model_weights; diff --git a/ggml/src/ggml-openvino/ggml-decoder.h b/ggml/src/ggml-openvino/ggml-decoder.h index 0b302b9320..4afec272e1 100644 --- a/ggml/src/ggml-openvino/ggml-decoder.h +++ b/ggml/src/ggml-openvino/ggml-decoder.h @@ -23,7 +23,7 @@ struct ModelParams { int32_t * rope_params = nullptr; std::vector swa_layers; - // std::vector kv_names; + std::vector kv_names; bool operator==(const ModelParams & other) const { return n_seq == other.n_seq && n_heads == other.n_heads && n_heads_kv == other.n_heads_kv && @@ -66,6 +66,7 @@ public: ComputeParams & compute_params, std::map> & model_weights, bool is_static, + bool is_stateful = false, bool is_prefill = false, int prefill_chunk_size = 256); @@ -171,10 +172,12 @@ public: virtual int32_t * get_rope_params() const override { return m_model_params.rope_params; } - // virtual std::map get_kv_param_res_names() const override; + virtual std::map get_kv_param_res_names() const override; virtual bool is_static() const override { return m_is_static; } + virtual bool is_stateful() const override { return m_is_stateful; } + ov::PartialShape get_graph_input_shape(const ggml_tensor * op, const ggml_tensor * input) const; static void dump_cgraph(const ggml_cgraph * cgraph, std::string & filename); @@ -200,6 +203,7 @@ public: void set_compute_params(const ComputeParams & compute_params) { m_compute_params = compute_params; } bool m_is_static = false; + bool m_is_stateful = false; bool m_is_prefill = false; int m_prefill_chunk_size = 0; diff --git a/ggml/src/ggml-openvino/openvino/decoder.hpp b/ggml/src/ggml-openvino/openvino/decoder.hpp index 1603c7fd20..3b8da2be5d 100644 --- a/ggml/src/ggml-openvino/openvino/decoder.hpp +++ b/ggml/src/ggml-openvino/openvino/decoder.hpp @@ -59,10 +59,13 @@ public: virtual std::vector get_model_output_names() const = 0; virtual int32_t* get_rope_params() const = 0; - // virtual std::map get_kv_param_res_names() const = 0; + + virtual std::map get_kv_param_res_names() const = 0; virtual bool is_static() const = 0; + virtual bool is_stateful() const = 0; + virtual int is_swa_layer(int layer) const = 0; }; diff --git a/ggml/src/ggml-openvino/openvino/node_context.hpp b/ggml/src/ggml-openvino/openvino/node_context.hpp index a0666b21ac..235adcc784 100644 --- a/ggml/src/ggml-openvino/openvino/node_context.hpp +++ b/ggml/src/ggml-openvino/openvino/node_context.hpp @@ -91,8 +91,11 @@ public: int get_op_case() const { return m_decoder->get_op_case(m_node_idx); } + bool is_static() const { return m_decoder->is_static(); } + bool is_stateful() const { return m_decoder->is_stateful(); } + private: std::shared_ptr m_decoder; std::shared_ptr& m_tensor_map; diff --git a/ggml/src/ggml-openvino/openvino/op/get_rows.cpp b/ggml/src/ggml-openvino/openvino/op/get_rows.cpp index dc8454a199..d6e7a35534 100644 --- a/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +++ b/ggml/src/ggml-openvino/openvino/op/get_rows.cpp @@ -37,6 +37,9 @@ OutputVector translate_get_rows(const NodeContext & context) { auto axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {1}); data = std::make_shared(data, ov::op::v0::Constant::create(ov::element::i64, {1}, {0})); res = std::make_shared(data, indices, axis, 1); + } else if (context.is_stateful() && data.get_partial_shape().rank() == 3) { + auto axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {1}); + res = std::make_shared(data, indices, axis, 1); } else { auto axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {0}); res = std::make_shared(data, indices, axis); @@ -45,7 +48,9 @@ OutputVector translate_get_rows(const NodeContext & context) { if (res.get_element_type() != context.get_output_type()) { res = std::make_shared(res, context.get_output_type()); } - res = std::make_shared(res, ov::op::v0::Constant::create(ov::element::i64, {1}, {0})); + if (!(context.is_stateful())) { + res = std::make_shared(res, ov::op::v0::Constant::create(ov::element::i64, {1}, {0})); + } return rename_outputs_with_suffix({res}, context.get_name()); } diff --git a/ggml/src/ggml-openvino/openvino/op/permute.cpp b/ggml/src/ggml-openvino/openvino/op/permute.cpp index bfe09a2b84..fa7ab0c43f 100644 --- a/ggml/src/ggml-openvino/openvino/op/permute.cpp +++ b/ggml/src/ggml-openvino/openvino/op/permute.cpp @@ -29,7 +29,7 @@ OutputVector translate_permute(const NodeContext & context) { auto src = context.get_input(0); auto perm = ov::op::v0::Constant::create(ov::element::i64, {4}, {0, 2, 1, 3}); - if (op_case == 1) { + if (op_case == 1 || context.is_stateful()) { res = std::make_shared(src, perm); } else if (op_case == 4) { auto output_shape = context.get_output_shape().to_shape(); diff --git a/ggml/src/ggml-openvino/openvino/op/reshape.cpp b/ggml/src/ggml-openvino/openvino/op/reshape.cpp index e26a8c778c..7eebd7b7b1 100644 --- a/ggml/src/ggml-openvino/openvino/op/reshape.cpp +++ b/ggml/src/ggml-openvino/openvino/op/reshape.cpp @@ -32,10 +32,15 @@ OutputVector translate_reshape(const NodeContext & context) { auto output_shape = context.get_output_shape().to_shape(); std::shared_ptr new_shape_node; if (op_case == 1) { - new_shape_node = ov::op::v0::Constant::create( - ov::element::i64, {4}, - std::vector{(int64_t) output_shape[0], -1, (int64_t) output_shape[2], (int64_t) output_shape[3]}); - + if (context.is_stateful()) { + new_shape_node = ov::op::v0::Constant::create( + ov::element::i64, {3}, + std::vector{-1, (int64_t) output_shape[2], (int64_t) output_shape[3]}); + } else { + new_shape_node = ov::op::v0::Constant::create( + ov::element::i64, {4}, + std::vector{(int64_t) output_shape[0], -1, (int64_t) output_shape[2], (int64_t) output_shape[3]}); + } } else if (op_case == 2) { new_shape_node = ov::op::v0::Constant::create( ov::element::i64, {4}, @@ -50,8 +55,13 @@ OutputVector translate_reshape(const NodeContext & context) { return {context.get_input(0).get_node_shared_ptr()->input_value(0)}; } else if (op_case == 5) { - std::vector shape_vec = {1, 1, -1, (int64_t) context.get_output_shape().to_shape()[3]}; - new_shape_node = ov::op::v0::Constant::create(ov::element::i64, {4}, shape_vec); + if (context.is_stateful()) { + std::vector shape_vec = {1, -1, (int64_t) context.get_output_shape().to_shape()[3]}; + new_shape_node = ov::op::v0::Constant::create(ov::element::i64, {3}, shape_vec); + } else { + std::vector shape_vec = {1, 1, -1, (int64_t) context.get_output_shape().to_shape()[3]}; + new_shape_node = ov::op::v0::Constant::create(ov::element::i64, {4}, shape_vec); + } // // Alternative // auto token_len = context.get_input("token_len"); diff --git a/ggml/src/ggml-openvino/openvino/op/rope.cpp b/ggml/src/ggml-openvino/openvino/op/rope.cpp index 96fbb6b795..b72e445706 100644 --- a/ggml/src/ggml-openvino/openvino/op/rope.cpp +++ b/ggml/src/ggml-openvino/openvino/op/rope.cpp @@ -54,9 +54,18 @@ OutputVector translate_rope(const NodeContext & context) { // The input comes from a VIEW int slice_len = output_shape[2] * output_shape[3]; data_node = process_view_input(context, 0, slice_len).get_node_shared_ptr(); - auto data_shape = ov::op::v0::Constant::create( - ov::element::i64, {4}, std::vector{1, -1, (int64_t) output_shape[2], (int64_t) output_shape[3]}); - data_node = std::make_shared(data_node, data_shape, false); + if (context.is_stateful()) { + auto data_shape = ov::op::v0::Constant::create( + ov::element::i64, {3}, std::vector{-1, (int64_t) output_shape[2], (int64_t) output_shape[3]}); + data_node = std::make_shared(data_node, data_shape, false); + } else { + auto data_shape = ov::op::v0::Constant::create( + ov::element::i64, {4}, std::vector{1, -1, (int64_t) output_shape[2], (int64_t) output_shape[3]}); + data_node = std::make_shared(data_node, data_shape, false); + } + //auto data_shape = ov::op::v0::Constant::create( + // ov::element::i64, {4}, std::vector{1, -1, (int64_t) output_shape[2], (int64_t) output_shape[3]}); + //data_node = std::make_shared(data_node, data_shape, false); } const int mode = op_params[2]; @@ -67,10 +76,19 @@ OutputVector translate_rope(const NodeContext & context) { auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0}); auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1}); auto two = ov::op::v0::Constant::create(ov::element::i64, {1}, {2}); - auto three = ov::op::v0::Constant::create(ov::element::i64, {1}, {3}); auto end = ov::op::v0::Constant::create(ov::element::i64, {1}, {output_shape[3]}); - auto even_slice = std::make_shared(data_node, zero, end, two, three); - auto odd_slice = std::make_shared(data_node, one, end, two, three); + Output even_slice; + Output odd_slice; + int32_t unsqueeze_dim = 4; + if (context.is_stateful()) { + unsqueeze_dim = 3; + even_slice = std::make_shared(data_node, zero, end, two, two); + odd_slice = std::make_shared(data_node, one, end, two, two); + } else { + auto three = ov::op::v0::Constant::create(ov::element::i64, {1}, {3}); + even_slice = std::make_shared(data_node, zero, end, two, three); + odd_slice = std::make_shared(data_node, one, end, two, three); + } Output first_half = std::make_shared(std::make_shared(even_slice, cos_theta_node), @@ -80,10 +98,10 @@ OutputVector translate_rope(const NodeContext & context) { std::make_shared(odd_slice, cos_theta_node)); first_half = std::make_shared(first_half, - ov::op::v0::Constant::create(ov::element::i64, {1}, {4})); + ov::op::v0::Constant::create(ov::element::i64, {1}, {unsqueeze_dim})); second_half = std::make_shared(second_half, - ov::op::v0::Constant::create(ov::element::i64, {1}, {4})); - auto stack = std::make_shared(OutputVector{first_half, second_half}, 4); + ov::op::v0::Constant::create(ov::element::i64, {1}, {unsqueeze_dim})); + auto stack = std::make_shared(OutputVector{first_half, second_half}, unsqueeze_dim); auto data_shape = ov::op::v0::Constant::create( ov::element::i64, {4}, std::vector{1, -1, (int64_t) output_shape[2], (int64_t) output_shape[3]}); @@ -102,7 +120,11 @@ OutputVector translate_rope(const NodeContext & context) { std::make_shared(slice_data_node_0, sin_theta_node), std::make_shared(slice_data_node_1, cos_theta_node)); - res = std::make_shared(ov::OutputVector{first_half_node, second_half_node}, 3); + int32_t concat_dim = 3; + if (context.is_stateful()) { + concat_dim = 2; + } + res = std::make_shared(ov::OutputVector{first_half_node, second_half_node}, concat_dim); } return rename_outputs_with_suffix({res}, context.get_name()); diff --git a/ggml/src/ggml-openvino/openvino/op/set_rows.cpp b/ggml/src/ggml-openvino/openvino/op/set_rows.cpp index 4ceb55589e..69c4ca7089 100644 --- a/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +++ b/ggml/src/ggml-openvino/openvino/op/set_rows.cpp @@ -45,7 +45,17 @@ OutputVector translate_set_rows(const NodeContext & context) { false); auto axes = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {2}); - Output res = std::make_shared(dst, ind_squeezed, data_reshaped, axes); + Output res; + if (context.is_stateful()) { + int concat_axis = 1; + int64_t dim2 = dst.get_partial_shape()[2].get_length(); + int64_t dim3 = dst.get_partial_shape()[3].get_length(); + data = std::make_shared( + data, ov::op::v0::Constant::create(ov::element::i64, {4}, {(int64_t) 1, (int64_t) -1, dim2, dim3}), false); + res = std::make_shared(OutputVector{dst, data}, concat_axis); + } else { + res = std::make_shared(dst, ind_squeezed, data_reshaped, axes); + } if (auto dst_reshape = std::dynamic_pointer_cast(dst.get_node_shared_ptr())) { // Fix the case of multiple sequences, reshape back to original shape [1, n_seq, ctx_per_seq, emb] diff --git a/ggml/src/ggml-openvino/openvino/translate_session.cpp b/ggml/src/ggml-openvino/openvino/translate_session.cpp index ccd0947a2b..02e08c24f4 100644 --- a/ggml/src/ggml-openvino/openvino/translate_session.cpp +++ b/ggml/src/ggml-openvino/openvino/translate_session.cpp @@ -18,6 +18,7 @@ #include #include #include +#include #include #include #include @@ -82,6 +83,20 @@ void add_sliced_mask(TensorMap & tensor_map, GgmlDecoder & ggml_model_decoder) { std::shared_ptr mask_sliced; if (is_static) { mask_sliced = mask; + } else if (ggml_model_decoder.is_stateful()) { + auto zero_2d = ov::op::v0::Constant::create(ov::element::i64, {2}, {0,0}); + auto one_2d = ov::op::v0::Constant::create(ov::element::i64, {2}, {1,1}); + auto zero_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {0}); + auto two_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {2}); + auto axes = ov::op::v0::Constant::create(ov::element::i64, {2}, {-2,-1}); + auto inp_pos = tensor_map.at("inp_pos").get_node_shared_ptr(); + auto shape_of_inp_pos = std::make_shared(inp_pos); + auto gather_inp_pos = std::make_shared(shape_of_inp_pos, two_1d, zero_1d); + auto stop = std::make_shared(ov::OutputVector{token_len_per_seq, gather_inp_pos}, 0); + mask_sliced = + std::make_shared(mask, zero_2d, stop, one_2d, axes); + mask_sliced = std::make_shared(mask_sliced, ov::element::f16); + mask_sliced->set_friendly_name(sliced_name); } else { auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0}); auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1}); @@ -226,11 +241,11 @@ std::shared_ptr TranslateSession::apply_transformations(std::shared_ptr(); - // if (!ggml_model_decoder->is_static()) { - // const auto kv_param_res_names = ggml_model_decoder->get_kv_param_res_names(); - // const auto kv_param_res_pairs = get_kv_param_res_pairs(model, kv_param_res_names); - // manager.register_pass(kv_param_res_pairs); - // } + if (ggml_model_decoder->is_stateful()) { + const auto kv_param_res_names = ggml_model_decoder->get_kv_param_res_names(); + const auto kv_param_res_pairs = get_kv_param_res_pairs(model, kv_param_res_names); + manager.register_pass(kv_param_res_pairs); + } if (ggml_model_decoder->is_static()) { manager.register_pass(); diff --git a/ggml/src/ggml-openvino/openvino/utils.cpp b/ggml/src/ggml-openvino/openvino/utils.cpp index bdda30fa6d..b7553f99c8 100644 --- a/ggml/src/ggml-openvino/openvino/utils.cpp +++ b/ggml/src/ggml-openvino/openvino/utils.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #include #include #include @@ -113,11 +114,20 @@ void ggml_rope_yarn_corr_dims(int n_dims, std::pair, ov::Output> make_sin_cos(int32_t * rope_params, std::shared_ptr inp_pos, - std::shared_ptr rope_freqs_weight) { - inp_pos = std::make_shared(inp_pos, ov::element::f32); - auto pos_perm = - std::make_shared(ov::element::i64, ov::Shape{4}, std::vector{0, 3, 1, 2}); - inp_pos = std::make_shared(inp_pos, pos_perm); + std::shared_ptr rope_freqs_weight, + bool stateful) { + if (stateful) { + inp_pos = std::make_shared(inp_pos, ov::op::v0::Constant::create(ov::element::i64, {1}, {0})); + inp_pos = std::make_shared(inp_pos, ov::element::f32); + auto pos_perm = + std::make_shared(ov::element::i64, ov::Shape{3}, std::vector{2, 1, 0}); + inp_pos = std::make_shared(inp_pos, pos_perm); + } else { + inp_pos = std::make_shared(inp_pos, ov::element::f32); + auto pos_perm = + std::make_shared(ov::element::i64, ov::Shape{4}, std::vector{0, 3, 1, 2}); + inp_pos = std::make_shared(inp_pos, pos_perm); + } float freq_base; float freq_scale; @@ -145,8 +155,14 @@ std::pair, ov::Output> make_sin_cos(int32_t * rope_params factor[i] = theta_scale * factor[i - 1]; } - Output freq_factors = - std::make_shared(ov::element::f32, ov::Shape{1, 1, 1, factor.size()}, factor); + Output freq_factors; + if (stateful) { + freq_factors = + std::make_shared(ov::element::f32, ov::Shape{1, 1, factor.size()}, factor); + } else { + freq_factors = + std::make_shared(ov::element::f32, ov::Shape{1, 1, 1, factor.size()}, factor); + } if (rope_freqs_weight) { freq_factors = std::make_shared(freq_factors, rope_freqs_weight); } @@ -161,7 +177,12 @@ std::pair, ov::Output> make_sin_cos(int32_t * rope_params theta = theta_interp; } else { auto ramp_mix = rope_yarn_ramp_mix(n_dims, corr_dims, ext_factor); - auto one = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1, 1}, {1.0f}); + Output one; + if (stateful) { + one = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1}, {1.0f}); + } else { + one = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1, 1}, {1.0f}); + } auto one_minus_ramp = std::make_shared(one, ramp_mix); theta = std::make_shared(std::make_shared(theta_interp, one_minus_ramp), diff --git a/ggml/src/ggml-openvino/openvino/utils.hpp b/ggml/src/ggml-openvino/openvino/utils.hpp index 6c6d2ae8d4..4ffe37ada6 100644 --- a/ggml/src/ggml-openvino/openvino/utils.hpp +++ b/ggml/src/ggml-openvino/openvino/utils.hpp @@ -66,7 +66,8 @@ OutputVector rename_outputs_with_suffix(const OutputVector& outputs, const std:: std::pair, ov::Output> make_sin_cos(int32_t* rope_params, std::shared_ptr inp_pos, - std::shared_ptr rope_freqs_weight = nullptr); + std::shared_ptr rope_freqs_weight = nullptr, + bool stateful = false); ov::Output process_view_input(const NodeContext& context, int input_index, int slice_len = 0); diff --git a/ggml/src/ggml-openvino/utils.cpp b/ggml/src/ggml-openvino/utils.cpp index 89cf51f880..ff94c4acfe 100644 --- a/ggml/src/ggml-openvino/utils.cpp +++ b/ggml/src/ggml-openvino/utils.cpp @@ -46,10 +46,14 @@ enum ggml_status ov_graph_compute(ggml_cgraph * cgraph) { // Use device from singleton (initialized during backend init) const auto & device = ggml_openvino_get_device_name(); const auto is_static = ggml_openvino_is_npu(); - return is_static ? ov_graph_compute_static(cgraph) : ov_graph_compute_dynamic(cgraph, device); + bool stateful = false; + if (getenv("GGML_OPENVINO_STATEFUL_EXECUTION") && !is_static) { + stateful = true; + } + return is_static ? ov_graph_compute_static(cgraph) : ov_graph_compute_dynamic(cgraph, device, stateful); } -enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, const std::string & device) { +enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, const std::string & device, bool stateful) { auto & core = ov_singleton_core(); const auto & config = ggml_openvino_get_compile_config(); static auto is_static = false; @@ -99,6 +103,12 @@ enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, const std::strin ggml_decoder->add_extra_inputs(); infer_request = infer_request_cache[key]; + auto * inp_pos = get_inp_pos_tensor(cgraph); + int32_t * pos_data = (int32_t *) inp_pos->data; + if (pos_data[0] == 0) { + infer_request->reset_state(); + } + decoder_end_time = ggml_time_us(); conversion_end_time = decoder_end_time; compile_end_time = decoder_end_time; @@ -108,7 +118,7 @@ enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, const std::strin std::shared_ptr model; auto model_weights = GgmlOvDecoder::create_weight_nodes(cgraph); - ggml_decoder = std::make_shared(cgraph, m_params, c_params, model_weights, is_static); + ggml_decoder = std::make_shared(cgraph, m_params, c_params, model_weights, is_static, stateful); decoder_end_time = ggml_time_us(); auto input_model = std::make_shared(ggml_decoder); @@ -202,6 +212,7 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph) { static std::string device = "NPU"; static auto is_static = true; + static auto stateful = false; static auto prefill_chunk_size = get_prefill_chunk_size(); const auto & config = ggml_openvino_get_compile_config(); @@ -265,9 +276,9 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph) { auto model_weights = GgmlOvDecoder::create_weight_nodes(cgraph); auto ggml_decoder_prefill = std::make_shared(cgraph, m_params, c_params, model_weights, - is_static, true, prefill_chunk_size); + is_static, stateful, true, prefill_chunk_size); auto ggml_decoder_decode = std::make_shared(cgraph, m_params, c_params, model_weights, - is_static, false, prefill_chunk_size); + is_static, stateful, false, prefill_chunk_size); decoder_end_time = ggml_time_us(); auto input_model_prefill = std::make_shared(ggml_decoder_prefill); @@ -606,8 +617,17 @@ ov::Tensor get_ov_output_tensor(std::shared_ptr ggml_decoder, con if (ggml_decoder->is_static() && result_name == "result_output" && output_shape[2] == 0) { output_shape[2] = 1; } - ov::Tensor output_tensor(output_type, output_shape, ggml_tensor->data); - return output_tensor; + if (ggml_decoder->is_stateful() && result_name == "result_output") { + std::vector output_shape_3d; + for (size_t i=1; idata); + return output_tensor; + } else { + ov::Tensor output_tensor(output_type, output_shape, ggml_tensor->data); + return output_tensor; + } } size_t checksum(const void * data, size_t size) { diff --git a/ggml/src/ggml-openvino/utils.h b/ggml/src/ggml-openvino/utils.h index 44ca2db00f..47bf2d4ff1 100644 --- a/ggml/src/ggml-openvino/utils.h +++ b/ggml/src/ggml-openvino/utils.h @@ -28,7 +28,7 @@ struct graph_key_hash { enum ggml_status ov_graph_compute(struct ggml_cgraph * cgraph); -enum ggml_status ov_graph_compute_dynamic(struct ggml_cgraph * cgraph, const std::string & device); +enum ggml_status ov_graph_compute_dynamic(struct ggml_cgraph * cgraph, const std::string & device, bool stateful = false); enum ggml_status ov_graph_compute_static(struct ggml_cgraph * cgraph); size_t checksum(const void * data, size_t size);