Thread safety per request only
This commit is contained in:
parent
fbc3128c17
commit
c397b1cfac
|
|
@ -19,7 +19,6 @@
|
||||||
#include <iomanip>
|
#include <iomanip>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <mutex>
|
|
||||||
#include <openvino/core/dimension.hpp>
|
#include <openvino/core/dimension.hpp>
|
||||||
#include <openvino/core/except.hpp>
|
#include <openvino/core/except.hpp>
|
||||||
#include <openvino/core/node.hpp>
|
#include <openvino/core/node.hpp>
|
||||||
|
|
@ -577,9 +576,6 @@ std::map<std::string, std::string> GgmlOvDecoder::get_kv_param_res_names() const
|
||||||
}
|
}
|
||||||
|
|
||||||
std::map<std::string, std::shared_ptr<ov::Node>> GgmlOvDecoder::create_weight_nodes(ggml_cgraph * cgraph, bool naive) {
|
std::map<std::string, std::shared_ptr<ov::Node>> GgmlOvDecoder::create_weight_nodes(ggml_cgraph * cgraph, bool naive) {
|
||||||
static std::mutex weights_mutex;
|
|
||||||
std::lock_guard<std::mutex> lock(weights_mutex);
|
|
||||||
|
|
||||||
std::map<std::string, std::shared_ptr<ov::Node>> model_weights;
|
std::map<std::string, std::shared_ptr<ov::Node>> model_weights;
|
||||||
auto * nodes = cgraph->nodes;
|
auto * nodes = cgraph->nodes;
|
||||||
auto n_nodes = cgraph->n_nodes;
|
auto n_nodes = cgraph->n_nodes;
|
||||||
|
|
|
||||||
|
|
@ -106,17 +106,23 @@ enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, std::shared_ptr<
|
||||||
int64_t infer_end_time;
|
int64_t infer_end_time;
|
||||||
|
|
||||||
{
|
{
|
||||||
std::lock_guard<std::mutex> lock(r_ctx->ov_compute_mutex);
|
std::shared_ptr<std::mutex> mutex;
|
||||||
|
|
||||||
auto it = r_ctx->decoder_cache.find(key);
|
auto it = r_ctx->decoder_cache.find(key);
|
||||||
|
|
||||||
cache_hit = it != r_ctx->decoder_cache.end();
|
cache_hit = it != r_ctx->decoder_cache.end();
|
||||||
ModelParams old_m_params;
|
ModelParams old_m_params;
|
||||||
if (cache_hit) {
|
if (cache_hit) {
|
||||||
ggml_decoder = it->second;
|
mutex = it->second->mutex;
|
||||||
|
std::lock_guard<std::mutex> lock(*(mutex));
|
||||||
|
ggml_decoder = it->second->ptr;
|
||||||
old_m_params = ggml_decoder->get_model_params();
|
old_m_params = ggml_decoder->get_model_params();
|
||||||
cache_hit = old_m_params.can_reuse_dynamically(m_params);
|
cache_hit = old_m_params.can_reuse_dynamically(m_params);
|
||||||
|
} else {
|
||||||
|
mutex = std::make_shared<std::mutex>();
|
||||||
|
r_ctx->decoder_cache[key] = std::make_shared<decoder_runtime_ctx>(mutex);
|
||||||
}
|
}
|
||||||
|
std::lock_guard<std::mutex> lock(*(mutex));
|
||||||
|
|
||||||
if (cache_hit) {
|
if (cache_hit) {
|
||||||
std::map<std::string, std::shared_ptr<ov::Node>> model_weights;
|
std::map<std::string, std::shared_ptr<ov::Node>> model_weights;
|
||||||
|
|
@ -200,7 +206,7 @@ enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, std::shared_ptr<
|
||||||
compile_end_time = ggml_time_us();
|
compile_end_time = ggml_time_us();
|
||||||
infer_request = std::make_shared<ov::InferRequest>(compiled_model.create_infer_request());
|
infer_request = std::make_shared<ov::InferRequest>(compiled_model.create_infer_request());
|
||||||
r_ctx->infer_request_cache[key] = infer_request;
|
r_ctx->infer_request_cache[key] = infer_request;
|
||||||
r_ctx->decoder_cache[key] = ggml_decoder;
|
r_ctx->decoder_cache.at(key)->ptr = ggml_decoder;
|
||||||
|
|
||||||
std::vector<std::string> ov_input_names;
|
std::vector<std::string> ov_input_names;
|
||||||
std::vector<std::string> ov_output_names;
|
std::vector<std::string> ov_output_names;
|
||||||
|
|
@ -306,15 +312,23 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph, std::shared_ptr<o
|
||||||
int64_t compile_end_time;
|
int64_t compile_end_time;
|
||||||
int64_t infer_end_time;
|
int64_t infer_end_time;
|
||||||
|
|
||||||
|
std::shared_ptr<std::mutex> mutex;
|
||||||
|
|
||||||
auto it = r_ctx->decoder_cache.find(key);
|
auto it = r_ctx->decoder_cache.find(key);
|
||||||
|
|
||||||
cache_hit = it != r_ctx->decoder_cache.end();
|
cache_hit = it != r_ctx->decoder_cache.end();
|
||||||
ModelParams old_m_params;
|
ModelParams old_m_params;
|
||||||
if (cache_hit) {
|
if (cache_hit) {
|
||||||
ggml_decoder = it->second;
|
mutex = it->second->mutex;
|
||||||
|
std::lock_guard<std::mutex> lock(*(mutex));
|
||||||
|
ggml_decoder = it->second->ptr;
|
||||||
old_m_params = ggml_decoder->get_model_params();
|
old_m_params = ggml_decoder->get_model_params();
|
||||||
cache_hit = old_m_params.can_reuse_statically(m_params);
|
cache_hit = old_m_params.can_reuse_statically(m_params);
|
||||||
|
} else {
|
||||||
|
mutex = std::make_shared<std::mutex>();
|
||||||
|
r_ctx->decoder_cache[key] = std::make_shared<decoder_runtime_ctx>(mutex);
|
||||||
}
|
}
|
||||||
|
std::lock_guard<std::mutex> lock(*(mutex));
|
||||||
|
|
||||||
if (cache_hit) {
|
if (cache_hit) {
|
||||||
std::map<std::string, std::shared_ptr<ov::Node>> model_weights;
|
std::map<std::string, std::shared_ptr<ov::Node>> model_weights;
|
||||||
|
|
@ -381,7 +395,7 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph, std::shared_ptr<o
|
||||||
model = is_prefill ? model_prefill : model_decode;
|
model = is_prefill ? model_prefill : model_decode;
|
||||||
ggml_decoder = is_prefill ? ggml_decoder_prefill : ggml_decoder_decode;
|
ggml_decoder = is_prefill ? ggml_decoder_prefill : ggml_decoder_decode;
|
||||||
infer_request = is_prefill ? r_ctx->infer_request_cache_prefill[key] : r_ctx->infer_request_cache[key];
|
infer_request = is_prefill ? r_ctx->infer_request_cache_prefill[key] : r_ctx->infer_request_cache[key];
|
||||||
r_ctx->decoder_cache[key] = ggml_decoder;
|
r_ctx->decoder_cache.at(key)->ptr = ggml_decoder;
|
||||||
|
|
||||||
std::vector<std::string> ov_input_names;
|
std::vector<std::string> ov_input_names;
|
||||||
std::vector<std::string> ov_output_names;
|
std::vector<std::string> ov_output_names;
|
||||||
|
|
|
||||||
|
|
@ -40,11 +40,17 @@ struct graph_key_hash {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct decoder_runtime_ctx {
|
||||||
|
decoder_runtime_ctx(std::shared_ptr<std::mutex> mutex) :
|
||||||
|
mutex(mutex) {}
|
||||||
|
std::shared_ptr<std::mutex> mutex;
|
||||||
|
std::shared_ptr<GgmlOvDecoder> ptr;
|
||||||
|
};
|
||||||
|
|
||||||
struct ov_runtime_context {
|
struct ov_runtime_context {
|
||||||
std::mutex ov_compute_mutex;
|
|
||||||
std::string device;
|
std::string device;
|
||||||
bool stateful;
|
bool stateful;
|
||||||
std::unordered_map<graph_key, std::shared_ptr<GgmlOvDecoder>, graph_key_hash> decoder_cache;
|
std::unordered_map<graph_key, std::shared_ptr<decoder_runtime_ctx>, graph_key_hash> decoder_cache;
|
||||||
std::unordered_map<graph_key, std::shared_ptr<ov::InferRequest>, graph_key_hash> infer_request_cache;
|
std::unordered_map<graph_key, std::shared_ptr<ov::InferRequest>, graph_key_hash> infer_request_cache;
|
||||||
std::unordered_map<graph_key, std::shared_ptr<ov::InferRequest>, graph_key_hash> infer_request_cache_prefill;
|
std::unordered_map<graph_key, std::shared_ptr<ov::InferRequest>, graph_key_hash> infer_request_cache_prefill;
|
||||||
std::unordered_map<graph_key, std::vector<std::string>, graph_key_hash> ov_input_names_cache;
|
std::unordered_map<graph_key, std::vector<std::string>, graph_key_hash> ov_input_names_cache;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue