Statful transformation for CPU GPU

This commit is contained in:
Yu, Zijun 2025-06-26 13:54:06 +08:00 committed by Mustafa Cavus
parent 8afee795ad
commit 4c582ac7a3
7 changed files with 216 additions and 120 deletions

View File

@ -26,12 +26,13 @@
#include "ggml-backend-impl.h"
#include "ggml-backend.h"
GgmlOvDecoder::GgmlOvDecoder(struct ggml_tensor* node, struct ggml_cgraph* cgraph, bool is_static, bool is_first_token)
: m_cgraph(cgraph),
m_node(node),
m_op_name(m_node ? std::string(m_node->name) : "NONE_OP"),
m_is_static(is_static),
m_is_first_token(is_first_token) {
GgmlOvDecoder::GgmlOvDecoder(struct ggml_tensor* node, struct ggml_cgraph* m_cgraph, bool is_static,
bool is_first_token) :
m_cgraph(m_cgraph),
m_node(node),
m_op_name(m_node ? std::string(m_node->name) : "NONE_OP"),
m_is_static(is_static),
m_is_first_token(is_first_token) {
static std::map<std::string, std::shared_ptr<ov::Node>> model_weights;
if (m_node) {
@ -44,10 +45,11 @@ GgmlOvDecoder::GgmlOvDecoder(struct ggml_tensor* node, struct ggml_cgraph* cgrap
}
if (getenv("GGML_OPENVINO_DUMP_CGRAPH")) {
dump_cgraph(m_cgraph);
std::string filename = "cgraph.txt";
dump_cgraph(m_cgraph, filename);
}
set_max_token_len();
set_llm_params();
static bool weight_created = false;
if (!weight_created) {
@ -105,33 +107,7 @@ void GgmlOvDecoder::set_input_output(ggml_tensor* node) {
if (m_model_inputs.find(src_name) != m_model_inputs.end()) {
continue;
}
ov::PartialShape input_shape;
if (std::string(src->name) == "inp_tokens" || std::string(src->name) == "inp_pos") {
if (m_is_static) {
if (m_is_first_token) {
input_shape = ov::PartialShape{1, 1, m_max_token_len};
} else {
input_shape = ov::PartialShape{1, 1, 1};
}
} else {
input_shape = ov::PartialShape{1, 1, ov::Dimension(1, m_max_token_len)};
}
} else if (std::string(src->name) == "KQ_mask") {
if (m_is_static) {
if (m_is_first_token) {
input_shape = ov::PartialShape{1, m_max_token_len, m_max_token_len};
} else {
input_shape = ov::PartialShape{1, 1, m_max_token_len};
}
} else {
auto max_mask_size = GGML_PAD(m_max_token_len, GGML_KQ_MASK_PAD);
input_shape =
ov::PartialShape{1, ov::Dimension(1, max_mask_size), ov::Dimension(1, max_mask_size)};
}
} else {
input_shape = ov::Shape{get_shape(src)};
}
auto param_node = std::make_shared<ov::op::v0::Parameter>(get_ov_type(src), input_shape);
auto param_node = std::make_shared<ov::op::v0::Parameter>(get_ov_type(src), get_graph_input_shape(src));
param_node->set_friendly_name(src_name);
m_model_inputs[src_name] = param_node;
}
@ -150,6 +126,7 @@ void GgmlOvDecoder::set_input_output(ggml_tensor* node) {
auto it = std::find(m_model_output_names.begin(), m_model_output_names.end(), name);
if (it == m_model_output_names.end()) {
m_model_output_names.push_back(name);
m_kv_names.push_back(name);
}
}
}
@ -213,17 +190,54 @@ void GgmlOvDecoder::set_input_output(ggml_tensor* node) {
}
}
void GgmlOvDecoder::set_max_token_len() {
void GgmlOvDecoder::set_llm_params() {
for (int i = 0; i < m_cgraph->n_nodes; i++) {
auto* node = m_cgraph->nodes[i];
if (std::string(node->name) == "cache_k_l0 (view)") {
if (node->op == GGML_OP_VIEW && std::string(node->name) == "cache_k_l0 (view)") {
auto* cache_k = node->src[0];
m_max_token_len = cache_k->ne[1];
break;
} else if (node->op == GGML_OP_ROPE && std::string(node->name) == "Qcur-0") {
m_head_size = node->ne[0];
m_num_heads = node->ne[1];
} else if (node->op == GGML_OP_ROPE && std::string(node->name) == "Kcur-0") {
m_num_heads_kv = node->ne[1];
}
}
}
ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor* src) const {
ov::PartialShape input_shape;
if (std::string(src->name) == "inp_tokens" || std::string(src->name) == "inp_pos") {
if (m_is_static) {
if (m_is_first_token) {
input_shape = ov::PartialShape{ 1, 1, m_max_token_len };
} else {
input_shape = ov::PartialShape{ 1, 1, 1 };
}
} else {
input_shape = ov::PartialShape{ 1, 1, ov::Dimension(1, m_max_token_len) };
}
} else if (std::string(src->name) == "KQ_mask") {
if (m_is_static) {
if (m_is_first_token) {
input_shape = ov::PartialShape{ 1, m_max_token_len, m_max_token_len };
} else {
input_shape = ov::PartialShape{ 1, 1, m_max_token_len };
}
} else {
auto max_mask_size = GGML_PAD(m_max_token_len, GGML_KQ_MASK_PAD);
input_shape = ov::PartialShape{ 1, ov::Dimension(1, max_mask_size), ov::Dimension(1, max_mask_size) };
}
} else if (std::string(src->name).find("cache_k") == 0) {
input_shape = ov::PartialShape{ m_max_token_len, m_num_heads_kv, m_head_size };
} else if (std::string(src->name).find("cache_v") == 0) {
input_shape = ov::PartialShape{ m_num_heads_kv, m_head_size, m_max_token_len };
} else {
input_shape = ov::PartialShape{ get_shape(src) };
}
return input_shape;
}
void GgmlOvDecoder::add_extra_inputs() {
int64_t past_token_len = -1;
// attention_size not used for NPU
@ -267,6 +281,16 @@ void GgmlOvDecoder::add_extra_inputs() {
}
}
std::map<std::string, std::string> GgmlOvDecoder::get_kv_param_res_names() const {
std::map<std::string, std::string> kv_param_res_names;
for (const auto& name : m_kv_names) {
if (name.find("cache_k") == 0 || name.find("cache_v") == 0) {
kv_param_res_names[name] = name;
}
}
return kv_param_res_names;
}
void GgmlOvDecoder::add_weight_const_parallel(std::map<std::string, std::shared_ptr<ov::Node>>& model_weights) {
static std::mutex weights_mutex;
auto* nodes = m_cgraph->nodes;
@ -344,8 +368,8 @@ std::shared_ptr<ov::Node> GgmlOvDecoder::create_weight_node(ggml_tensor* tensor)
return weight_node;
}
void GgmlOvDecoder::dump_cgraph(const struct ggml_cgraph* cgraph) {
std::ofstream file("cgraph.txt");
void GgmlOvDecoder::dump_cgraph(const struct ggml_cgraph* cgraph, std::string& filename) {
std::ofstream file(filename);
if (!file.is_open()) {
std::cerr << "Failed to open file" << std::endl;
return;

View File

@ -3,6 +3,7 @@
#include <cstdint>
#include <map>
#include <memory>
#include <openvino/core/partial_shape.hpp>
#include <vector>
#include "ggml.h"
@ -89,28 +90,34 @@ public:
return m_model_output_names;
}
virtual bool is_static() const override {
return m_is_static;
}
virtual bool is_first_token() const override {
return m_is_first_token;
}
virtual int get_max_token_len() const override {
return m_max_token_len;
}
virtual int get_max_token_len() const override { return m_max_token_len; }
virtual int get_num_heads() const override { return m_num_heads; }
virtual int get_num_heads_kv() const override { return m_num_heads_kv; }
virtual int get_head_size() const override { return m_head_size; }
virtual std::map<std::string, std::string> get_kv_param_res_names() const override;
virtual bool is_static() const override { return m_is_static; }
virtual bool is_first_token() const override { return m_is_first_token; }
ov::PartialShape get_graph_input_shape(const ggml_tensor* src) const;
private:
void set_input_output(ggml_tensor* node);
void add_extra_inputs();
static void dump_cgraph(const struct ggml_cgraph* cgraph);
static void dump_cgraph(const struct ggml_cgraph* cgraph, std::string& filename);
static std::vector<size_t> get_shape(const ggml_tensor* tensor);
static std::vector<size_t> get_stride(const ggml_tensor* tensor);
static ov::element::Type get_ov_type(const ggml_tensor* tensor);
// set max_token_len, num_heads, etc
void set_llm_params();
static std::shared_ptr<ov::Node> create_weight_node(ggml_tensor* tensor);
void set_max_token_len();
int m_max_token_len;
void add_weight_const_parallel(std::map<std::string, std::shared_ptr<ov::Node>>& model_weights);
struct ggml_cgraph* m_cgraph;
@ -129,6 +136,11 @@ private:
std::map<std::string, std::shared_ptr<ov::Tensor>> m_model_extra_input_values;
std::map<std::string, std::shared_ptr<ov::Node>> m_model_weights;
std::vector<std::string> m_model_output_names;
int m_max_token_len;
int m_num_heads;
int m_num_heads_kv;
int m_head_size;
std::vector<std::string> m_kv_names;
bool m_is_static;
bool m_is_first_token;
};

View File

@ -4,6 +4,7 @@
#include <map>
#include <openvino/core/node.hpp>
#include <openvino/frontend/decoder.hpp>
#include <string>
namespace ov {
namespace frontend {
@ -57,6 +58,11 @@ public:
virtual const std::map<std::string, std::shared_ptr<ov::Node>>& get_model_weights() const = 0;
virtual const std::vector<std::string>& get_model_output_names() const = 0;
virtual int get_num_heads() const = 0;
virtual int get_num_heads_kv() const = 0;
virtual int get_head_size() const = 0;
virtual std::map<std::string, std::string> get_kv_param_res_names() const = 0;
virtual bool is_static() const = 0;
virtual bool is_first_token() const = 0;
virtual int get_max_token_len() const = 0;

View File

@ -12,6 +12,7 @@
#include <openvino/op/range.hpp>
#include <openvino/op/reshape.hpp>
#include <openvino/op/scatter_nd_update.hpp>
#include <openvino/op/shape_of.hpp>
#include <openvino/op/slice.hpp>
#include <openvino/op/squeeze.hpp>
#include <openvino/op/transpose.hpp>
@ -57,13 +58,6 @@ OutputVector translate_cpy(const NodeContext& context) {
if (op_case == 1) {
// Write K to cache_k
int64_t head_size = src0_shape[2];
int64_t num_heads = src0_shape[1];
auto reshaped_src1_shape =
ov::op::v0::Constant::create(ov::element::i64, {3}, std::vector<int64_t>{-1, num_heads, head_size});
auto reshaped_src1 = std::make_shared<ov::op::v1::Reshape>(src1, reshaped_src1_shape, false);
auto token_len = get_dimensions(src0.get_node_shared_ptr(), {0});
auto token_len_scalar = std::make_shared<ov::op::v0::Squeeze>(token_len, zero);
@ -80,7 +74,8 @@ OutputVector translate_cpy(const NodeContext& context) {
}
indices = std::make_shared<ov::op::v0::Unsqueeze>(indices, one);
res = std::make_shared<ov::op::v3::ScatterNDUpdate>(reshaped_src1, indices, src0);
auto updated = std::make_shared<ov::op::v3::ScatterNDUpdate>(src1, indices, src0);
res = std::make_shared<ov::op::v1::Reshape>(updated, std::make_shared<ov::op::v0::ShapeOf>(src1), false);
} else {
// Write V to cache_v
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
@ -140,7 +135,7 @@ OutputVector translate_cpy(const NodeContext& context) {
false);
auto updated = std::make_shared<ov::op::v3::ScatterNDUpdate>(reshaped_src1, indices_final, flattend_src0);
res = std::make_shared<ov::op::v0::Unsqueeze>(updated, zero);
res = std::make_shared<ov::op::v1::Reshape>(updated, std::make_shared<ov::op::v0::ShapeOf>(src1), false);
}
return rename_outputs_with_suffix({res}, context.get_name());

View File

@ -1,7 +1,12 @@
#include "translate_session.hpp"
#include <cstdlib>
#include <map>
#include <memory>
#include <openvino/op/parameter.hpp>
#include <openvino/op/result.hpp>
#include <openvino/pass/constant_folding.hpp>
#include <openvino/pass/make_stateful.hpp>
#include "input_model.hpp"
@ -11,6 +16,41 @@ namespace ggml {
using namespace ov::op;
namespace {
ov::pass::MakeStateful::ParamResPairs get_kv_param_res_pairs(
const std::shared_ptr<ov::Model>& model, const std::map<std::string, std::string>& kv_param_res_names) {
ov::pass::MakeStateful::ParamResPairs pairs;
const auto& params = model->get_parameters();
const auto& results = model->get_results();
for (const auto& param_res : kv_param_res_names) {
const auto& param_name = param_res.first;
const auto& res_name = param_res.second;
auto param_it = std::find_if(params.begin(), params.end(), [&](const std::shared_ptr<v0::Parameter>& node) {
return node->get_friendly_name() == param_name;
});
OPENVINO_ASSERT(param_it != params.end(), "The tensor name ", param_name,
" is not associated with any of "
"Parameters in the network.");
auto res_it = std::find_if(results.begin(), results.end(), [&](const std::shared_ptr<v0::Result>& node) {
return node->get_friendly_name() == res_name;
});
OPENVINO_ASSERT(res_it != results.end(), "The tensor name ", res_name,
" is not associated with any of "
"Results in the network.");
std::shared_ptr<ov::op::v0::Parameter> param = *param_it;
std::shared_ptr<ov::op::v0::Result> res = *res_it;
pairs.emplace_back(param, res);
}
return pairs;
}
} // namespace
TranslateSession::TranslateSession(const frontend::InputModel::Ptr& input_model,
const std::unordered_map<std::string, CreatorFunction>& translator_map)
: m_input_model(input_model),
@ -88,25 +128,26 @@ std::shared_ptr<Model> TranslateSession::translate_graph(const frontend::InputMo
results.push_back(result);
}
ov::ParameterVector used_params;
for (const auto& param : params) {
if (!param->output(0).get_target_inputs().empty()) {
used_params.push_back(param);
}
}
if (getenv("GGML_OPENVINO_PROFILING")) {
if (auto diff = params.size() - used_params.size()) {
std::cout << diff << " parameters are not used in the model." << std::endl;
}
}
resulting_model = std::make_shared<Model>(results, used_params);
resulting_model = std::make_shared<Model>(results, params);
apply_transformations(resulting_model);
return resulting_model;
}
void TranslateSession::apply_transformations(const std::shared_ptr<Model>& model) {
auto ggml_model_decoder = std::dynamic_pointer_cast<InputModel>(m_input_model)->get_model_decoder();
ov::pass::Manager manager;
manager.set_per_pass_validation(true);
manager.register_pass<ov::pass::ConstantFolding>();
manager.run_passes(resulting_model);
return resulting_model;
if (!ggml_model_decoder->is_static()) {
const auto kv_param_res_names = ggml_model_decoder->get_kv_param_res_names();
const auto kv_param_res_pairs = get_kv_param_res_pairs(model, kv_param_res_names);
manager.register_pass<ov::pass::MakeStateful>(kv_param_res_pairs);
}
manager.run_passes(model);
}
} // namespace ggml

View File

@ -16,7 +16,7 @@ public:
std::shared_ptr<Model> translate_graph(const frontend::InputModel::Ptr& input_model);
private:
void print_model_topology();
void apply_transformations(const std::shared_ptr<Model>& model);
const frontend::InputModel::Ptr m_input_model;
const std::unordered_map<std::string, CreatorFunction>& m_translator_map;
std::shared_ptr<Model> m_ov_model;

View File

@ -9,10 +9,13 @@
#include <memory>
#include <openvino/core/any.hpp>
#include <openvino/core/graph_util.hpp>
#include <openvino/core/partial_shape.hpp>
#include <openvino/core/type/float16.hpp>
#include <openvino/frontend/manager.hpp>
#include <openvino/op/parameter.hpp>
#include <openvino/openvino.hpp>
#include <openvino/runtime/compiled_model.hpp>
#include <openvino/runtime/infer_request.hpp>
#include <openvino/runtime/intel_npu/properties.hpp>
#include <openvino/runtime/tensor.hpp>
#include <unordered_map>
@ -28,11 +31,15 @@ std::shared_ptr<GgmlOvDecoder> get_ggml_decoder(struct ggml_cgraph* cgraph, bool
}
ov::Tensor convert_ggml_input_to_ov(std::shared_ptr<GgmlOvDecoder> ggml_decoder, const std::string& name) {
auto* input_data = ggml_decoder->get_input_ggml_tensor(name)->data;
ov::Tensor input_tensor;
ov::Shape input_shape = ggml_decoder->get_input_shape(name).to_shape();
std::vector<size_t> input_stride = ggml_decoder->get_input_stride(name);
input_tensor = ov::Tensor(ggml_decoder->get_input_type(name), input_shape, input_data);
const auto* ggml_tensor = ggml_decoder->get_input_ggml_tensor(name);
auto* input_data = ggml_tensor->data;
ov::Shape input_shape;
if (name.find("cache_k") == 0 || name.find("cache_v") == 0) {
input_shape = ggml_decoder->get_graph_input_shape(ggml_tensor).to_shape();
} else {
input_shape = ggml_decoder->get_input_shape(name).to_shape();
}
auto input_tensor = ov::Tensor(ggml_decoder->get_input_type(name), input_shape, input_data);
return input_tensor;
}
@ -82,41 +89,37 @@ enum ggml_status openvino_frontend_compute(ggml_backend_t backend, struct ggml_c
core.set_property(ov::cache_dir(cache_dir));
}
// CPU and GPU will only use cache_prefill
using CachedItem = std::pair<std::shared_ptr<ov::Model>, ov::CompiledModel>;
static std::unordered_map<struct ggml_cgraph*, CachedItem> compiled_cache_prefill;
static std::unordered_map<struct ggml_cgraph*, CachedItem> compiled_cache_kvcache;
static std::unordered_map<struct ggml_cgraph*, std::shared_ptr<ov::InferRequest>> infer_request_cache;
static std::unordered_map<struct ggml_cgraph*, std::vector<std::string>> ov_input_names_cache;
static std::unordered_map<struct ggml_cgraph*, std::vector<std::string>> ov_output_names_cache;
// For NPU, store the kvcache model, since we cannot create two infer_request
static std::unordered_map<struct ggml_cgraph*, ov::CompiledModel> compiled_model_cache;
std::shared_ptr<GgmlOvDecoder> ggml_decoder;
std::shared_ptr<ov::Model> model;
ov::CompiledModel compiled_model;
ov::InferRequest infer_request;
int64_t decoder_end_time;
int64_t conversion_end_time;
int64_t compile_end_time;
bool is_first_token = is_prefill(cgraph);
auto it = compiled_cache_prefill.find(cgraph);
if (it != compiled_cache_prefill.end()) {
auto it = infer_request_cache.find(cgraph);
if (it != infer_request_cache.end()) {
ggml_decoder = get_ggml_decoder(cgraph, is_static, false);
decoder_end_time = ggml_time_us();
if (is_static) {
if (is_first_token) {
model = compiled_cache_prefill[cgraph].first;
compiled_model = compiled_cache_prefill[cgraph].second;
} else {
model = compiled_cache_kvcache[cgraph].first;
compiled_model = compiled_cache_kvcache[cgraph].second;
}
} else {
model = it->second.first;
compiled_model = it->second.second;
// For NPU for the first time we call kvcache modle, pop the compiled kvcache model from cache
if (is_static && compiled_model_cache.find(cgraph) != compiled_model_cache.end()) {
infer_request_cache[cgraph] =
std::make_shared<ov::InferRequest>(compiled_model_cache[cgraph].create_infer_request());
compiled_model_cache.erase(cgraph);
}
infer_request = *infer_request_cache[cgraph];
conversion_end_time = ggml_time_us();
compile_end_time = conversion_end_time;
} else {
std::shared_ptr<ov::Model> model;
if (is_static) {
ggml_decoder = get_ggml_decoder(cgraph, is_static, true);
auto ggml_decoder_kvcache = get_ggml_decoder(cgraph, is_static, false);
@ -129,12 +132,14 @@ enum ggml_status openvino_frontend_compute(ggml_backend_t backend, struct ggml_c
auto model_kvcache = ov::frontend::ggml::FrontEnd::convert(input_model_kvcache);
conversion_end_time = ggml_time_us();
compiled_model = core.compile_model(model, device, config);
auto compiled_model = core.compile_model(model, device, config);
auto compiled_model_kvcache = core.compile_model(model_kvcache, device, config);
compiled_model_cache[cgraph] = compiled_model_kvcache;
compile_end_time = ggml_time_us();
compiled_cache_prefill[cgraph] = std::make_pair(model, compiled_model);
compiled_cache_kvcache[cgraph] = std::make_pair(model_kvcache, compiled_model_kvcache);
infer_request_cache[cgraph] = std::make_shared<ov::InferRequest>(compiled_model.create_infer_request());
infer_request = *infer_request_cache[cgraph];
compiled_model_cache[cgraph] = compiled_model_kvcache;
if (getenv("GGML_OPENVINO_DUMP_IR")) {
char timestamped_filename[64];
@ -152,9 +157,10 @@ enum ggml_status openvino_frontend_compute(ggml_backend_t backend, struct ggml_c
model = ov::frontend::ggml::FrontEnd::convert(input_model);
conversion_end_time = ggml_time_us();
compiled_model = core.compile_model(model, device, config);
auto compiled_model = core.compile_model(model, device, config);
compile_end_time = ggml_time_us();
compiled_cache_prefill[cgraph] = std::make_pair(model, compiled_model);
infer_request_cache[cgraph] = std::make_shared<ov::InferRequest>(compiled_model.create_infer_request());
infer_request = *infer_request_cache[cgraph];
if (getenv("GGML_OPENVINO_DUMP_IR")) {
char timestamped_filename[64];
@ -163,12 +169,23 @@ enum ggml_status openvino_frontend_compute(ggml_backend_t backend, struct ggml_c
ov::serialize(model, timestamped_filename);
}
}
}
auto infer_request = compiled_model.create_infer_request();
auto ov_params = model->get_parameters();
for (size_t i = 0; i < ov_params.size(); i++) {
auto param_name = ov_params[i]->get_friendly_name();
std::vector<std::string> ov_input_names;
std::vector<std::string> ov_output_names;
for (const auto& ov_param : model->get_parameters()) {
ov_input_names.push_back(ov_param->get_friendly_name());
}
for (const auto& ov_output : model->get_results()) {
ov_output_names.push_back(ov_output->get_friendly_name());
}
ov_input_names_cache[cgraph] = ov_input_names;
ov_output_names_cache[cgraph] = ov_output_names;
}
auto ov_input_names = ov_input_names_cache[cgraph];
auto ov_output_names = ov_output_names_cache[cgraph];
for (size_t i = 0; i < ov_input_names.size(); i++) {
auto param_name = ov_input_names[i];
auto input_tensor = get_ov_input_tensor(ggml_decoder, param_name);
infer_request.set_input_tensor(i, input_tensor);
@ -181,14 +198,15 @@ enum ggml_status openvino_frontend_compute(ggml_backend_t backend, struct ggml_c
infer_request.infer();
auto infer_end_time = ggml_time_us();
auto output_names = ggml_decoder->get_model_output_names();
auto output_tensors = get_ggml_graph_output_dst(ggml_decoder);
for (size_t i = 0; i < output_names.size(); i++) {
auto output_tensor = infer_request.get_output_tensor(i);
std::memcpy(output_tensors[output_names[i]], output_tensor.data(), output_tensor.get_byte_size());
auto gguf_tensor_addrs = get_ggml_graph_output_dst(ggml_decoder);
for (size_t i = 0; i < ov_output_names.size(); i++) {
auto result_name = ov_output_names[i];
const auto output_tensor = infer_request.get_output_tensor(i);
std::memcpy(gguf_tensor_addrs[result_name], output_tensor.data(), output_tensor.get_byte_size());
if (getenv("GGML_OPENVINO_DEBUG_OUTPUT")) {
print_output_tensor_info(output_names[i], output_tensor, output_tensors);
print_output_tensor_info(result_name, output_tensor, gguf_tensor_addrs);
}
}
auto end_time = ggml_time_us();