metal : fix data race in pipeline library (#17731)
This commit is contained in:
parent
7feb0a1005
commit
3d94e967a1
|
|
@ -50,7 +50,7 @@ void ggml_metal_pipelines_add(ggml_metal_pipelines_t ppls, const char * name, gg
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_metal_pipeline_t ggml_metal_pipelines_get(ggml_metal_pipelines_t ppls, const char * name) {
|
ggml_metal_pipeline_t ggml_metal_pipelines_get(ggml_metal_pipelines_t ppls, const char * name) {
|
||||||
if (ppls->data.find(name) == ppls->data.end()) {
|
if (ppls->data.find(name) == ppls->data.end()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -146,6 +146,8 @@ struct ggml_metal_library {
|
||||||
id<MTLDevice> device;
|
id<MTLDevice> device;
|
||||||
|
|
||||||
ggml_metal_pipelines_t pipelines; // cache of compiled pipelines
|
ggml_metal_pipelines_t pipelines; // cache of compiled pipelines
|
||||||
|
|
||||||
|
NSLock * lock;
|
||||||
};
|
};
|
||||||
|
|
||||||
ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) {
|
ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) {
|
||||||
|
|
@ -296,9 +298,10 @@ ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) {
|
||||||
|
|
||||||
ggml_metal_library_t res = calloc(1, sizeof(struct ggml_metal_library));
|
ggml_metal_library_t res = calloc(1, sizeof(struct ggml_metal_library));
|
||||||
|
|
||||||
res->obj = library;
|
res->obj = library;
|
||||||
res->device = device;
|
res->device = device;
|
||||||
res->pipelines = ggml_metal_pipelines_init();
|
res->pipelines = ggml_metal_pipelines_init();
|
||||||
|
res->lock = [NSLock new];
|
||||||
|
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
@ -365,6 +368,7 @@ ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev
|
||||||
res->obj = library;
|
res->obj = library;
|
||||||
res->device = device;
|
res->device = device;
|
||||||
res->pipelines = ggml_metal_pipelines_init();
|
res->pipelines = ggml_metal_pipelines_init();
|
||||||
|
res->lock = [NSLock new];
|
||||||
|
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
@ -380,20 +384,27 @@ void ggml_metal_library_free(ggml_metal_library_t lib) {
|
||||||
|
|
||||||
ggml_metal_pipelines_free(lib->pipelines);
|
ggml_metal_pipelines_free(lib->pipelines);
|
||||||
|
|
||||||
|
[lib->lock release];
|
||||||
|
|
||||||
free(lib);
|
free(lib);
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline(ggml_metal_library_t lib, const char * name) {
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline(ggml_metal_library_t lib, const char * name) {
|
||||||
return ggml_metal_pipelines_get(lib->pipelines, name);
|
[lib->lock lock];
|
||||||
|
|
||||||
|
ggml_metal_pipeline_t res = ggml_metal_pipelines_get(lib->pipelines, name);
|
||||||
|
|
||||||
|
[lib->lock unlock];
|
||||||
|
|
||||||
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv) {
|
ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv) {
|
||||||
// note: the pipelines are cached in the library per device, so they are shared across all metal contexts
|
[lib->lock lock];
|
||||||
ggml_critical_section_start();
|
|
||||||
|
|
||||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
ggml_metal_pipeline_t res = ggml_metal_pipelines_get(lib->pipelines, name);
|
||||||
if (res) {
|
if (res) {
|
||||||
ggml_critical_section_end();
|
[lib->lock unlock];
|
||||||
|
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
@ -414,7 +425,7 @@ ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t l
|
||||||
mtl_function = [lib->obj newFunctionWithName:base_func constantValues:cv->obj error:&error];
|
mtl_function = [lib->obj newFunctionWithName:base_func constantValues:cv->obj error:&error];
|
||||||
}
|
}
|
||||||
if (!mtl_function) {
|
if (!mtl_function) {
|
||||||
ggml_critical_section_end();
|
[lib->lock unlock];
|
||||||
|
|
||||||
GGML_LOG_ERROR("%s: failed to compile pipeline: base = '%s', name = '%s'\n", __func__, base, name);
|
GGML_LOG_ERROR("%s: failed to compile pipeline: base = '%s', name = '%s'\n", __func__, base, name);
|
||||||
if (error) {
|
if (error) {
|
||||||
|
|
@ -433,7 +444,7 @@ ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t l
|
||||||
(int) res->obj.threadExecutionWidth);
|
(int) res->obj.threadExecutionWidth);
|
||||||
|
|
||||||
if (res->obj.maxTotalThreadsPerThreadgroup == 0 || res->obj.threadExecutionWidth == 0) {
|
if (res->obj.maxTotalThreadsPerThreadgroup == 0 || res->obj.threadExecutionWidth == 0) {
|
||||||
ggml_critical_section_end();
|
[lib->lock unlock];
|
||||||
|
|
||||||
GGML_LOG_ERROR("%s: incompatible pipeline %s\n", __func__, name);
|
GGML_LOG_ERROR("%s: incompatible pipeline %s\n", __func__, name);
|
||||||
|
|
||||||
|
|
@ -443,7 +454,7 @@ ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t l
|
||||||
ggml_metal_pipelines_add(lib->pipelines, name, res);
|
ggml_metal_pipelines_add(lib->pipelines, name, res);
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_critical_section_end();
|
[lib->lock unlock];
|
||||||
|
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue