Thread safety per request only
This commit is contained in:
parent
fbc3128c17
commit
c397b1cfac
|
|
@ -19,7 +19,6 @@
|
|||
#include <iomanip>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <openvino/core/dimension.hpp>
|
||||
#include <openvino/core/except.hpp>
|
||||
#include <openvino/core/node.hpp>
|
||||
|
|
@ -577,9 +576,6 @@ std::map<std::string, std::string> GgmlOvDecoder::get_kv_param_res_names() const
|
|||
}
|
||||
|
||||
std::map<std::string, std::shared_ptr<ov::Node>> GgmlOvDecoder::create_weight_nodes(ggml_cgraph * cgraph, bool naive) {
|
||||
static std::mutex weights_mutex;
|
||||
std::lock_guard<std::mutex> lock(weights_mutex);
|
||||
|
||||
std::map<std::string, std::shared_ptr<ov::Node>> model_weights;
|
||||
auto * nodes = cgraph->nodes;
|
||||
auto n_nodes = cgraph->n_nodes;
|
||||
|
|
|
|||
|
|
@ -106,17 +106,23 @@ enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, std::shared_ptr<
|
|||
int64_t infer_end_time;
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(r_ctx->ov_compute_mutex);
|
||||
std::shared_ptr<std::mutex> mutex;
|
||||
|
||||
auto it = r_ctx->decoder_cache.find(key);
|
||||
|
||||
cache_hit = it != r_ctx->decoder_cache.end();
|
||||
ModelParams old_m_params;
|
||||
if (cache_hit) {
|
||||
ggml_decoder = it->second;
|
||||
mutex = it->second->mutex;
|
||||
std::lock_guard<std::mutex> lock(*(mutex));
|
||||
ggml_decoder = it->second->ptr;
|
||||
old_m_params = ggml_decoder->get_model_params();
|
||||
cache_hit = old_m_params.can_reuse_dynamically(m_params);
|
||||
} else {
|
||||
mutex = std::make_shared<std::mutex>();
|
||||
r_ctx->decoder_cache[key] = std::make_shared<decoder_runtime_ctx>(mutex);
|
||||
}
|
||||
std::lock_guard<std::mutex> lock(*(mutex));
|
||||
|
||||
if (cache_hit) {
|
||||
std::map<std::string, std::shared_ptr<ov::Node>> model_weights;
|
||||
|
|
@ -200,7 +206,7 @@ enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, std::shared_ptr<
|
|||
compile_end_time = ggml_time_us();
|
||||
infer_request = std::make_shared<ov::InferRequest>(compiled_model.create_infer_request());
|
||||
r_ctx->infer_request_cache[key] = infer_request;
|
||||
r_ctx->decoder_cache[key] = ggml_decoder;
|
||||
r_ctx->decoder_cache.at(key)->ptr = ggml_decoder;
|
||||
|
||||
std::vector<std::string> ov_input_names;
|
||||
std::vector<std::string> ov_output_names;
|
||||
|
|
@ -306,15 +312,23 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph, std::shared_ptr<o
|
|||
int64_t compile_end_time;
|
||||
int64_t infer_end_time;
|
||||
|
||||
std::shared_ptr<std::mutex> mutex;
|
||||
|
||||
auto it = r_ctx->decoder_cache.find(key);
|
||||
|
||||
cache_hit = it != r_ctx->decoder_cache.end();
|
||||
ModelParams old_m_params;
|
||||
if (cache_hit) {
|
||||
ggml_decoder = it->second;
|
||||
mutex = it->second->mutex;
|
||||
std::lock_guard<std::mutex> lock(*(mutex));
|
||||
ggml_decoder = it->second->ptr;
|
||||
old_m_params = ggml_decoder->get_model_params();
|
||||
cache_hit = old_m_params.can_reuse_statically(m_params);
|
||||
} else {
|
||||
mutex = std::make_shared<std::mutex>();
|
||||
r_ctx->decoder_cache[key] = std::make_shared<decoder_runtime_ctx>(mutex);
|
||||
}
|
||||
std::lock_guard<std::mutex> lock(*(mutex));
|
||||
|
||||
if (cache_hit) {
|
||||
std::map<std::string, std::shared_ptr<ov::Node>> model_weights;
|
||||
|
|
@ -381,7 +395,7 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph, std::shared_ptr<o
|
|||
model = is_prefill ? model_prefill : model_decode;
|
||||
ggml_decoder = is_prefill ? ggml_decoder_prefill : ggml_decoder_decode;
|
||||
infer_request = is_prefill ? r_ctx->infer_request_cache_prefill[key] : r_ctx->infer_request_cache[key];
|
||||
r_ctx->decoder_cache[key] = ggml_decoder;
|
||||
r_ctx->decoder_cache.at(key)->ptr = ggml_decoder;
|
||||
|
||||
std::vector<std::string> ov_input_names;
|
||||
std::vector<std::string> ov_output_names;
|
||||
|
|
|
|||
|
|
@ -40,11 +40,17 @@ struct graph_key_hash {
|
|||
}
|
||||
};
|
||||
|
||||
struct decoder_runtime_ctx {
|
||||
decoder_runtime_ctx(std::shared_ptr<std::mutex> mutex) :
|
||||
mutex(mutex) {}
|
||||
std::shared_ptr<std::mutex> mutex;
|
||||
std::shared_ptr<GgmlOvDecoder> ptr;
|
||||
};
|
||||
|
||||
struct ov_runtime_context {
|
||||
std::mutex ov_compute_mutex;
|
||||
std::string device;
|
||||
bool stateful;
|
||||
std::unordered_map<graph_key, std::shared_ptr<GgmlOvDecoder>, graph_key_hash> decoder_cache;
|
||||
std::unordered_map<graph_key, std::shared_ptr<decoder_runtime_ctx>, graph_key_hash> decoder_cache;
|
||||
std::unordered_map<graph_key, std::shared_ptr<ov::InferRequest>, graph_key_hash> infer_request_cache;
|
||||
std::unordered_map<graph_key, std::shared_ptr<ov::InferRequest>, graph_key_hash> infer_request_cache_prefill;
|
||||
std::unordered_map<graph_key, std::vector<std::string>, graph_key_hash> ov_input_names_cache;
|
||||
|
|
|
|||
Loading…
Reference in New Issue