Initial stateful graph support

This commit is contained in:
Mustafa Cavus 2026-01-07 16:05:02 -08:00
parent 0d6f253e48
commit 5f30eacdb4
14 changed files with 197 additions and 58 deletions

View File

@ -44,9 +44,11 @@ GgmlOvDecoder::GgmlOvDecoder(ggml_cgraph * cgraph,
ComputeParams & compute_params,
std::map<std::string, std::shared_ptr<ov::Node>> & model_weights,
bool is_static,
bool is_stateful,
bool is_prefill,
int prefill_chunk_size) :
m_is_static(is_static),
m_is_stateful(is_stateful),
m_is_prefill(is_prefill),
m_prefill_chunk_size(prefill_chunk_size),
m_cgraph(cgraph),
@ -157,19 +159,40 @@ void GgmlOvDecoder::set_input_output(ggml_tensor * node, bool naive) {
ggml_backend_buffer * buffer = src->buffer;
if (buffer->usage == GGML_BACKEND_BUFFER_USAGE_ANY || src->flags & GGML_TENSOR_FLAG_INPUT) {
ov::PartialShape stateful_kv_shape;
// GGML_BACKEND_BUFFER_USAGE_ANY are kv caches
if (buffer->usage == GGML_BACKEND_BUFFER_USAGE_ANY) {
assert(src_name.find("cache_k") == 0 || src_name.find("cache_v") == 0);
if (auto it = std::find(m_model_params.kv_names.begin(), m_model_params.kv_names.end(), src_name); it == m_model_params.kv_names.end()) {
m_model_params.kv_names.push_back(src_name);
if (is_stateful()) {
// TODO: The shape modification for stateful model below is not validated for all supported models yet. More generic solution might be needed
// to enable additional cases. Ideally, this could be removed from decoder and done as part of a transformation later.
auto stateless_kv_shape = get_graph_input_shape(node, src);
assert(stateless_kv_shape.size() == 4 && stateless_kv_shape[0] == 1 && stateless_kv_shape[1] == 1
&& stateless_kv_shape[2].is_dynamic() && stateless_kv_shape[3] == (m_model_params.n_heads_kv*m_model_params.head_size));
stateful_kv_shape = {stateless_kv_shape[0], ov::Dimension::dynamic(), m_model_params.n_heads_kv, m_model_params.head_size};
}
}
}
if (m_model_inputs.find(src_name) != m_model_inputs.end()) {
continue;
}
m_inputs[src_name] = src;
auto param_node =
std::make_shared<ov::op::v0::Parameter>(get_ov_type(src), get_graph_input_shape(node, src));
param_node->set_friendly_name(src_name);
param_node->output(0).get_tensor().set_names({src_name});
m_model_inputs[src_name] = param_node;
assert(stateful_kv_shape.rank().is_static());
if (stateful_kv_shape.rank().get_length() != 0) {
auto param_node =
std::make_shared<ov::op::v0::Parameter>(get_ov_type(src), stateful_kv_shape);
param_node->set_friendly_name(src_name);
param_node->output(0).get_tensor().set_names({src_name});
m_model_inputs[src_name] = param_node;
} else {
auto param_node =
std::make_shared<ov::op::v0::Parameter>(get_ov_type(src), get_graph_input_shape(node, src));
param_node->set_friendly_name(src_name);
param_node->output(0).get_tensor().set_names({src_name});
m_model_inputs[src_name] = param_node;
}
}
}
}
@ -378,6 +401,8 @@ ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor * op, co
} else if (name.find("KQ_mask") == 0) {
if (m_is_static) {
input_shape = ov::PartialShape{1, 1, m_is_prefill ? m_prefill_chunk_size : 1, m_model_params.ctx};
} else if (m_is_stateful) {
input_shape = ov::PartialShape{1, 1, -1, -1};
} else {
input_shape = ov::PartialShape{-1, 1, -1, -1};
}
@ -465,15 +490,15 @@ const ggml_tensor * GgmlOvDecoder::get_tensor_from_name(const std::string & name
return nullptr;
}
// std::map<std::string, std::string> GgmlOvDecoder::get_kv_param_res_names() const {
// std::map<std::string, std::string> kv_param_res_names;
// for (const auto & name : m_model_params.kv_names) {
// if (name.find("cache_k") == 0 || name.find("cache_v") == 0) {
// kv_param_res_names[name] = name;
// }
// }
// return kv_param_res_names;
// }
std::map<std::string, std::string> GgmlOvDecoder::get_kv_param_res_names() const {
std::map<std::string, std::string> kv_param_res_names;
for (const auto & name : m_model_params.kv_names) {
if (name.find("cache_k") == 0 || name.find("cache_v") == 0) {
kv_param_res_names[name] = name;
}
}
return kv_param_res_names;
}
std::map<std::string, std::shared_ptr<ov::Node>> GgmlOvDecoder::create_weight_nodes(ggml_cgraph * cgraph) {
std::map<std::string, std::shared_ptr<ov::Node>> model_weights;

View File

@ -23,7 +23,7 @@ struct ModelParams {
int32_t * rope_params = nullptr;
std::vector<int> swa_layers;
// std::vector<std::string> kv_names;
std::vector<std::string> kv_names;
bool operator==(const ModelParams & other) const {
return n_seq == other.n_seq && n_heads == other.n_heads && n_heads_kv == other.n_heads_kv &&
@ -66,6 +66,7 @@ public:
ComputeParams & compute_params,
std::map<std::string, std::shared_ptr<ov::Node>> & model_weights,
bool is_static,
bool is_stateful = false,
bool is_prefill = false,
int prefill_chunk_size = 256);
@ -171,10 +172,12 @@ public:
virtual int32_t * get_rope_params() const override { return m_model_params.rope_params; }
// virtual std::map<std::string, std::string> get_kv_param_res_names() const override;
virtual std::map<std::string, std::string> get_kv_param_res_names() const override;
virtual bool is_static() const override { return m_is_static; }
virtual bool is_stateful() const override { return m_is_stateful; }
ov::PartialShape get_graph_input_shape(const ggml_tensor * op, const ggml_tensor * input) const;
static void dump_cgraph(const ggml_cgraph * cgraph, std::string & filename);
@ -200,6 +203,7 @@ public:
void set_compute_params(const ComputeParams & compute_params) { m_compute_params = compute_params; }
bool m_is_static = false;
bool m_is_stateful = false;
bool m_is_prefill = false;
int m_prefill_chunk_size = 0;

View File

@ -59,10 +59,13 @@ public:
virtual std::vector<std::string> get_model_output_names() const = 0;
virtual int32_t* get_rope_params() const = 0;
// virtual std::map<std::string, std::string> get_kv_param_res_names() const = 0;
virtual std::map<std::string, std::string> get_kv_param_res_names() const = 0;
virtual bool is_static() const = 0;
virtual bool is_stateful() const = 0;
virtual int is_swa_layer(int layer) const = 0;
};

View File

@ -91,8 +91,11 @@ public:
int get_op_case() const {
return m_decoder->get_op_case(m_node_idx);
}
bool is_static() const { return m_decoder->is_static(); }
bool is_stateful() const { return m_decoder->is_stateful(); }
private:
std::shared_ptr<GgmlDecoder> m_decoder;
std::shared_ptr<TensorMap>& m_tensor_map;

View File

@ -37,6 +37,9 @@ OutputVector translate_get_rows(const NodeContext & context) {
auto axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {1});
data = std::make_shared<ov::op::v0::Squeeze>(data, ov::op::v0::Constant::create(ov::element::i64, {1}, {0}));
res = std::make_shared<ov::op::v8::Gather>(data, indices, axis, 1);
} else if (context.is_stateful() && data.get_partial_shape().rank() == 3) {
auto axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {1});
res = std::make_shared<ov::op::v8::Gather>(data, indices, axis, 1);
} else {
auto axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {0});
res = std::make_shared<ov::op::v8::Gather>(data, indices, axis);
@ -45,7 +48,9 @@ OutputVector translate_get_rows(const NodeContext & context) {
if (res.get_element_type() != context.get_output_type()) {
res = std::make_shared<ov::op::v0::Convert>(res, context.get_output_type());
}
res = std::make_shared<ov::op::v0::Unsqueeze>(res, ov::op::v0::Constant::create(ov::element::i64, {1}, {0}));
if (!(context.is_stateful())) {
res = std::make_shared<ov::op::v0::Unsqueeze>(res, ov::op::v0::Constant::create(ov::element::i64, {1}, {0}));
}
return rename_outputs_with_suffix({res}, context.get_name());
}

View File

@ -29,7 +29,7 @@ OutputVector translate_permute(const NodeContext & context) {
auto src = context.get_input(0);
auto perm = ov::op::v0::Constant::create(ov::element::i64, {4}, {0, 2, 1, 3});
if (op_case == 1) {
if (op_case == 1 || context.is_stateful()) {
res = std::make_shared<ov::op::v1::Transpose>(src, perm);
} else if (op_case == 4) {
auto output_shape = context.get_output_shape().to_shape();

View File

@ -32,10 +32,15 @@ OutputVector translate_reshape(const NodeContext & context) {
auto output_shape = context.get_output_shape().to_shape();
std::shared_ptr<ov::Node> new_shape_node;
if (op_case == 1) {
new_shape_node = ov::op::v0::Constant::create(
ov::element::i64, {4},
std::vector<int64_t>{(int64_t) output_shape[0], -1, (int64_t) output_shape[2], (int64_t) output_shape[3]});
if (context.is_stateful()) {
new_shape_node = ov::op::v0::Constant::create(
ov::element::i64, {3},
std::vector<int64_t>{-1, (int64_t) output_shape[2], (int64_t) output_shape[3]});
} else {
new_shape_node = ov::op::v0::Constant::create(
ov::element::i64, {4},
std::vector<int64_t>{(int64_t) output_shape[0], -1, (int64_t) output_shape[2], (int64_t) output_shape[3]});
}
} else if (op_case == 2) {
new_shape_node = ov::op::v0::Constant::create(
ov::element::i64, {4},
@ -50,8 +55,13 @@ OutputVector translate_reshape(const NodeContext & context) {
return {context.get_input(0).get_node_shared_ptr()->input_value(0)};
} else if (op_case == 5) {
std::vector<int64_t> shape_vec = {1, 1, -1, (int64_t) context.get_output_shape().to_shape()[3]};
new_shape_node = ov::op::v0::Constant::create(ov::element::i64, {4}, shape_vec);
if (context.is_stateful()) {
std::vector<int64_t> shape_vec = {1, -1, (int64_t) context.get_output_shape().to_shape()[3]};
new_shape_node = ov::op::v0::Constant::create(ov::element::i64, {3}, shape_vec);
} else {
std::vector<int64_t> shape_vec = {1, 1, -1, (int64_t) context.get_output_shape().to_shape()[3]};
new_shape_node = ov::op::v0::Constant::create(ov::element::i64, {4}, shape_vec);
}
// // Alternative
// auto token_len = context.get_input("token_len");

View File

@ -54,9 +54,18 @@ OutputVector translate_rope(const NodeContext & context) {
// The input comes from a VIEW
int slice_len = output_shape[2] * output_shape[3];
data_node = process_view_input(context, 0, slice_len).get_node_shared_ptr();
auto data_shape = ov::op::v0::Constant::create(
ov::element::i64, {4}, std::vector<int64_t>{1, -1, (int64_t) output_shape[2], (int64_t) output_shape[3]});
data_node = std::make_shared<ov::op::v1::Reshape>(data_node, data_shape, false);
if (context.is_stateful()) {
auto data_shape = ov::op::v0::Constant::create(
ov::element::i64, {3}, std::vector<int64_t>{-1, (int64_t) output_shape[2], (int64_t) output_shape[3]});
data_node = std::make_shared<ov::op::v1::Reshape>(data_node, data_shape, false);
} else {
auto data_shape = ov::op::v0::Constant::create(
ov::element::i64, {4}, std::vector<int64_t>{1, -1, (int64_t) output_shape[2], (int64_t) output_shape[3]});
data_node = std::make_shared<ov::op::v1::Reshape>(data_node, data_shape, false);
}
//auto data_shape = ov::op::v0::Constant::create(
// ov::element::i64, {4}, std::vector<int64_t>{1, -1, (int64_t) output_shape[2], (int64_t) output_shape[3]});
//data_node = std::make_shared<ov::op::v1::Reshape>(data_node, data_shape, false);
}
const int mode = op_params[2];
@ -67,10 +76,19 @@ OutputVector translate_rope(const NodeContext & context) {
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
auto two = ov::op::v0::Constant::create(ov::element::i64, {1}, {2});
auto three = ov::op::v0::Constant::create(ov::element::i64, {1}, {3});
auto end = ov::op::v0::Constant::create(ov::element::i64, {1}, {output_shape[3]});
auto even_slice = std::make_shared<ov::op::v8::Slice>(data_node, zero, end, two, three);
auto odd_slice = std::make_shared<ov::op::v8::Slice>(data_node, one, end, two, three);
Output<Node> even_slice;
Output<Node> odd_slice;
int32_t unsqueeze_dim = 4;
if (context.is_stateful()) {
unsqueeze_dim = 3;
even_slice = std::make_shared<ov::op::v8::Slice>(data_node, zero, end, two, two);
odd_slice = std::make_shared<ov::op::v8::Slice>(data_node, one, end, two, two);
} else {
auto three = ov::op::v0::Constant::create(ov::element::i64, {1}, {3});
even_slice = std::make_shared<ov::op::v8::Slice>(data_node, zero, end, two, three);
odd_slice = std::make_shared<ov::op::v8::Slice>(data_node, one, end, two, three);
}
Output<Node> first_half =
std::make_shared<ov::op::v1::Subtract>(std::make_shared<ov::op::v1::Multiply>(even_slice, cos_theta_node),
@ -80,10 +98,10 @@ OutputVector translate_rope(const NodeContext & context) {
std::make_shared<ov::op::v1::Multiply>(odd_slice, cos_theta_node));
first_half = std::make_shared<ov::op::v0::Unsqueeze>(first_half,
ov::op::v0::Constant::create(ov::element::i64, {1}, {4}));
ov::op::v0::Constant::create(ov::element::i64, {1}, {unsqueeze_dim}));
second_half = std::make_shared<ov::op::v0::Unsqueeze>(second_half,
ov::op::v0::Constant::create(ov::element::i64, {1}, {4}));
auto stack = std::make_shared<ov::op::v0::Concat>(OutputVector{first_half, second_half}, 4);
ov::op::v0::Constant::create(ov::element::i64, {1}, {unsqueeze_dim}));
auto stack = std::make_shared<ov::op::v0::Concat>(OutputVector{first_half, second_half}, unsqueeze_dim);
auto data_shape = ov::op::v0::Constant::create(
ov::element::i64, {4}, std::vector<int64_t>{1, -1, (int64_t) output_shape[2], (int64_t) output_shape[3]});
@ -102,7 +120,11 @@ OutputVector translate_rope(const NodeContext & context) {
std::make_shared<ov::op::v1::Multiply>(slice_data_node_0, sin_theta_node),
std::make_shared<ov::op::v1::Multiply>(slice_data_node_1, cos_theta_node));
res = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{first_half_node, second_half_node}, 3);
int32_t concat_dim = 3;
if (context.is_stateful()) {
concat_dim = 2;
}
res = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{first_half_node, second_half_node}, concat_dim);
}
return rename_outputs_with_suffix({res}, context.get_name());

View File

@ -45,7 +45,17 @@ OutputVector translate_set_rows(const NodeContext & context) {
false);
auto axes = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {2});
Output<Node> res = std::make_shared<ov::op::v3::ScatterUpdate>(dst, ind_squeezed, data_reshaped, axes);
Output<Node> res;
if (context.is_stateful()) {
int concat_axis = 1;
int64_t dim2 = dst.get_partial_shape()[2].get_length();
int64_t dim3 = dst.get_partial_shape()[3].get_length();
data = std::make_shared<ov::op::v1::Reshape>(
data, ov::op::v0::Constant::create(ov::element::i64, {4}, {(int64_t) 1, (int64_t) -1, dim2, dim3}), false);
res = std::make_shared<ov::op::v0::Concat>(OutputVector{dst, data}, concat_axis);
} else {
res = std::make_shared<ov::op::v3::ScatterUpdate>(dst, ind_squeezed, data_reshaped, axes);
}
if (auto dst_reshape = std::dynamic_pointer_cast<ov::op::v1::Reshape>(dst.get_node_shared_ptr())) {
// Fix the case of multiple sequences, reshape back to original shape [1, n_seq, ctx_per_seq, emb]

View File

@ -18,6 +18,7 @@
#include <openvino/op/convert.hpp>
#include <openvino/op/cos.hpp>
#include <openvino/op/divide.hpp>
#include <openvino/op/gather.hpp>
#include <openvino/op/multiply.hpp>
#include <openvino/op/parameter.hpp>
#include <openvino/op/range.hpp>
@ -82,6 +83,20 @@ void add_sliced_mask(TensorMap & tensor_map, GgmlDecoder & ggml_model_decoder) {
std::shared_ptr<ov::Node> mask_sliced;
if (is_static) {
mask_sliced = mask;
} else if (ggml_model_decoder.is_stateful()) {
auto zero_2d = ov::op::v0::Constant::create(ov::element::i64, {2}, {0,0});
auto one_2d = ov::op::v0::Constant::create(ov::element::i64, {2}, {1,1});
auto zero_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
auto two_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {2});
auto axes = ov::op::v0::Constant::create(ov::element::i64, {2}, {-2,-1});
auto inp_pos = tensor_map.at("inp_pos").get_node_shared_ptr();
auto shape_of_inp_pos = std::make_shared<ov::op::v3::ShapeOf>(inp_pos);
auto gather_inp_pos = std::make_shared<ov::op::v8::Gather>(shape_of_inp_pos, two_1d, zero_1d);
auto stop = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{token_len_per_seq, gather_inp_pos}, 0);
mask_sliced =
std::make_shared<ov::op::v8::Slice>(mask, zero_2d, stop, one_2d, axes);
mask_sliced = std::make_shared<ov::op::v0::Convert>(mask_sliced, ov::element::f16);
mask_sliced->set_friendly_name(sliced_name);
} else {
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
@ -226,11 +241,11 @@ std::shared_ptr<Model> TranslateSession::apply_transformations(std::shared_ptr<M
manager.set_per_pass_validation(true);
manager.register_pass<ov::pass::MarkCompressedFloatConstants>();
// if (!ggml_model_decoder->is_static()) {
// const auto kv_param_res_names = ggml_model_decoder->get_kv_param_res_names();
// const auto kv_param_res_pairs = get_kv_param_res_pairs(model, kv_param_res_names);
// manager.register_pass<ov::pass::MakeStateful>(kv_param_res_pairs);
// }
if (ggml_model_decoder->is_stateful()) {
const auto kv_param_res_names = ggml_model_decoder->get_kv_param_res_names();
const auto kv_param_res_pairs = get_kv_param_res_pairs(model, kv_param_res_names);
manager.register_pass<ov::pass::MakeStateful>(kv_param_res_pairs);
}
if (ggml_model_decoder->is_static()) {
manager.register_pass<pass::EliminateZeroPoints>();

View File

@ -15,6 +15,7 @@
#include <openvino/op/multiply.hpp>
#include <openvino/op/shape_of.hpp>
#include <openvino/op/sin.hpp>
#include <openvino/op/squeeze.hpp>
#include <openvino/op/subtract.hpp>
#include <openvino/op/transpose.hpp>
#include <string>
@ -113,11 +114,20 @@ void ggml_rope_yarn_corr_dims(int n_dims,
std::pair<ov::Output<Node>, ov::Output<Node>> make_sin_cos(int32_t * rope_params,
std::shared_ptr<ov::Node> inp_pos,
std::shared_ptr<ov::Node> rope_freqs_weight) {
inp_pos = std::make_shared<ov::op::v0::Convert>(inp_pos, ov::element::f32);
auto pos_perm =
std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{4}, std::vector<int64_t>{0, 3, 1, 2});
inp_pos = std::make_shared<ov::op::v1::Transpose>(inp_pos, pos_perm);
std::shared_ptr<ov::Node> rope_freqs_weight,
bool stateful) {
if (stateful) {
inp_pos = std::make_shared<ov::op::v0::Squeeze>(inp_pos, ov::op::v0::Constant::create(ov::element::i64, {1}, {0}));
inp_pos = std::make_shared<ov::op::v0::Convert>(inp_pos, ov::element::f32);
auto pos_perm =
std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{3}, std::vector<int64_t>{2, 1, 0});
inp_pos = std::make_shared<ov::op::v1::Transpose>(inp_pos, pos_perm);
} else {
inp_pos = std::make_shared<ov::op::v0::Convert>(inp_pos, ov::element::f32);
auto pos_perm =
std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{4}, std::vector<int64_t>{0, 3, 1, 2});
inp_pos = std::make_shared<ov::op::v1::Transpose>(inp_pos, pos_perm);
}
float freq_base;
float freq_scale;
@ -145,8 +155,14 @@ std::pair<ov::Output<Node>, ov::Output<Node>> make_sin_cos(int32_t * rope_params
factor[i] = theta_scale * factor[i - 1];
}
Output<Node> freq_factors =
std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{1, 1, 1, factor.size()}, factor);
Output<Node> freq_factors;
if (stateful) {
freq_factors =
std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{1, 1, factor.size()}, factor);
} else {
freq_factors =
std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{1, 1, 1, factor.size()}, factor);
}
if (rope_freqs_weight) {
freq_factors = std::make_shared<ov::op::v1::Divide>(freq_factors, rope_freqs_weight);
}
@ -161,7 +177,12 @@ std::pair<ov::Output<Node>, ov::Output<Node>> make_sin_cos(int32_t * rope_params
theta = theta_interp;
} else {
auto ramp_mix = rope_yarn_ramp_mix(n_dims, corr_dims, ext_factor);
auto one = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1, 1}, {1.0f});
Output<Node> one;
if (stateful) {
one = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1}, {1.0f});
} else {
one = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1, 1}, {1.0f});
}
auto one_minus_ramp = std::make_shared<ov::op::v1::Subtract>(one, ramp_mix);
theta = std::make_shared<ov::op::v1::Add>(std::make_shared<ov::op::v1::Multiply>(theta_interp, one_minus_ramp),

View File

@ -66,7 +66,8 @@ OutputVector rename_outputs_with_suffix(const OutputVector& outputs, const std::
std::pair<ov::Output<Node>, ov::Output<Node>> make_sin_cos(int32_t* rope_params,
std::shared_ptr<ov::Node> inp_pos,
std::shared_ptr<ov::Node> rope_freqs_weight = nullptr);
std::shared_ptr<ov::Node> rope_freqs_weight = nullptr,
bool stateful = false);
ov::Output<ov::Node> process_view_input(const NodeContext& context, int input_index, int slice_len = 0);

View File

@ -46,10 +46,14 @@ enum ggml_status ov_graph_compute(ggml_cgraph * cgraph) {
// Use device from singleton (initialized during backend init)
const auto & device = ggml_openvino_get_device_name();
const auto is_static = ggml_openvino_is_npu();
return is_static ? ov_graph_compute_static(cgraph) : ov_graph_compute_dynamic(cgraph, device);
bool stateful = false;
if (getenv("GGML_OPENVINO_STATEFUL_EXECUTION") && !is_static) {
stateful = true;
}
return is_static ? ov_graph_compute_static(cgraph) : ov_graph_compute_dynamic(cgraph, device, stateful);
}
enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, const std::string & device) {
enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, const std::string & device, bool stateful) {
auto & core = ov_singleton_core();
const auto & config = ggml_openvino_get_compile_config();
static auto is_static = false;
@ -99,6 +103,12 @@ enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, const std::strin
ggml_decoder->add_extra_inputs();
infer_request = infer_request_cache[key];
auto * inp_pos = get_inp_pos_tensor(cgraph);
int32_t * pos_data = (int32_t *) inp_pos->data;
if (pos_data[0] == 0) {
infer_request->reset_state();
}
decoder_end_time = ggml_time_us();
conversion_end_time = decoder_end_time;
compile_end_time = decoder_end_time;
@ -108,7 +118,7 @@ enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, const std::strin
std::shared_ptr<ov::Model> model;
auto model_weights = GgmlOvDecoder::create_weight_nodes(cgraph);
ggml_decoder = std::make_shared<GgmlOvDecoder>(cgraph, m_params, c_params, model_weights, is_static);
ggml_decoder = std::make_shared<GgmlOvDecoder>(cgraph, m_params, c_params, model_weights, is_static, stateful);
decoder_end_time = ggml_time_us();
auto input_model = std::make_shared<ov::frontend::ggml::InputModel>(ggml_decoder);
@ -202,6 +212,7 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph) {
static std::string device = "NPU";
static auto is_static = true;
static auto stateful = false;
static auto prefill_chunk_size = get_prefill_chunk_size();
const auto & config = ggml_openvino_get_compile_config();
@ -265,9 +276,9 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph) {
auto model_weights = GgmlOvDecoder::create_weight_nodes(cgraph);
auto ggml_decoder_prefill = std::make_shared<GgmlOvDecoder>(cgraph, m_params, c_params, model_weights,
is_static, true, prefill_chunk_size);
is_static, stateful, true, prefill_chunk_size);
auto ggml_decoder_decode = std::make_shared<GgmlOvDecoder>(cgraph, m_params, c_params, model_weights,
is_static, false, prefill_chunk_size);
is_static, stateful, false, prefill_chunk_size);
decoder_end_time = ggml_time_us();
auto input_model_prefill = std::make_shared<ov::frontend::ggml::InputModel>(ggml_decoder_prefill);
@ -606,8 +617,17 @@ ov::Tensor get_ov_output_tensor(std::shared_ptr<GgmlOvDecoder> ggml_decoder, con
if (ggml_decoder->is_static() && result_name == "result_output" && output_shape[2] == 0) {
output_shape[2] = 1;
}
ov::Tensor output_tensor(output_type, output_shape, ggml_tensor->data);
return output_tensor;
if (ggml_decoder->is_stateful() && result_name == "result_output") {
std::vector<long unsigned int> output_shape_3d;
for (size_t i=1; i<output_shape.size(); i++) {
output_shape_3d.push_back(output_shape[i]);
}
ov::Tensor output_tensor(output_type, output_shape_3d, ggml_tensor->data);
return output_tensor;
} else {
ov::Tensor output_tensor(output_type, output_shape, ggml_tensor->data);
return output_tensor;
}
}
size_t checksum(const void * data, size_t size) {

View File

@ -28,7 +28,7 @@ struct graph_key_hash {
enum ggml_status ov_graph_compute(struct ggml_cgraph * cgraph);
enum ggml_status ov_graph_compute_dynamic(struct ggml_cgraph * cgraph, const std::string & device);
enum ggml_status ov_graph_compute_dynamic(struct ggml_cgraph * cgraph, const std::string & device, bool stateful = false);
enum ggml_status ov_graph_compute_static(struct ggml_cgraph * cgraph);
size_t checksum(const void * data, size_t size);