Thread safety per request only

This commit is contained in:
Mustafa Cavus 2026-03-17 15:55:29 -07:00 committed by Zijun Yu
parent fbc3128c17
commit c397b1cfac
3 changed files with 27 additions and 11 deletions

View File

@ -19,7 +19,6 @@
#include <iomanip>
#include <map>
#include <memory>
#include <mutex>
#include <openvino/core/dimension.hpp>
#include <openvino/core/except.hpp>
#include <openvino/core/node.hpp>
@ -577,9 +576,6 @@ std::map<std::string, std::string> GgmlOvDecoder::get_kv_param_res_names() const
}
std::map<std::string, std::shared_ptr<ov::Node>> GgmlOvDecoder::create_weight_nodes(ggml_cgraph * cgraph, bool naive) {
static std::mutex weights_mutex;
std::lock_guard<std::mutex> lock(weights_mutex);
std::map<std::string, std::shared_ptr<ov::Node>> model_weights;
auto * nodes = cgraph->nodes;
auto n_nodes = cgraph->n_nodes;

View File

@ -106,17 +106,23 @@ enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, std::shared_ptr<
int64_t infer_end_time;
{
std::lock_guard<std::mutex> lock(r_ctx->ov_compute_mutex);
std::shared_ptr<std::mutex> mutex;
auto it = r_ctx->decoder_cache.find(key);
cache_hit = it != r_ctx->decoder_cache.end();
ModelParams old_m_params;
if (cache_hit) {
ggml_decoder = it->second;
mutex = it->second->mutex;
std::lock_guard<std::mutex> lock(*(mutex));
ggml_decoder = it->second->ptr;
old_m_params = ggml_decoder->get_model_params();
cache_hit = old_m_params.can_reuse_dynamically(m_params);
} else {
mutex = std::make_shared<std::mutex>();
r_ctx->decoder_cache[key] = std::make_shared<decoder_runtime_ctx>(mutex);
}
std::lock_guard<std::mutex> lock(*(mutex));
if (cache_hit) {
std::map<std::string, std::shared_ptr<ov::Node>> model_weights;
@ -200,7 +206,7 @@ enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, std::shared_ptr<
compile_end_time = ggml_time_us();
infer_request = std::make_shared<ov::InferRequest>(compiled_model.create_infer_request());
r_ctx->infer_request_cache[key] = infer_request;
r_ctx->decoder_cache[key] = ggml_decoder;
r_ctx->decoder_cache.at(key)->ptr = ggml_decoder;
std::vector<std::string> ov_input_names;
std::vector<std::string> ov_output_names;
@ -306,15 +312,23 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph, std::shared_ptr<o
int64_t compile_end_time;
int64_t infer_end_time;
std::shared_ptr<std::mutex> mutex;
auto it = r_ctx->decoder_cache.find(key);
cache_hit = it != r_ctx->decoder_cache.end();
ModelParams old_m_params;
if (cache_hit) {
ggml_decoder = it->second;
mutex = it->second->mutex;
std::lock_guard<std::mutex> lock(*(mutex));
ggml_decoder = it->second->ptr;
old_m_params = ggml_decoder->get_model_params();
cache_hit = old_m_params.can_reuse_statically(m_params);
} else {
mutex = std::make_shared<std::mutex>();
r_ctx->decoder_cache[key] = std::make_shared<decoder_runtime_ctx>(mutex);
}
std::lock_guard<std::mutex> lock(*(mutex));
if (cache_hit) {
std::map<std::string, std::shared_ptr<ov::Node>> model_weights;
@ -381,7 +395,7 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph, std::shared_ptr<o
model = is_prefill ? model_prefill : model_decode;
ggml_decoder = is_prefill ? ggml_decoder_prefill : ggml_decoder_decode;
infer_request = is_prefill ? r_ctx->infer_request_cache_prefill[key] : r_ctx->infer_request_cache[key];
r_ctx->decoder_cache[key] = ggml_decoder;
r_ctx->decoder_cache.at(key)->ptr = ggml_decoder;
std::vector<std::string> ov_input_names;
std::vector<std::string> ov_output_names;

View File

@ -40,11 +40,17 @@ struct graph_key_hash {
}
};
struct decoder_runtime_ctx {
decoder_runtime_ctx(std::shared_ptr<std::mutex> mutex) :
mutex(mutex) {}
std::shared_ptr<std::mutex> mutex;
std::shared_ptr<GgmlOvDecoder> ptr;
};
struct ov_runtime_context {
std::mutex ov_compute_mutex;
std::string device;
bool stateful;
std::unordered_map<graph_key, std::shared_ptr<GgmlOvDecoder>, graph_key_hash> decoder_cache;
std::unordered_map<graph_key, std::shared_ptr<decoder_runtime_ctx>, graph_key_hash> decoder_cache;
std::unordered_map<graph_key, std::shared_ptr<ov::InferRequest>, graph_key_hash> infer_request_cache;
std::unordered_map<graph_key, std::shared_ptr<ov::InferRequest>, graph_key_hash> infer_request_cache_prefill;
std::unordered_map<graph_key, std::vector<std::string>, graph_key_hash> ov_input_names_cache;