소스 검색

metal : fix data race in pipeline library (#17731)

Georgi Gerganov 1 개월 전
부모
커밋
3d94e967a1
2개의 변경된 파일22개의 추가작업 그리고 11개의 파일을 삭제
  1. 1 1
      ggml/src/ggml-metal/ggml-metal-device.cpp
  2. 21 10
      ggml/src/ggml-metal/ggml-metal-device.m

+ 1 - 1
ggml/src/ggml-metal/ggml-metal-device.cpp

@@ -50,7 +50,7 @@ void ggml_metal_pipelines_add(ggml_metal_pipelines_t ppls, const char * name, gg
 }
 
 ggml_metal_pipeline_t ggml_metal_pipelines_get(ggml_metal_pipelines_t ppls, const char * name) {
-    if  (ppls->data.find(name) == ppls->data.end()) {
+    if (ppls->data.find(name) == ppls->data.end()) {
         return nullptr;
     }
 

+ 21 - 10
ggml/src/ggml-metal/ggml-metal-device.m

@@ -146,6 +146,8 @@ struct ggml_metal_library {
     id<MTLDevice> device;
 
     ggml_metal_pipelines_t pipelines; // cache of compiled pipelines
+
+    NSLock * lock;
 };
 
 ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) {
@@ -296,9 +298,10 @@ ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) {
 
     ggml_metal_library_t res = calloc(1, sizeof(struct ggml_metal_library));
 
-    res->obj = library;
-    res->device = device;
+    res->obj       = library;
+    res->device    = device;
     res->pipelines = ggml_metal_pipelines_init();
+    res->lock      = [NSLock new];
 
     return res;
 }
@@ -365,6 +368,7 @@ ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev
     res->obj       = library;
     res->device    = device;
     res->pipelines = ggml_metal_pipelines_init();
+    res->lock      = [NSLock new];
 
     return res;
 }
@@ -380,20 +384,27 @@ void ggml_metal_library_free(ggml_metal_library_t lib) {
 
     ggml_metal_pipelines_free(lib->pipelines);
 
+    [lib->lock release];
+
     free(lib);
 }
 
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline(ggml_metal_library_t lib, const char * name) {
-    return ggml_metal_pipelines_get(lib->pipelines, name);
+    [lib->lock lock];
+
+    ggml_metal_pipeline_t res = ggml_metal_pipelines_get(lib->pipelines, name);
+
+    [lib->lock unlock];
+
+    return res;
 }
 
 ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv) {
-    // note: the pipelines are cached in the library per device, so they are shared across all metal contexts
-    ggml_critical_section_start();
+    [lib->lock lock];
 
-    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
+    ggml_metal_pipeline_t res = ggml_metal_pipelines_get(lib->pipelines, name);
     if (res) {
-        ggml_critical_section_end();
+        [lib->lock unlock];
 
         return res;
     }
@@ -414,7 +425,7 @@ ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t l
             mtl_function = [lib->obj newFunctionWithName:base_func constantValues:cv->obj error:&error];
         }
         if (!mtl_function) {
-            ggml_critical_section_end();
+            [lib->lock unlock];
 
             GGML_LOG_ERROR("%s: failed to compile pipeline: base = '%s', name = '%s'\n", __func__, base, name);
             if (error) {
@@ -433,7 +444,7 @@ ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t l
                 (int) res->obj.threadExecutionWidth);
 
         if (res->obj.maxTotalThreadsPerThreadgroup == 0 || res->obj.threadExecutionWidth == 0) {
-            ggml_critical_section_end();
+            [lib->lock unlock];
 
             GGML_LOG_ERROR("%s: incompatible pipeline %s\n", __func__, name);
 
@@ -443,7 +454,7 @@ ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t l
         ggml_metal_pipelines_add(lib->pipelines, name, res);
     }
 
-    ggml_critical_section_end();
+    [lib->lock unlock];
 
     return res;
 }