NPU Unify PD (#14)

* Stateless. Fix llama-cli llama-server

* Simplify broadcast op in attention

* Replace get_output_tensor+memcpy with set_output_tensor

* NPU unify PD. Unify dynamic and static dims
This commit is contained in:
Zijun Yu 2025-11-04 15:19:09 +08:00 committed by Mustafa Cavus
parent eba8113dc4
commit b8690bc055
11 changed files with 227 additions and 370 deletions

View File

@ -27,7 +27,6 @@
#include <openvino/op/constant.hpp>
#include <openvino/op/convert.hpp>
#include <openvino/op/parameter.hpp>
#include <openvino/op/unsqueeze.hpp>
#include <openvino/runtime/tensor.hpp>
#include <optional>
#include <ostream>
@ -39,7 +38,6 @@
GgmlOvDecoder::GgmlOvDecoder(ggml_tensor * node,
ggml_cgraph * cgraph,
bool is_static,
bool is_first_token,
int context_size,
int context_size_swa,
int num_heads,
@ -55,25 +53,24 @@ GgmlOvDecoder::GgmlOvDecoder(ggml_tensor * node,
m_num_heads(num_heads),
m_num_heads_kv(num_heads_kv),
m_head_size(head_size),
m_is_static(is_static),
m_is_first_token(is_first_token) {
m_is_static(is_static) {
set_input_output(node);
}
GgmlOvDecoder::GgmlOvDecoder(ggml_cgraph * cgraph,
std::map<std::string, std::shared_ptr<ov::Node>> & model_weights,
bool is_static,
bool is_first_token) :
bool is_static) :
m_cgraph(cgraph),
m_op_name(m_node ? std::string(m_node->name) : ""),
m_model_weights(model_weights),
m_is_static(is_static),
m_is_first_token(is_first_token) {
if (is_first_token && getenv("GGML_OPENVINO_PRINT_CGRAPH_TENSOR_ADDRESS")) {
m_is_static(is_static) {
if (auto * env = getenv("GGML_OPENVINO_PRINT_CGRAPH_TENSOR_ADDRESS"); env && std::string(env) != "0") {
unsetenv("GGML_OPENVINO_PRINT_CGRAPH_TENSOR_ADDRESS");
print_tensor_address_map(cgraph);
}
set_llm_params();
validate_cgraph();
for (int node_n = 0; node_n < cgraph->n_nodes; node_n++) {
auto * cur_node = cgraph->nodes[node_n];
@ -160,8 +157,8 @@ void GgmlOvDecoder::set_input_output(ggml_tensor * node, bool naive) {
// Model outputs are tensors with GGML_TENSOR_FLAG_OUTPUT flag and kv_caches
static std::set<std::string> debug_output_names = {};
// Workaround: the final tensor "result_output" does not have GGML_TENSOR_FLAG_OUTPUT flag set in cgraph
if (node->op == GGML_OP_SET_ROWS || node->flags & GGML_TENSOR_FLAG_OUTPUT || node_name.find("result") == 0 ||
debug_output_names.count(node_name)) {
if (node->op == GGML_OP_SET_ROWS || node->flags & GGML_TENSOR_FLAG_OUTPUT ||
node_name.find("output") != std::string::npos || debug_output_names.count(node_name)) {
if (node->op == GGML_OP_SET_ROWS) {
assert(node_name.find("cache_k") == 0 || node_name.find("cache_v") == 0);
if (auto it = std::find(m_kv_names.begin(), m_kv_names.end(), node_name); it == m_kv_names.end()) {
@ -285,53 +282,54 @@ void GgmlOvDecoder::set_llm_params() {
} else {
m_context_size = cache_k->ne[1];
}
} else if (node->op == GGML_OP_ROPE &&
(name.find("Qcur-0") == 0 || std::string(node->src[0]->name).find("Qcur-0") == 0)) {
m_head_size = node->ne[0];
m_num_heads = node->ne[1];
m_rope_params = node->op_params;
} else if (node->op == GGML_OP_ROPE &&
(name.find("Kcur-0") == 0 || std::string(node->src[0]->name).find("Kcur-0") == 0)) {
m_num_heads_kv = node->ne[1];
} else if (node->op == GGML_OP_ROPE) {
if (name.find("Qcur-0") == 0 || std::string(node->src[0]->name).find("Qcur-0") == 0) {
m_head_size = node->ne[0];
m_num_heads = node->ne[1];
m_rope_params = node->op_params;
auto * inp_pos = node->src[1];
m_input_len = inp_pos->ne[0];
m_past_kv_len = *(int32_t *) inp_pos->data;
} else if (name.find("Kcur-0") == 0 || std::string(node->src[0]->name).find("Kcur-0") == 0) {
m_num_heads_kv = node->ne[1];
}
}
}
}
void GgmlOvDecoder::validate_cgraph() const {
if (m_is_static && m_input_len != 1) {
throw std::runtime_error("Static graph (NPU) must have input_len == 1, but got " + std::to_string(m_input_len) +
", try set -ub 1");
}
}
ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor * src) const {
auto name = std::string(src->name);
ov::PartialShape input_shape;
if (name == "inp_tokens" || name == "inp_pos") {
if (m_is_static) {
if (m_is_first_token) {
input_shape = ov::PartialShape{1, 1, m_context_size};
} else {
input_shape = ov::PartialShape{1, 1, 1};
}
} else {
input_shape = ov::PartialShape{1, 1, -1};
}
} else if (name == "inp_out_ids" && !m_is_static) {
input_shape = ov::PartialShape{1, 1, -1};
if (name == "inp_tokens" || name == "inp_pos" || name == "inp_out_ids") {
input_shape = ov::PartialShape{1, 1, m_is_static ? 1 : -1};
} else if (name.find("KQ_mask") == 0) {
if (m_is_static) {
if (m_is_first_token) {
input_shape = ov::PartialShape{1, m_context_size, m_context_size};
} else {
input_shape = ov::PartialShape{1, 1, m_context_size};
}
input_shape = ov::PartialShape{1, 1, m_context_size};
} else {
input_shape = ov::PartialShape{1, -1, -1};
}
} else if (name.find("cache_") == 0) {
auto past_token_len = -1;
if (m_is_static) {
int layer = extract_layer_from_name(name);
bool is_swa = is_swa_layer(layer);
input_shape = ov::PartialShape{is_swa ? m_context_size_swa : m_context_size, m_num_heads_kv, m_head_size};
} else {
input_shape = ov::PartialShape{1, -1, m_num_heads_kv, m_head_size};
past_token_len = is_swa ? m_context_size_swa : m_context_size;
}
input_shape = ov::PartialShape{past_token_len, m_num_heads_kv, m_head_size};
} else if (const auto * op = get_tensor_used_op(src); op && op->op == GGML_OP_SET_ROWS) {
input_shape = ov::PartialShape{1, 1, m_is_static ? 1 : -1};
} else if (src->op == GGML_OP_VIEW) {
// This case is added to make test-backend-ops work
input_shape = ov::PartialShape{get_shape(src->view_src)};
@ -745,9 +743,8 @@ int32_t * GgmlOvDecoder::get_output_op_params(const std::string & name) const {
void GgmlOvDecoder::visit_subgraph(std::function<void(std::shared_ptr<GgmlDecoder>)> node_visitor) const {
for (const auto & node : m_nodes) {
auto decoder =
std::make_shared<GgmlOvDecoder>(node, m_cgraph, m_is_static, m_is_first_token, m_context_size,
m_context_size_swa, m_num_heads, m_num_heads_kv, m_head_size, m_swa_layers);
auto decoder = std::make_shared<GgmlOvDecoder>(node, m_cgraph, m_is_static, m_context_size, m_context_size_swa,
m_num_heads, m_num_heads_kv, m_head_size, m_swa_layers);
node_visitor(decoder);
}
}

View File

@ -16,14 +16,12 @@ public:
// Graph decoder
GgmlOvDecoder(ggml_cgraph * cgraph,
std::map<std::string, std::shared_ptr<ov::Node>> & model_weights,
bool is_static,
bool is_first_token);
bool is_static);
// Node decoder, called in GgmlOvDecoder::visit_subgraph
GgmlOvDecoder(ggml_tensor * node,
ggml_cgraph * cgraph,
bool is_static,
bool is_first_token,
int context_size,
int context_size_swa,
int num_heads,
@ -81,9 +79,9 @@ public:
virtual void visit_subgraph(std::function<void(std::shared_ptr<GgmlDecoder>)> node_visitor) const override;
const ggml_tensor * get_input_ggml_tensor(const std::string & name) const { return m_inputs.at(name); }
ggml_tensor * get_input_ggml_tensor(const std::string & name) const { return m_inputs.at(name); }
const ggml_tensor * get_output_ggml_tensor(const std::string & name) const { return m_outputs.at(name); }
ggml_tensor * get_output_ggml_tensor(const std::string & name) const { return m_outputs.at(name); }
virtual int get_op_case() const override { return m_op_case; }
@ -119,14 +117,16 @@ public:
virtual int get_head_size() const override { return m_head_size; }
int get_past_kv_len() const { return m_past_kv_len; }
int get_input_len() const { return m_input_len; }
virtual int32_t * get_rope_params() const override { return m_rope_params; }
virtual std::map<std::string, std::string> get_kv_param_res_names() const override;
virtual bool is_static() const override { return m_is_static; }
virtual bool is_first_token() const override { return m_is_first_token; }
ov::PartialShape get_graph_input_shape(const ggml_tensor * src) const;
static void dump_cgraph(const ggml_cgraph * cgraph, std::string & filename);
@ -153,6 +153,7 @@ private:
// set context_size, num_heads, etc
void set_llm_params();
void validate_cgraph() const;
ggml_cgraph * m_cgraph = nullptr;
ggml_tensor * m_node = nullptr;
@ -176,10 +177,11 @@ private:
int m_num_heads;
int m_num_heads_kv;
int m_head_size;
int m_past_kv_len;
int m_input_len;
int32_t * m_rope_params;
std::vector<std::string> m_kv_names;
bool m_is_static = false;
bool m_is_first_token;
};
void print_tensor_address_map(const ggml_cgraph * cgraph);

View File

@ -65,7 +65,6 @@ public:
virtual std::map<std::string, std::string> get_kv_param_res_names() const = 0;
virtual bool is_static() const = 0;
virtual bool is_first_token() const = 0;
virtual int get_context_size() const = 0;
virtual int get_context_size_swa() const = 0;
virtual int is_swa_layer(int layer) const = 0;

View File

@ -97,12 +97,7 @@ public:
int get_op_case() const {
return m_decoder->get_op_case();
}
bool is_static() const {
return m_decoder->is_static();
}
bool is_first_token() const {
return m_decoder->is_first_token();
}
bool is_static() const { return m_decoder->is_static(); }
private:
std::shared_ptr<GgmlDecoder> m_decoder;

View File

@ -2,9 +2,11 @@
#include "../op_table.hpp"
#include "../utils.hpp"
#include <cstdint>
#include <memory>
#include <openvino/op/broadcast.hpp>
#include <openvino/op/concat.hpp>
#include <openvino/op/constant.hpp>
#include <openvino/op/convert.hpp>
#include <openvino/op/reshape.hpp>
#include <openvino/op/scaled_dot_product_attention.hpp>
@ -51,43 +53,25 @@ OutputVector translate_flash_attn_ext(const NodeContext & context) {
auto stop = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{token_len, kv_len}, 0);
mask_sliced = std::make_shared<ov::op::v8::Slice>(mask, zero_2d, stop, one_2d, axes);
mask_sliced = std::make_shared<ov::op::v0::Unsqueeze>(mask_sliced, zero_1d);
}
if (mask_sliced.get_element_type() != ov::element::f16) {
mask_sliced = std::make_shared<ov::op::v0::Convert>(mask_sliced, ov::element::f16);
}
auto tile_kv = [](int64_t q_batch, int64_t kv_batch, ov::Output<Node> kv, bool is_static) {
int64_t factor = q_batch / kv_batch;
auto tile_kv = [&](int64_t num_heads, int64_t num_heads_kv, int64_t head_size, ov::Output<Node> kv) {
int64_t factor = num_heads / num_heads_kv;
if (factor > 1) {
auto q_batch_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{q_batch});
auto kv_batch_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{kv_batch});
auto factor_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{factor});
ov::Output<ov::Node> kv_broadcast_shape, kv_unsqueezed, new_kv_shape;
if (is_static) {
auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {1});
kv_unsqueezed = std::make_shared<ov::op::v0::Unsqueeze>(kv, unsqueeze_axes);
auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {1});
kv_unsqueezed = std::make_shared<ov::op::v0::Unsqueeze>(kv, unsqueeze_axes);
auto kv_last_two_dims = get_dimensions(kv.get_node_shared_ptr(), {1, 2});
kv_broadcast_shape = std::make_shared<ov::op::v0::Concat>(
ov::OutputVector{kv_batch_node, factor_node, kv_last_two_dims}, 0);
new_kv_shape =
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{q_batch_node, kv_last_two_dims}, 0);
} else {
auto one_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {2});
kv_unsqueezed = std::make_shared<ov::op::v0::Unsqueeze>(kv, unsqueeze_axes);
kv_broadcast_shape =
ov::op::v0::Constant::create(ov::element::i64, {4}, {num_heads_kv, factor, (int64_t) 1, head_size});
new_kv_shape = ov::op::v0::Constant::create(ov::element::i64, {3}, {num_heads, (int64_t) -1, head_size});
auto kv_last_two_dims = get_dimensions(kv.get_node_shared_ptr(), {2, 3});
kv_broadcast_shape = std::make_shared<ov::op::v0::Concat>(
ov::OutputVector{one_1d, kv_batch_node, factor_node, kv_last_two_dims}, 0);
new_kv_shape =
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{one_1d, q_batch_node, kv_last_two_dims}, 0);
}
kv = std::make_shared<ov::op::v3::Broadcast>(kv_unsqueezed, kv_broadcast_shape);
kv = std::make_shared<ov::op::v3::Broadcast>(kv_unsqueezed, kv_broadcast_shape,
ov::op::BroadcastType::BIDIRECTIONAL);
kv = std::make_shared<ov::op::v1::Reshape>(kv, new_kv_shape, false);
}
return kv;
@ -95,18 +79,12 @@ OutputVector translate_flash_attn_ext(const NodeContext & context) {
auto q_shape = context.get_input_shape(0).to_shape();
auto k_shape = context.get_input_shape(1).to_shape();
k = tile_kv(q_shape[0], k_shape[0], k, context.is_static());
v = tile_kv(q_shape[0], k_shape[0], v, context.is_static());
k = tile_kv(q_shape[0], k_shape[0], q_shape[2], k);
v = tile_kv(q_shape[0], k_shape[0], q_shape[2], v);
auto sdpa = std::make_shared<ov::op::v13::ScaledDotProductAttention>(q, k, v, mask_sliced, scale_node, false);
auto sdpa_f32 = std::make_shared<ov::op::v0::Convert>(sdpa, ov::element::f32);
if (context.is_static()) {
res = std::make_shared<ov::op::v1::Transpose>(sdpa_f32,
ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2}));
} else {
res = std::make_shared<ov::op::v1::Transpose>(
sdpa_f32, ov::op::v0::Constant::create(ov::element::i64, {4}, {0, 2, 1, 3}));
}
res = std::make_shared<ov::op::v1::Transpose>(sdpa, ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2}));
res = std::make_shared<ov::op::v0::Convert>(res, ov::element::f32);
return rename_outputs_with_suffix({res}, context.get_name());
}

View File

@ -26,40 +26,8 @@ OutputVector translate_permute(const NodeContext & context) {
ov::Output<Node> res;
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
if (op_case == 1) {
if (context.is_static()) {
res = std::make_shared<ov::op::v1::Transpose>(
context.get_input(0), ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2}));
} else {
auto src = context.get_input(0);
if (src.get_partial_shape().rank() == 3) {
src = std::make_shared<ov::op::v0::Unsqueeze>(src, zero);
}
res = std::make_shared<ov::op::v1::Transpose>(
src, ov::op::v0::Constant::create(ov::element::i64, {4}, {0, 2, 1, 3}));
}
} else {
auto src = context.get_input(0);
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
if (context.is_static()) {
auto src_shape_ = context.get_input_shape(0).to_shape();
std::vector<int64_t> src_shape(src_shape_.begin(), src_shape_.end());
auto src_reshaped = std::make_shared<ov::op::v1::Reshape>(
src,
ov::op::v0::Constant::create(ov::element::i64, {3},
std::vector<int64_t>{-1, src_shape[1], src_shape[2]}),
false);
res = std::make_shared<ov::op::v1::Transpose>(
src_reshaped, ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2}));
} else {
if (src.get_partial_shape().rank() == 3) {
src = std::make_shared<ov::op::v0::Unsqueeze>(src, zero);
}
res = std::make_shared<ov::op::v1::Transpose>(
src, ov::op::v0::Constant::create(ov::element::i64, {4}, {0, 2, 1, 3}));
}
}
auto src = context.get_input(0);
res = std::make_shared<ov::op::v1::Transpose>(src, ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2}));
return rename_outputs_with_suffix({res}, context.get_name());
}

View File

@ -84,10 +84,6 @@ OutputVector translate_rope(const NodeContext & context) {
ov::op::v0::Constant::create(ov::element::i64, {1}, {3}));
auto stack = std::make_shared<ov::op::v0::Concat>(OutputVector{first_half, second_half}, 3);
res = std::make_shared<ov::op::v1::Reshape>(stack, std::make_shared<ov::op::v0::ShapeOf>(data_node), false);
if (!(context.is_static())) {
res =
std::make_shared<ov::op::v0::Unsqueeze>(res, ov::op::v0::Constant::create(ov::element::i64, {1}, {0}));
}
} else if (mode == ROPE_TYPE_NEOX) {
auto data_split = std::make_shared<ov::op::v1::Split>(
data_node, ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {2}), 2);

View File

@ -33,10 +33,6 @@ OutputVector translate_set_rows(const NodeContext & context) {
auto dst_shape = context.get_output_shape(0).to_shape();
FRONT_END_OP_CONVERSION_CHECK(dst_shape[0] == 1, "Unsupported shape in SET_ROWS");
if (context.is_static() && context.is_first_token()) {
return rename_outputs_with_suffix({data}, context.get_name());
}
auto indices = context.get_input(1);
auto dst = context.get_input(context.get_output_name());
@ -54,13 +50,11 @@ OutputVector translate_set_rows(const NodeContext & context) {
auto updated = std::make_shared<ov::op::v3::ScatterUpdate>(dst_reshaped, indices_reshaped, data_reshaped, zero);
res = std::make_shared<ov::op::v1::Reshape>(updated, std::make_shared<ov::op::v0::ShapeOf>(dst), false);
} else {
assert(dst.get_partial_shape().rank() == 4 && dst.get_partial_shape()[2].is_static() &&
dst.get_partial_shape()[3].is_static());
int64_t dim1 = dst.get_partial_shape()[1].get_length();
int64_t dim2 = dst.get_partial_shape()[2].get_length();
int64_t dim3 = dst.get_partial_shape()[3].get_length();
data = std::make_shared<ov::op::v1::Reshape>(
data, ov::op::v0::Constant::create(ov::element::i64, {4}, {(int64_t) 1, (int64_t) -1, dim2, dim3}), false);
res = std::make_shared<ov::op::v0::Concat>(OutputVector{dst, data}, 1);
data, ov::op::v0::Constant::create(ov::element::i64, {3}, {(int64_t) -1, dim1, dim2}), false);
res = std::make_shared<ov::op::v0::Concat>(OutputVector{dst, data}, 0);
}
return rename_outputs_with_suffix({res}, context.get_name());
}

View File

@ -27,7 +27,6 @@
#include <openvino/op/squeeze.hpp>
#include <openvino/op/strided_slice.hpp>
#include <openvino/op/transpose.hpp>
#include <openvino/op/unsqueeze.hpp>
#include <openvino/pass/constant_folding.hpp>
#include <openvino/pass/make_stateful.hpp>
@ -112,7 +111,6 @@ void add_sliced_mask(TensorMap & tensor_map, GgmlDecoder & ggml_model_decoder) {
auto stop = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{token_len, kv_len}, 0);
mask_sliced = std::make_shared<ov::op::v8::Slice>(mask, zero_2d, stop, one_2d, axes);
mask_sliced = std::make_shared<ov::op::v0::Unsqueeze>(mask_sliced, zero_1d);
mask_sliced = std::make_shared<ov::op::v0::Convert>(mask_sliced, ov::element::f16);
mask_sliced->set_friendly_name(sliced_name);
}
@ -243,11 +241,11 @@ std::shared_ptr<Model> TranslateSession::apply_transformations(std::shared_ptr<M
manager.set_per_pass_validation(true);
manager.register_pass<ov::pass::MarkCompressedFloatConstants>();
if (!ggml_model_decoder->is_static()) {
const auto kv_param_res_names = ggml_model_decoder->get_kv_param_res_names();
const auto kv_param_res_pairs = get_kv_param_res_pairs(model, kv_param_res_names);
manager.register_pass<ov::pass::MakeStateful>(kv_param_res_pairs);
}
// if (!ggml_model_decoder->is_static()) {
// const auto kv_param_res_names = ggml_model_decoder->get_kv_param_res_names();
// const auto kv_param_res_pairs = get_kv_param_res_pairs(model, kv_param_res_names);
// manager.register_pass<ov::pass::MakeStateful>(kv_param_res_pairs);
// }
// if (ggml_model_decoder->is_static()) {
manager.register_pass<pass::EliminateZeroPoints>();

View File

@ -12,12 +12,14 @@
#include <cstddef>
#include <cstdint>
#include <cstdlib>
#include <cstring>
#include <iomanip>
#include <iostream>
#include <memory>
#include <mutex>
#include <openvino/core/any.hpp>
#include <openvino/core/graph_util.hpp>
#include <openvino/core/shape.hpp>
#include <openvino/core/type/float16.hpp>
#include <openvino/frontend/manager.hpp>
#include <openvino/openvino.hpp>
@ -26,60 +28,29 @@
#include <openvino/runtime/intel_npu/properties.hpp>
#include <openvino/runtime/properties.hpp>
#include <openvino/runtime/tensor.hpp>
#include <string>
#include <unordered_map>
#include <vector>
ov::Tensor convert_ggml_input_to_ov(std::shared_ptr<GgmlOvDecoder> ggml_decoder, const std::string & name) {
const auto * ggml_tensor = ggml_decoder->get_input_ggml_tensor(name);
auto * input_data = ggml_tensor->data;
ov::Shape input_shape;
if (name.find("cache_k") == 0 || name.find("cache_v") == 0) {
input_shape = ggml_decoder->get_graph_input_shape(ggml_tensor).to_shape();
} else if (ggml_tensor->op == GGML_OP_VIEW) {
// This case is added to make test-backend-ops work
input_shape = ggml_decoder->get_graph_input_shape(ggml_tensor->view_src).to_shape();
} else {
input_shape = ggml_decoder->get_input_shape(name).to_shape();
}
auto input_tensor = ov::Tensor(ggml_decoder->get_input_type(name), input_shape, input_data);
return input_tensor;
}
std::map<std::string, void *> get_ggml_graph_output_dst(std::shared_ptr<GgmlOvDecoder> ggml_decoder) {
std::map<std::string, void *> output_tensors;
auto output_names = ggml_decoder->get_model_output_names();
for (size_t inp = 0; inp < output_names.size(); ++inp) {
auto name = output_names[inp];
const auto * tensor = ggml_decoder->get_output_ggml_tensor(name);
auto * output_data = tensor->view_src ? tensor->view_src->data : tensor->data;
output_tensors[name] = output_data;
}
return output_tensors;
}
static ov::frontend::FrontEnd::Ptr get_ggml_frontend() {
auto fem = ov::frontend::FrontEndManager();
auto front_end = fem.load_by_framework("ggml");
return front_end;
}
// Suppress deprecation warning for ov::Tensor::data()
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
enum ggml_status openvino_frontend_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
static ov::Core core;
static std::string device = getenv("GGML_OPENVINO_DEVICE") ? getenv("GGML_OPENVINO_DEVICE") : "";
if (device.empty()) {
const std::vector<std::string> preferred_device = {"GPU", "CPU", "NPU"};
const auto available_devices = core.get_available_devices();
for (const auto & dev : preferred_device) {
if (std::find(available_devices.begin(), available_devices.end(), dev) != available_devices.end()) {
device = dev;
break;
}
auto get_device = [&] {
std::string device = getenv("GGML_OPENVINO_DEVICE") ? getenv("GGML_OPENVINO_DEVICE") : "CPU";
auto available_devices = core.get_available_devices();
if (std::find(available_devices.begin(), available_devices.end(), device) == available_devices.end()) {
GGML_LOG_WARN("GGML OpenVINO Backend: device %s is not available, fallback to CPU\n", device.c_str());
device = "CPU";
}
}
return device;
};
static std::string device = get_device();
bool is_static = device == "NPU" ? true : false;
ov::AnyMap config;
if (getenv("GGML_OPENVINO_DUMP_CGRAPH")) {
@ -102,11 +73,9 @@ enum ggml_status openvino_frontend_compute(ggml_backend_t backend, ggml_cgraph *
static std::unordered_map<ggml_cgraph *, std::shared_ptr<ov::InferRequest>> infer_request_cache;
static std::unordered_map<ggml_cgraph *, std::vector<std::string>> ov_input_names_cache;
static std::unordered_map<ggml_cgraph *, std::vector<std::string>> ov_output_names_cache;
// For NPU, store the kvcache model, since we cannot create two infer_request
static std::unordered_map<ggml_cgraph *, ov::CompiledModel> compiled_model_cache;
std::shared_ptr<GgmlOvDecoder> ggml_decoder;
ov::InferRequest infer_request;
std::shared_ptr<ov::InferRequest> infer_request;
int64_t decoder_end_time;
int64_t conversion_end_time;
@ -118,83 +87,36 @@ enum ggml_status openvino_frontend_compute(ggml_backend_t backend, ggml_cgraph *
auto it = infer_request_cache.find(cgraph);
if (it != infer_request_cache.end()) {
std::map<std::string, std::shared_ptr<ov::Node>> model_weights;
ggml_decoder = std::make_shared<GgmlOvDecoder>(cgraph, model_weights, is_static, false);
ggml_decoder = std::make_shared<GgmlOvDecoder>(cgraph, model_weights, is_static);
decoder_end_time = ggml_time_us();
// For NPU for the first time we call kvcache modle, pop the compiled kvcache model from cache
if (is_static && compiled_model_cache.find(cgraph) != compiled_model_cache.end()) {
infer_request_cache[cgraph] =
std::make_shared<ov::InferRequest>(compiled_model_cache[cgraph].create_infer_request());
compiled_model_cache.erase(cgraph);
}
infer_request = *infer_request_cache[cgraph];
infer_request = infer_request_cache[cgraph];
conversion_end_time = ggml_time_us();
compile_end_time = conversion_end_time;
} else {
std::shared_ptr<ov::Model> model;
auto model_weights = GgmlOvDecoder::create_weight_nodes(cgraph, get_types_to_requant(device));
if (is_static) {
ggml_decoder = std::make_shared<GgmlOvDecoder>(cgraph, model_weights, is_static, true);
auto ggml_decoder_kvcache = std::make_shared<GgmlOvDecoder>(cgraph, model_weights, is_static, false);
decoder_end_time = ggml_time_us();
ggml_decoder = std::make_shared<GgmlOvDecoder>(cgraph, model_weights, is_static);
decoder_end_time = ggml_time_us();
auto input_model = std::make_shared<ov::frontend::ggml::InputModel>(ggml_decoder);
auto input_model_kvcache = std::make_shared<ov::frontend::ggml::InputModel>(ggml_decoder_kvcache);
auto input_model = std::make_shared<ov::frontend::ggml::InputModel>(ggml_decoder);
model = ov::frontend::ggml::FrontEnd::convert(input_model);
ggml_decoder->clear_model_weights();
conversion_end_time = ggml_time_us();
model = ov::frontend::ggml::FrontEnd::convert(input_model);
ggml_decoder->clear_model_weights();
auto model_kvcache = ov::frontend::ggml::FrontEnd::convert(input_model_kvcache);
ggml_decoder_kvcache->clear_model_weights();
conversion_end_time = ggml_time_us();
if (getenv("GGML_OPENVINO_DUMP_IR")) {
char timestamped_filename[64];
auto timestamp = (long long) ggml_time_us();
snprintf(timestamped_filename, sizeof(timestamped_filename), "model_prefill_%lld.xml", timestamp);
ov::serialize(model, timestamped_filename);
snprintf(timestamped_filename, sizeof(timestamped_filename), "model_kvcache_%lld.xml", timestamp);
ov::serialize(model_kvcache, timestamped_filename);
}
auto compiled_model = core.compile_model(model, device, get_npu_prefill_config());
auto compiled_model_kvcache = core.compile_model(model_kvcache, device, get_npu_generate_config());
compiled_model_cache[cgraph] = compiled_model_kvcache;
compile_end_time = ggml_time_us();
infer_request_cache[cgraph] = std::make_shared<ov::InferRequest>(compiled_model.create_infer_request());
infer_request = *infer_request_cache[cgraph];
compiled_model_cache[cgraph] = compiled_model_kvcache;
} else {
ggml_decoder = std::make_shared<GgmlOvDecoder>(cgraph, model_weights, is_static, true);
decoder_end_time = ggml_time_us();
auto input_model = std::make_shared<ov::frontend::ggml::InputModel>(ggml_decoder);
model = ov::frontend::ggml::FrontEnd::convert(input_model);
ggml_decoder->clear_model_weights();
conversion_end_time = ggml_time_us();
if (getenv("GGML_OPENVINO_DUMP_IR")) {
char timestamped_filename[64];
auto timestamp = (long long) ggml_time_us();
snprintf(timestamped_filename, sizeof(timestamped_filename), "model_%lld.xml", timestamp);
ov::serialize(model, timestamped_filename);
}
auto * disable_sdpa_optimization = getenv("GGML_OPENVINO_DISABLE_SDPA_OPTIMIZATION");
if (disable_sdpa_optimization && std::string(disable_sdpa_optimization) != "0") {
config = {
{"GPU_ENABLE_SDPA_OPTIMIZATION", "0"}
};
}
auto compiled_model = core.compile_model(model, device, config);
compile_end_time = ggml_time_us();
infer_request_cache[cgraph] = std::make_shared<ov::InferRequest>(compiled_model.create_infer_request());
infer_request = *infer_request_cache[cgraph];
if (getenv("GGML_OPENVINO_DUMP_IR")) {
char timestamped_filename[64];
auto timestamp = (long long) ggml_time_us();
snprintf(timestamped_filename, sizeof(timestamped_filename), "model_%lld.xml", timestamp);
ov::serialize(model, timestamped_filename);
}
auto compiled_model = core.compile_model(model, device, get_ov_compile_config(device));
compile_end_time = ggml_time_us();
infer_request_cache[cgraph] = std::make_shared<ov::InferRequest>(compiled_model.create_infer_request());
infer_request = infer_request_cache[cgraph];
std::vector<std::string> ov_input_names;
std::vector<std::string> ov_output_names;
for (const auto & ov_param : model->get_parameters()) {
@ -210,72 +132,66 @@ enum ggml_status openvino_frontend_compute(ggml_backend_t backend, ggml_cgraph *
auto ov_input_names = ov_input_names_cache[cgraph];
auto ov_output_names = ov_output_names_cache[cgraph];
for (size_t i = 0; i < ov_input_names.size(); i++) {
auto param_name = ov_input_names[i];
auto input_tensor = get_ov_input_tensor(ggml_decoder, param_name);
infer_request.set_input_tensor(i, input_tensor);
infer_request->set_input_tensor(i, input_tensor);
if (getenv("GGML_OPENVINO_DEBUG_INPUT")) {
print_input_tensor_info(param_name, input_tensor);
}
}
for (size_t i = 0; i < ov_output_names.size(); i++) {
auto output_tensor = get_ov_output_tensor(ggml_decoder, ov_output_names[i]);
infer_request->set_output_tensor(i, output_tensor);
}
auto input_end_time = ggml_time_us();
infer_request.infer();
infer_request->infer();
auto infer_end_time = ggml_time_us();
auto gguf_tensor_addrs = get_ggml_graph_output_dst(ggml_decoder);
for (size_t i = 0; i < ov_output_names.size(); i++) {
auto & result_name = ov_output_names[i];
const auto output_tensor = infer_request.get_output_tensor(i);
std::memcpy(gguf_tensor_addrs[result_name], output_tensor.data(), output_tensor.get_byte_size());
const auto output_tensor = infer_request->get_output_tensor(i);
if (getenv("GGML_OPENVINO_DEBUG_OUTPUT")) {
print_output_tensor_info(result_name, output_tensor, gguf_tensor_addrs);
print_output_tensor_info(ov_output_names[i], output_tensor, output_tensor.data());
}
}
auto end_time = ggml_time_us();
if (getenv("GGML_OPENVINO_PROFILING")) {
GGML_LOG_INFO("GGML OpenVINO Backend: \n");
GGML_LOG_INFO("\nGGML OpenVINO Backend: \n");
GGML_LOG_INFO(" - Graph decoder Time: %ld ms \n", (decoder_end_time - start_time) / 1000);
GGML_LOG_INFO(" - Graph conversion Time: %ld ms \n", (conversion_end_time - decoder_end_time) / 1000);
GGML_LOG_INFO(" - Graph compile Time: %ld ms \n", (compile_end_time - conversion_end_time) / 1000);
GGML_LOG_INFO(" - Graph Input Time: %ld ms \n", (input_end_time - compile_end_time) / 1000);
GGML_LOG_INFO(" - Graph Inference Time: %ld ms \n", (infer_end_time - input_end_time) / 1000);
GGML_LOG_INFO(" - Graph Output Time: %ld ms \n", (end_time - infer_end_time) / 1000);
}
return GGML_STATUS_SUCCESS;
GGML_UNUSED(backend);
}
namespace {
ov::AnyMap get_npu_base_config() {
return {
{"NPU_COMPILATION_MODE_PARAMS", "compute-layers-with-higher-precision=Sqrt,Power,ReduceMean,Add_RMSNorm" },
{"NPU_COMPILER_DYNAMIC_QUANTIZATION", "YES" },
{"NPU_USE_NPUW", "YES" },
{"NPUW_DEVICES", "NPU" },
{"NPUW_FOLD", "YES" },
{"NPUW_WEIGHTS_BANK", "shared" },
{"NPUW_FUNCALL_FOR_ALL", "YES" },
{"NPUW_FUNCALL_ASYNC", "YES" },
{"NPUW_DQ", "YES" },
{"NPUW_DQ_FULL", "NO" },
{"NPUW_CACHE_DIR", getenv("GGML_OPENVINO_CACHE_DIR") ? getenv("GGML_OPENVINO_CACHE_DIR") : ""},
};
}
} // namespace
ov::AnyMap get_npu_prefill_config() {
auto config = get_npu_base_config();
return config;
}
ov::AnyMap get_npu_generate_config() {
auto config = get_npu_base_config();
ov::AnyMap get_ov_compile_config(const std::string & device) {
ov::AnyMap config;
if (device == "NPU") {
config = {
{"NPU_COMPILER_DYNAMIC_QUANTIZATION", "YES" },
{"NPU_USE_NPUW", "YES" },
{"NPUW_DEVICES", "NPU" },
{"NPUW_FOLD", "YES" },
{"NPUW_WEIGHTS_BANK", "shared"},
{"NPUW_FUNCALL_FOR_ALL", "YES" },
{"NPUW_FUNCALL_ASYNC", "YES" },
{"NPUW_DQ", "YES" },
{"NPUW_DQ_FULL", "NO" },
};
if (auto * cache_dir = getenv("GGML_OPENVINO_CACHE_DIR"); cache_dir) {
config["NPUW_CACHE_DIR"] = cache_dir;
}
}
return config;
}
@ -291,7 +207,7 @@ std::map<ggml_type, ExtraQuantType> get_types_to_requant(const std::string & dev
}
if (device == "GPU") {
return {
// gs16 is WIP
// gs16 will be supported on openvino-2025.4
{GGML_TYPE_Q6_K, ExtraQuantType::Q8_0_32},
};
}
@ -331,70 +247,91 @@ enum ggml_status naive_compute(ggml_cgraph * cgraph,
infer_request.set_input_tensor(i, input_tensor);
}
infer_request.infer();
auto gguf_tensor_addrs = get_ggml_graph_output_dst(decoder);
auto ov_results = model->get_results();
for (size_t i = 0; i < ov_results.size(); i++) {
auto result_name = ov_results[i]->get_friendly_name();
const auto output_tensor = infer_request.get_output_tensor(i);
std::memcpy(gguf_tensor_addrs[result_name], output_tensor.data(), output_tensor.get_byte_size());
auto output_tensor = get_ov_output_tensor(decoder, result_name);
infer_request.set_output_tensor(i, output_tensor);
}
infer_request.infer();
return GGML_STATUS_SUCCESS;
}
namespace {
ov::Tensor convert_ggml_input_to_ov(std::shared_ptr<GgmlOvDecoder> ggml_decoder, const std::string & name) {
const auto * ggml_tensor = ggml_decoder->get_input_ggml_tensor(name);
auto * input_data = ggml_tensor->data;
ov::Shape input_shape;
if (ggml_tensor->op == GGML_OP_VIEW) {
// This case is added to make test-backend-ops work
input_shape = ggml_decoder->get_graph_input_shape(ggml_tensor->view_src).to_shape();
} else {
input_shape = ggml_decoder->get_input_shape(name).to_shape();
}
auto input_tensor = ov::Tensor(ggml_decoder->get_input_type(name), input_shape, input_data);
return input_tensor;
}
} // namespace
ov::Tensor get_ov_input_tensor(std::shared_ptr<GgmlOvDecoder> ggml_decoder, const std::string & param_name) {
bool is_static = ggml_decoder->is_static();
bool is_first_token = ggml_decoder->is_first_token();
ov::Tensor input_tensor;
if (ggml_decoder->get_model_extra_inputs().find(param_name) != ggml_decoder->get_model_extra_inputs().end()) {
input_tensor = *ggml_decoder->get_model_extra_input_values().at(param_name);
} else if (!is_static) {
} else if (param_name.find("cache_k") == 0 || param_name.find("cache_v") == 0) {
void * input_data = ggml_decoder->get_input_ggml_tensor(param_name)->data;
size_t past_kv_len =
ggml_decoder->is_static() ? ggml_decoder->get_context_size() : ggml_decoder->get_past_kv_len();
ov::Shape input_shape = {past_kv_len, (size_t) ggml_decoder->get_num_heads_kv(),
(size_t) ggml_decoder->get_head_size()};
input_tensor = ov::Tensor(ggml_decoder->get_input_type(param_name), input_shape, input_data);
} else if (is_static && param_name.find("KQ_mask") == 0) {
size_t context_size = ggml_decoder->get_context_size();
const auto * input_tensor_ggml = ggml_decoder->get_input_ggml_tensor(param_name);
std::vector<float> padded_data = pad_input<float>(input_tensor_ggml, 1, context_size, -INFINITY);
input_tensor = ov::Tensor(ov::element::f32, ov::Shape{1, 1, context_size});
auto * data_ptr = input_tensor.data<float>();
std::copy(padded_data.begin(), padded_data.end(), data_ptr);
} else if (is_static && param_name.find("inp_out_ids") == 0) {
input_tensor = convert_ggml_input_to_ov(ggml_decoder, param_name);
if (input_tensor.get_size() == 0) {
input_tensor = ov::Tensor(input_tensor.get_element_type(), ov::Shape{1, 1, 1});
*input_tensor.data<int32_t>() = 0;
}
} else {
if (param_name == "inp_tokens" || param_name == "inp_pos") {
if (is_first_token) {
size_t context_size = ggml_decoder->get_context_size();
const auto * input_tensor_ggml = ggml_decoder->get_input_ggml_tensor(param_name);
std::vector<int32_t> padded_data = pad_input<int32_t>(input_tensor_ggml, 1, context_size, 0);
input_tensor = ov::Tensor(ov::element::i32, ov::Shape{1, 1, context_size});
auto * data_ptr = input_tensor.data<int32_t>();
std::copy(padded_data.begin(), padded_data.end(), data_ptr);
} else {
input_tensor = convert_ggml_input_to_ov(ggml_decoder, param_name);
}
} else if (param_name.find("KQ_mask") == 0) {
size_t context_size = ggml_decoder->get_context_size();
const auto * input_tensor_ggml = ggml_decoder->get_input_ggml_tensor(param_name);
if (is_first_token) {
std::vector<float> padded_data =
pad_input<float>(input_tensor_ggml, context_size, context_size, -INFINITY);
set_zero_diagonal(padded_data, context_size);
input_tensor = ov::Tensor(ov::element::f32, ov::Shape{1, context_size, context_size});
auto * data_ptr = input_tensor.data<float>();
std::copy(padded_data.begin(), padded_data.end(), data_ptr);
} else {
std::vector<float> padded_data = pad_input<float>(input_tensor_ggml, 1, context_size, -INFINITY);
input_tensor = ov::Tensor(ov::element::f32, ov::Shape{1, 1, context_size});
auto * data_ptr = input_tensor.data<float>();
std::copy(padded_data.begin(), padded_data.end(), data_ptr);
}
} else if (const auto * op = ggml_decoder->get_tensor_used_op(ggml_decoder->get_tensor_from_name(param_name));
op && op->op == GGML_OP_SET_ROWS && is_static && is_first_token) {
input_tensor = ov::Tensor(ov::element::i64, ov::Shape{1, 1, 1});
} else {
input_tensor = convert_ggml_input_to_ov(ggml_decoder, param_name);
}
input_tensor = convert_ggml_input_to_ov(ggml_decoder, param_name);
}
return input_tensor;
}
ov::Tensor get_ov_output_tensor(std::shared_ptr<GgmlOvDecoder> ggml_decoder, const std::string & result_name) {
auto * ggml_tensor = ggml_decoder->get_output_ggml_tensor(result_name);
auto output_type = ggml_decoder->get_output_type(result_name);
ov::Shape output_shape;
if (result_name.find("cache") == std::string::npos) {
output_shape = ggml_decoder->get_output_shape(result_name).to_shape();
if (ggml_decoder->is_static() && result_name == "result_output") {
output_shape[1] = 1;
}
} else {
size_t total_token_len = ggml_decoder->get_past_kv_len() + ggml_decoder->get_input_len();
size_t num_heads_kv = ggml_decoder->get_num_heads_kv();
size_t head_size = ggml_decoder->get_head_size();
if (ggml_decoder->is_static()) {
total_token_len = ggml_decoder->get_context_size();
}
output_shape = ov::Shape{total_token_len, num_heads_kv, head_size};
}
ov::Tensor output_tensor(output_type, output_shape, ggml_tensor->data);
return output_tensor;
}
size_t checksum(const void * data, size_t size) {
const uint8_t * bytes = static_cast<const uint8_t *>(data);
size_t sum = 0;
@ -405,10 +342,6 @@ size_t checksum(const void * data, size_t size) {
return sum;
}
// Suppress deprecation warning for ov::Tensor::data()
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
void print_input_tensor_info(const std::string & name, const ov::Tensor & tensor) {
std::cout << "Input name: " << name << ", Input shape: " << tensor.get_shape() << ", Address: " << tensor.data()
<< std::endl;
@ -433,11 +366,9 @@ void print_input_tensor_info(const std::string & name, const ov::Tensor & tensor
}
}
void print_output_tensor_info(const std::string & name,
const ov::Tensor & tensor,
std::map<std::string, void *> & output_dst) {
std::cout << "Output name: " << name << ", Output shape: " << tensor.get_shape()
<< ", Address: " << output_dst[name] << std::endl;
void print_output_tensor_info(const std::string & name, const ov::Tensor & tensor, void * output_dst) {
std::cout << "Output name: " << name << ", Output shape: " << tensor.get_shape() << ", Address: " << output_dst
<< std::endl;
auto print_float_stats = [](const std::string & type_name, size_t size, auto get_value) {
if (size == 0) {
@ -485,15 +416,13 @@ void print_output_tensor_info(const std::string & name,
}
}
#pragma GCC diagnostic pop
void set_zero_diagonal(std::vector<float> & matrix, size_t dim) {
for (size_t i = 0; i < dim; ++i) {
matrix[i * dim + i] = 0.0f;
}
}
bool is_prefill(ggml_cgraph * cgraph) {
const ggml_tensor * get_inp_pos_tensor(ggml_cgraph * cgraph) {
for (int i = 0; i < cgraph->n_nodes; ++i) {
auto * op = cgraph->nodes[i];
for (int j = 0; j < GGML_MAX_SRC; ++j) {
@ -501,11 +430,17 @@ bool is_prefill(ggml_cgraph * cgraph) {
if (src == nullptr) {
break;
}
if (std::string(src->name) == "inp_tokens") {
return src->ne[0] != 1;
if (std::string(src->name) == "inp_pos") {
return src;
}
}
}
GGML_LOG_ERROR("is_prefill: inp_tokens not found in cgraph");
throw std::runtime_error("is_prefill: inp_tokens not found in cgraph");
GGML_LOG_ERROR("get_inp_pos_tensor: inp_pos not found in cgraph");
throw std::runtime_error("get_inp_pos_tensor: inp_pos not found in cgraph");
}
bool get_is_first_token(const ggml_tensor * inp_pos) {
return *(int32_t *) inp_pos->data == 0;
}
#pragma GCC diagnostic pop

View File

@ -7,19 +7,11 @@
enum ggml_status openvino_frontend_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph);
std::shared_ptr<GgmlOvDecoder> get_ggml_decoder(struct ggml_cgraph * cgraph, bool is_static, bool is_first_token);
ov::Tensor convert_ggml_input_to_ov(std::shared_ptr<GgmlOvDecoder> ggml_decoder, const std::string & name);
std::map<std::string, void *> get_ggml_graph_output_dst(std::shared_ptr<GgmlOvDecoder> ggml_decoder);
size_t checksum(const void * data, size_t size);
void print_input_tensor_info(const std::string & name, const ov::Tensor & tensor);
void print_output_tensor_info(const std::string & name,
const ov::Tensor & tensor,
std::map<std::string, void *> & output_dst);
void print_output_tensor_info(const std::string & name, const ov::Tensor & tensor, void * output_dst);
template <typename T>
std::vector<T> pad_input(const ggml_tensor * tensor, size_t padded_rows, size_t padded_cols, T pad_value) {
@ -38,15 +30,18 @@ std::vector<T> pad_input(const ggml_tensor * tensor, size_t padded_rows, size_t
void set_zero_diagonal(std::vector<float> & matrix, size_t dim);
bool is_prefill(struct ggml_cgraph * cgraph);
const ggml_tensor * get_inp_pos_tensor(struct ggml_cgraph * cgraph);
ov::AnyMap get_npu_prefill_config();
ov::AnyMap get_npu_generate_config();
bool get_is_first_token(const ggml_tensor * inp_pos);
ov::AnyMap get_ov_compile_config(const std::string & device);
std::map<ggml_type, ExtraQuantType> get_types_to_requant(const std::string & device);
ov::Tensor get_ov_input_tensor(std::shared_ptr<GgmlOvDecoder> ggml_decoder, const std::string & param_name);
ov::Tensor get_ov_output_tensor(std::shared_ptr<GgmlOvDecoder> ggml_decoder, const std::string & result_name);
bool is_naive(struct ggml_cgraph * cgraph);
enum ggml_status naive_compute(struct ggml_cgraph * cgraph,