diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 329500a03e..c647baef87 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -50,7 +50,7 @@ void ggml_metal_pipelines_add(ggml_metal_pipelines_t ppls, const char * name, gg } ggml_metal_pipeline_t ggml_metal_pipelines_get(ggml_metal_pipelines_t ppls, const char * name) { - if (ppls->data.find(name) == ppls->data.end()) { + if (ppls->data.find(name) == ppls->data.end()) { return nullptr; } diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 62bc4ba45f..4d2bfcf91c 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -146,6 +146,8 @@ struct ggml_metal_library { id device; ggml_metal_pipelines_t pipelines; // cache of compiled pipelines + + NSLock * lock; }; ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) { @@ -296,9 +298,10 @@ ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) { ggml_metal_library_t res = calloc(1, sizeof(struct ggml_metal_library)); - res->obj = library; - res->device = device; + res->obj = library; + res->device = device; res->pipelines = ggml_metal_pipelines_init(); + res->lock = [NSLock new]; return res; } @@ -365,6 +368,7 @@ ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev res->obj = library; res->device = device; res->pipelines = ggml_metal_pipelines_init(); + res->lock = [NSLock new]; return res; } @@ -380,20 +384,27 @@ void ggml_metal_library_free(ggml_metal_library_t lib) { ggml_metal_pipelines_free(lib->pipelines); + [lib->lock release]; + free(lib); } ggml_metal_pipeline_t ggml_metal_library_get_pipeline(ggml_metal_library_t lib, const char * name) { - return ggml_metal_pipelines_get(lib->pipelines, name); + [lib->lock lock]; + + ggml_metal_pipeline_t res = ggml_metal_pipelines_get(lib->pipelines, name); + + [lib->lock unlock]; + + return res; } ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv) { - // note: the pipelines are cached in the library per device, so they are shared across all metal contexts - ggml_critical_section_start(); + [lib->lock lock]; - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + ggml_metal_pipeline_t res = ggml_metal_pipelines_get(lib->pipelines, name); if (res) { - ggml_critical_section_end(); + [lib->lock unlock]; return res; } @@ -414,7 +425,7 @@ ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t l mtl_function = [lib->obj newFunctionWithName:base_func constantValues:cv->obj error:&error]; } if (!mtl_function) { - ggml_critical_section_end(); + [lib->lock unlock]; GGML_LOG_ERROR("%s: failed to compile pipeline: base = '%s', name = '%s'\n", __func__, base, name); if (error) { @@ -433,7 +444,7 @@ ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t l (int) res->obj.threadExecutionWidth); if (res->obj.maxTotalThreadsPerThreadgroup == 0 || res->obj.threadExecutionWidth == 0) { - ggml_critical_section_end(); + [lib->lock unlock]; GGML_LOG_ERROR("%s: incompatible pipeline %s\n", __func__, name); @@ -443,7 +454,7 @@ ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t l ggml_metal_pipelines_add(lib->pipelines, name, res); } - ggml_critical_section_end(); + [lib->lock unlock]; return res; }