From 59e7e7c47d444a6c7a25e90d3a00488966d6680f Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Wed, 3 Dec 2025 15:45:40 +0800 Subject: [PATCH] NPU fix llama-bench --- ggml/src/ggml-openvino/ggml-decoder.cpp | 98 ++++++++++----------- ggml/src/ggml-openvino/ggml-decoder.h | 93 ++++++++++++------- ggml/src/ggml-openvino/openvino/decoder.hpp | 2 +- ggml/src/ggml-openvino/utils.cpp | 84 ++++++++++++------ ggml/src/ggml-openvino/utils.h | 22 +++++ 5 files changed, 191 insertions(+), 108 deletions(-) diff --git a/ggml/src/ggml-openvino/ggml-decoder.cpp b/ggml/src/ggml-openvino/ggml-decoder.cpp index c7035c1580..4c0258c4e3 100644 --- a/ggml/src/ggml-openvino/ggml-decoder.cpp +++ b/ggml/src/ggml-openvino/ggml-decoder.cpp @@ -36,6 +36,8 @@ #include GgmlOvDecoder::GgmlOvDecoder(ggml_cgraph * cgraph, + ModelParams & model_params, + ComputeParams & compute_params, std::map> & model_weights, bool is_static, bool is_prefill, @@ -44,7 +46,9 @@ GgmlOvDecoder::GgmlOvDecoder(ggml_cgraph * cgraph, m_is_prefill(is_prefill), m_prefill_chunk_size(prefill_chunk_size), m_cgraph(cgraph), - m_model_weights(model_weights) { + m_model_weights(model_weights), + m_model_params(model_params), + m_compute_params(compute_params) { if (auto * env = getenv("GGML_OPENVINO_PRINT_CGRAPH_TENSOR_ADDRESS"); env && std::string(env) != "0") { #ifdef _WIN32 _putenv_s("GGML_OPENVINO_PRINT_CGRAPH_TENSOR_ADDRESS", ""); @@ -54,7 +58,6 @@ GgmlOvDecoder::GgmlOvDecoder(ggml_cgraph * cgraph, print_tensor_address_map(cgraph); } - set_llm_params(); validate_cgraph(); for (int node_n = 0; node_n < cgraph->n_nodes; node_n++) { @@ -163,12 +166,6 @@ void GgmlOvDecoder::set_input_output(ggml_tensor * node, bool naive) { // Workaround: the final tensor "result_output" does not have GGML_TENSOR_FLAG_OUTPUT flag set in cgraph if (node->op == GGML_OP_SET_ROWS || node->flags & GGML_TENSOR_FLAG_OUTPUT || node_name.find("output") != std::string::npos || debug_output_names.count(node_name)) { - if (node->op == GGML_OP_SET_ROWS) { - assert(node_name.find("cache_k") == 0 || node_name.find("cache_v") == 0); - if (auto it = std::find(m_kv_names.begin(), m_kv_names.end(), node_name); it == m_kv_names.end()) { - m_kv_names.push_back(node_name); - } - } if (auto it = std::find(m_model_output_names.begin(), m_model_output_names.end(), node_name); it == m_model_output_names.end()) { m_model_output_names.push_back(node_name); @@ -277,9 +274,11 @@ int extract_layer_from_name(const std::string & name) { return layer; } -void GgmlOvDecoder::set_llm_params() { - for (int i = 0; i < m_cgraph->n_nodes; i++) { - auto * node = m_cgraph->nodes[i]; +std::pair GgmlOvDecoder::compute_llm_params(ggml_cgraph * cgraph, bool is_static) { + ModelParams model_params; + ComputeParams compute_params; + for (int i = 0; i < cgraph->n_nodes; i++) { + auto * node = cgraph->nodes[i]; std::string name = std::string(node->name); if (node->op == GGML_OP_FLASH_ATTN_EXT) { auto * cache_k_perm = node->src[1]; @@ -294,49 +293,50 @@ void GgmlOvDecoder::set_llm_params() { assert(mask_name.find("KQ_mask") == 0); if (std::string(node->src[3]->name).find("swa") != std::string::npos) { - m_swa_layers.push_back(layer); - m_ctx_per_seq_swa = cache_k->ne[1]; + model_params.swa_layers.push_back(layer); + model_params.ctx_per_seq_swa = cache_k->ne[1]; } else { - m_ctx_per_seq = cache_k->ne[1]; - m_n_seq = cache_k->ne[2]; + model_params.ctx_per_seq = cache_k->ne[1]; + model_params.n_seq = cache_k->ne[2]; } - m_n_seq_active = mask->ne[3]; + compute_params.n_seq_active = mask->ne[3]; auto seq_size = cache_k->ne[0] * cache_k->ne[1] * ggml_type_size(cache_k->type); size_t offset; memcpy(&offset, cache_k_view->op_params, sizeof(size_t)); - m_seq_active_start = offset / seq_size; - m_token_len_per_seq = node->ne[2]; + compute_params.seq_active_start = offset / seq_size; + compute_params.token_len_per_seq = node->ne[2]; if (mask_name.find("swa") != std::string::npos) { - m_attention_size_swa = mask->ne[0]; + compute_params.attention_size_swa = mask->ne[0]; } else { - m_attention_size = mask->ne[0]; + compute_params.attention_size = mask->ne[0]; } - if (m_is_static) { - m_attention_size = m_ctx_per_seq; - m_attention_size_swa = m_ctx_per_seq_swa; - m_token_len_per_seq = 1; + if (is_static) { + compute_params.attention_size = model_params.ctx_per_seq; + compute_params.attention_size_swa = model_params.ctx_per_seq_swa; + compute_params.token_len_per_seq = 1; } } else if (node->op == GGML_OP_ROPE) { if (name.find("Qcur-0") == 0 || std::string(node->src[0]->name).find("Qcur-0") == 0) { - m_head_size = node->ne[0]; - m_n_heads = node->ne[1]; - m_rope_params = node->op_params; + model_params.head_size = node->ne[0]; + model_params.n_heads = node->ne[1]; + model_params.rope_params = node->op_params; auto * inp_pos = node->src[1]; - m_input_len = inp_pos->ne[0]; + compute_params.input_len = inp_pos->ne[0]; } else if (name.find("Kcur-0") == 0 || std::string(node->src[0]->name).find("Kcur-0") == 0) { - m_n_heads_kv = node->ne[1]; + model_params.n_heads_kv = node->ne[1]; } } } - m_ctx = m_ctx_per_seq * m_n_seq; - m_ctx_swa = m_ctx_per_seq_swa * m_n_seq; + model_params.ctx = model_params.ctx_per_seq * model_params.n_seq; + model_params.ctx_swa = model_params.ctx_per_seq_swa * model_params.n_seq; + return {model_params, compute_params}; } void GgmlOvDecoder::validate_cgraph() const { - if (m_n_seq > 1 && m_is_static == true) { + if (m_model_params.n_seq > 1 && m_is_static == true) { throw std::runtime_error("n_seq > 1 is not supported on NPU. Try setting -np 1."); } } @@ -354,7 +354,7 @@ ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor * op, co } else if (name.find("KQ_mask") == 0) { if (m_is_static) { - input_shape = ov::PartialShape{1, 1, m_is_prefill ? m_prefill_chunk_size : 1, m_ctx}; + input_shape = ov::PartialShape{1, 1, m_is_prefill ? m_prefill_chunk_size : 1, m_model_params.ctx}; } else { input_shape = ov::PartialShape{-1, 1, -1, -1}; } @@ -403,14 +403,14 @@ void GgmlOvDecoder::add_extra_inputs() { } }; - create_1d_input("attention_size", m_attention_size); - if (m_attention_size_swa != -1) { - create_1d_input("attention_size_swa", m_attention_size_swa); + create_1d_input("attention_size", m_compute_params.attention_size); + if (m_compute_params.attention_size_swa != -1) { + create_1d_input("attention_size_swa", m_compute_params.attention_size_swa); } - create_1d_input("n_seq_active", m_n_seq_active); - create_1d_input("seq_active_start", m_seq_active_start); - create_1d_input("seq_active_end", m_seq_active_start + m_n_seq_active); - create_1d_input("token_len_per_seq", m_token_len_per_seq); + create_1d_input("n_seq_active", m_compute_params.n_seq_active); + create_1d_input("seq_active_start", m_compute_params.seq_active_start); + create_1d_input("seq_active_end", m_compute_params.seq_active_start + m_compute_params.n_seq_active); + create_1d_input("token_len_per_seq", m_compute_params.token_len_per_seq); // create_1d_input("token_len", m_token_len_per_seq * m_n_seq_active); } @@ -445,15 +445,15 @@ const ggml_tensor * GgmlOvDecoder::get_tensor_from_name(const std::string & name return nullptr; } -std::map GgmlOvDecoder::get_kv_param_res_names() const { - std::map kv_param_res_names; - for (const auto & name : m_kv_names) { - if (name.find("cache_k") == 0 || name.find("cache_v") == 0) { - kv_param_res_names[name] = name; - } - } - return kv_param_res_names; -} +// std::map GgmlOvDecoder::get_kv_param_res_names() const { +// std::map kv_param_res_names; +// for (const auto & name : m_model_params.kv_names) { +// if (name.find("cache_k") == 0 || name.find("cache_v") == 0) { +// kv_param_res_names[name] = name; +// } +// } +// return kv_param_res_names; +// } std::map> GgmlOvDecoder::create_weight_nodes( ggml_cgraph * cgraph, diff --git a/ggml/src/ggml-openvino/ggml-decoder.h b/ggml/src/ggml-openvino/ggml-decoder.h index 11f35f038e..f2efb65a23 100644 --- a/ggml/src/ggml-openvino/ggml-decoder.h +++ b/ggml/src/ggml-openvino/ggml-decoder.h @@ -11,6 +11,42 @@ #include #include +struct ModelParams { + int ctx = -1; + int ctx_swa = -1; + int ctx_per_seq = -1; + int ctx_per_seq_swa = -1; + int n_seq = -1; + int n_heads = -1; + int n_heads_kv = -1; + int head_size = -1; + int32_t * rope_params = nullptr; + std::vector swa_layers; + + // std::vector kv_names; + + bool can_reuse_dynamically(const ModelParams & other) const { + return n_seq == other.n_seq && n_heads == other.n_heads && n_heads_kv == other.n_heads_kv && + head_size == other.head_size && rope_params == other.rope_params && swa_layers == other.swa_layers; + } + + bool can_reuse_statically(const ModelParams & other) const { + return can_reuse_dynamically(other) && ctx_per_seq == other.ctx_per_seq && + ctx_per_seq_swa == other.ctx_per_seq_swa; + } +}; + +struct ComputeParams { + int n_seq_active = -1; + int seq_active_start = -1; + int attention_size = -1; + int attention_size_swa = -1; + int input_len = -1; + int token_len_per_seq = -1; + int past_kv_len = -1; + int output_len = -1; +}; + class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder { public: struct NodeInfo { @@ -25,6 +61,8 @@ public: }; // Graph decoder GgmlOvDecoder(ggml_cgraph * cgraph, + ModelParams & model_params, + ComputeParams & compute_params, std::map> & model_weights, bool is_static, bool is_prefill = false, @@ -120,27 +158,28 @@ public: virtual const std::vector & get_model_output_names() const override { return m_model_output_names; } - virtual int get_ctx_size() const { return m_ctx; } + virtual int get_ctx_size() const { return m_model_params.ctx; } - virtual int get_ctx_swa_size() const { return m_ctx_swa; } + virtual int get_ctx_swa_size() const { return m_model_params.ctx_swa; } - virtual int get_ctx_per_seq() const { return m_ctx_per_seq; } + virtual int get_ctx_per_seq() const { return m_model_params.ctx_per_seq; } - virtual int get_ctx_per_seq_swa() const { return m_ctx_per_seq_swa; } + virtual int get_ctx_per_seq_swa() const { return m_model_params.ctx_per_seq_swa; } - virtual int get_n_seq() const { return m_n_seq; } + virtual int get_n_seq() const { return m_model_params.n_seq; } virtual int is_swa_layer(int layer) const override { - return std::find(m_swa_layers.begin(), m_swa_layers.end(), layer) != m_swa_layers.end(); + return std::find(m_model_params.swa_layers.begin(), m_model_params.swa_layers.end(), layer) != + m_model_params.swa_layers.end(); } - int get_past_kv_len() const { return m_past_kv_len; } + int get_past_kv_len() const { return m_compute_params.past_kv_len; } - int get_input_len() const { return m_input_len; } + int get_input_len() const { return m_compute_params.input_len; } - virtual int32_t * get_rope_params() const override { return m_rope_params; } + virtual int32_t * get_rope_params() const override { return m_model_params.rope_params; } - virtual std::map get_kv_param_res_names() const override; + // virtual std::map get_kv_param_res_names() const override; virtual bool is_static() const override { return m_is_static; } @@ -161,6 +200,16 @@ public: void clear_model_weights() { m_model_weights.clear(); } + static std::pair compute_llm_params(ggml_cgraph * cgraph, bool is_static); + + ModelParams get_model_params() const { return m_model_params; } + + ComputeParams get_compute_params() const { return m_compute_params; } + + void set_model_params(const ModelParams & model_params) { m_model_params = model_params; } + + void set_compute_params(const ComputeParams & compute_params) { m_compute_params = compute_params; } + bool m_is_static = false; bool m_is_prefill = false; int m_prefill_chunk_size = 0; @@ -174,7 +223,6 @@ private: int compute_op_case(const ggml_tensor * node); std::string compute_op_type(const ggml_tensor * node); - void set_llm_params(); void validate_cgraph() const; ggml_cgraph * m_cgraph = nullptr; @@ -191,27 +239,8 @@ private: std::vector m_model_output_names; std::vector m_node_info_list; - // Fixed for a model - int m_ctx = -1; - int m_ctx_swa = -1; - int m_ctx_per_seq = -1; - int m_ctx_per_seq_swa = -1; - int m_n_seq = -1; - int m_n_heads = -1; - int m_n_heads_kv = -1; - int m_head_size = -1; - std::vector m_swa_layers; - std::vector m_kv_names; - - // Changed per inference - int m_n_seq_active = -1; - int m_seq_active_start = -1; - int m_attention_size = -1; - int m_attention_size_swa = -1; - int m_input_len = -1; - int m_token_len_per_seq = -1; - int m_past_kv_len = -1; - int32_t * m_rope_params = nullptr; + ModelParams m_model_params; + ComputeParams m_compute_params; }; void print_tensor_address_map(const ggml_cgraph * cgraph); diff --git a/ggml/src/ggml-openvino/openvino/decoder.hpp b/ggml/src/ggml-openvino/openvino/decoder.hpp index 1d5b7a850f..9c455a3724 100644 --- a/ggml/src/ggml-openvino/openvino/decoder.hpp +++ b/ggml/src/ggml-openvino/openvino/decoder.hpp @@ -75,7 +75,7 @@ public: virtual const std::vector& get_model_output_names() const = 0; virtual int32_t* get_rope_params() const = 0; - virtual std::map get_kv_param_res_names() const = 0; + // virtual std::map get_kv_param_res_names() const = 0; virtual bool is_static() const = 0; diff --git a/ggml/src/ggml-openvino/utils.cpp b/ggml/src/ggml-openvino/utils.cpp index ae8916cc58..e90073a1f2 100644 --- a/ggml/src/ggml-openvino/utils.cpp +++ b/ggml/src/ggml-openvino/utils.cpp @@ -79,16 +79,21 @@ enum ggml_status openvino_frontend_compute(ggml_backend_t backend, ggml_cgraph * } static std::mutex cache_mutex; - static std::unordered_map> infer_request_cache; - static std::unordered_map> infer_request_cache_prefill; - static std::unordered_map> ov_input_names_cache; - static std::unordered_map> ov_output_names_cache; + static std::unordered_map, graph_key_hash> decoder_cache; + static std::unordered_map, graph_key_hash> infer_request_cache; + static std::unordered_map, graph_key_hash> infer_request_cache_prefill; + static std::unordered_map, graph_key_hash> ov_input_names_cache; + static std::unordered_map, graph_key_hash> ov_output_names_cache; std::shared_ptr ggml_decoder; std::shared_ptr infer_request; + ModelParams m_params; + ComputeParams c_params; + std::tie(m_params, c_params) = GgmlOvDecoder::compute_llm_params(cgraph, is_static); const auto * inp_pos = get_inp_pos_tensor(cgraph); const auto is_prefill = get_is_prefill(inp_pos); + const auto key = compute_graph_key(cgraph); int64_t decoder_end_time; int64_t conversion_end_time; @@ -98,25 +103,34 @@ enum ggml_status openvino_frontend_compute(ggml_backend_t backend, ggml_cgraph * { std::lock_guard lock(cache_mutex); - auto it = infer_request_cache.find(cgraph); - if (it != infer_request_cache.end()) { + auto it = decoder_cache.find(key); + + auto cache_hit = it != decoder_cache.end(); + if (cache_hit) { + ggml_decoder = it->second; + cache_hit = is_static ? ggml_decoder->get_model_params().can_reuse_statically(m_params) : + ggml_decoder->get_model_params().can_reuse_dynamically(m_params); + } + + if (cache_hit) { std::map> model_weights; - ggml_decoder = - std::make_shared(cgraph, model_weights, is_static, is_prefill, prefill_chunk_size); + ggml_decoder = std::make_shared(cgraph, m_params, c_params, model_weights, is_static, + is_prefill, prefill_chunk_size); + decoder_cache[key] = ggml_decoder; decoder_end_time = ggml_time_us(); - infer_request = infer_request_cache[cgraph]; - if (is_static && is_prefill) { - infer_request = infer_request_cache_prefill[cgraph]; - } + infer_request = is_static && is_prefill ? infer_request_cache_prefill[key] : infer_request_cache[key]; conversion_end_time = ggml_time_us(); compile_end_time = conversion_end_time; } else { + infer_request_cache.erase(key); + infer_request_cache_prefill.erase(key); + std::shared_ptr model; auto model_weights = GgmlOvDecoder::create_weight_nodes(cgraph, get_types_to_requant(device)); if (!is_static) { - ggml_decoder = std::make_shared(cgraph, model_weights, is_static); + ggml_decoder = std::make_shared(cgraph, m_params, c_params, model_weights, is_static); decoder_end_time = ggml_time_us(); auto input_model = std::make_shared(ggml_decoder); @@ -133,13 +147,14 @@ enum ggml_status openvino_frontend_compute(ggml_backend_t backend, ggml_cgraph * auto compiled_model = core.compile_model(model, device, get_ov_compile_config(device)); compile_end_time = ggml_time_us(); - infer_request_cache[cgraph] = std::make_shared(compiled_model.create_infer_request()); - infer_request = infer_request_cache[cgraph]; + infer_request = std::make_shared(compiled_model.create_infer_request()); + infer_request_cache[key] = infer_request; + decoder_cache[key] = ggml_decoder; } else { - auto ggml_decoder_prefill = - std::make_shared(cgraph, model_weights, is_static, true, prefill_chunk_size); - auto ggml_decoder_decode = - std::make_shared(cgraph, model_weights, is_static, false, prefill_chunk_size); + auto ggml_decoder_prefill = std::make_shared(cgraph, m_params, c_params, model_weights, + is_static, true, prefill_chunk_size); + auto ggml_decoder_decode = std::make_shared(cgraph, m_params, c_params, model_weights, + is_static, false, prefill_chunk_size); decoder_end_time = ggml_time_us(); auto input_model_prefill = std::make_shared(ggml_decoder_prefill); @@ -162,15 +177,17 @@ enum ggml_status openvino_frontend_compute(ggml_backend_t backend, ggml_cgraph * auto compiled_model_prefill = core.compile_model(model_prefill, device, get_ov_compile_config(device)); auto compiled_model_decode = core.compile_model(model_decode, device, get_ov_compile_config(device)); - infer_request_cache_prefill[cgraph] = + + infer_request_cache_prefill[key] = std::make_shared(compiled_model_prefill.create_infer_request()); - infer_request_cache[cgraph] = + infer_request_cache[key] = std::make_shared(compiled_model_decode.create_infer_request()); compile_end_time = ggml_time_us(); model = is_prefill ? model_prefill : model_decode; ggml_decoder = is_prefill ? ggml_decoder_prefill : ggml_decoder_decode; - infer_request = is_prefill ? infer_request_cache_prefill[cgraph] : infer_request_cache[cgraph]; + infer_request = is_prefill ? infer_request_cache_prefill[key] : infer_request_cache[key]; + decoder_cache[key] = ggml_decoder; } std::vector ov_input_names; @@ -181,8 +198,8 @@ enum ggml_status openvino_frontend_compute(ggml_backend_t backend, ggml_cgraph * for (const auto & ov_output : model->get_results()) { ov_output_names.push_back(ov_output->get_friendly_name()); } - ov_input_names_cache[cgraph] = ov_input_names; - ov_output_names_cache[cgraph] = ov_output_names; + ov_input_names_cache[key] = ov_input_names; + ov_output_names_cache[key] = ov_output_names; // Set output tensors (for NPU) and kvcache i/o tensors once and for all // Note: does not seem to improve perf on CPU/GPU, but breaks llama-bench, so disabled it for CPU/GPU @@ -205,8 +222,8 @@ enum ggml_status openvino_frontend_compute(ggml_backend_t backend, ggml_cgraph * } } - auto ov_input_names = ov_input_names_cache[cgraph]; - auto ov_output_names = ov_output_names_cache[cgraph]; + auto ov_input_names = ov_input_names_cache[key]; + auto ov_output_names = ov_output_names_cache[key]; if (!is_static) { for (size_t i = 0; i < ov_input_names.size(); i++) { @@ -675,4 +692,19 @@ bool get_is_prefill(const ggml_tensor * inp_pos) { return inp_pos->ne[0] > 1; } +graph_key compute_graph_key(ggml_cgraph * cgraph) { + graph_key key; + key.n_nodes = cgraph->n_nodes; + + if (cgraph->n_nodes > 0) { + key.first_node_name = std::string(cgraph->nodes[0]->name); + key.last_node_name = std::string(cgraph->nodes[cgraph->n_nodes - 1]->name); + } else { + key.first_node_name = ""; + key.last_node_name = ""; + } + + return key; +} + #pragma GCC diagnostic pop diff --git a/ggml/src/ggml-openvino/utils.h b/ggml/src/ggml-openvino/utils.h index 31f86d0999..dca74f8afc 100644 --- a/ggml/src/ggml-openvino/utils.h +++ b/ggml/src/ggml-openvino/utils.h @@ -6,6 +6,26 @@ #include #include +struct graph_key { + size_t n_nodes; + std::string first_node_name; + std::string last_node_name; + + bool operator==(const graph_key & other) const { + return n_nodes == other.n_nodes && first_node_name == other.first_node_name && + last_node_name == other.last_node_name; + } +}; + +struct graph_key_hash { + size_t operator()(const graph_key & key) const { + size_t h = std::hash{}(key.n_nodes); + h ^= std::hash{}(key.first_node_name) + 0x9e3779b9 + (h << 6) + (h >> 2); + h ^= std::hash{}(key.last_node_name) + 0x9e3779b9 + (h << 6) + (h >> 2); + return h; + } +}; + enum ggml_status openvino_frontend_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph); size_t checksum(const void * data, size_t size); @@ -46,6 +66,8 @@ const ggml_tensor * get_inp_pos_tensor(struct ggml_cgraph * cgraph); bool get_is_prefill(const ggml_tensor * inp_pos); +graph_key compute_graph_key(struct ggml_cgraph * cgraph); + ov::AnyMap get_ov_compile_config(const std::string & device); std::map get_types_to_requant(const std::string & device);