Ver Fonte

metal : round up to 16 to fix MTLDebugComputeCommandEncoder assertion (#3938)

Peter Sugihara há 2 anos atrás
pai
commit
d9b33fe95b
1 ficheiros alterados com 2 adições e 2 exclusões
  1. 2 2
      ggml-metal.m

+ 2 - 2
ggml-metal.m

@@ -1017,7 +1017,7 @@ void ggml_metal_graph_compute(
                             [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
                             [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
                             [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
-                            [encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0];
+                            [encoder setThreadgroupMemoryLength:MAX(16, nth/32*sizeof(float)) atIndex:0];
 
                             [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                         } break;
@@ -1348,7 +1348,7 @@ void ggml_metal_graph_compute(
                             [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
                             [encoder setBytes:&nb01    length:sizeof(uint64_t) atIndex:3];
                             [encoder setBytes:&eps     length:sizeof(   float) atIndex:4];
-                            [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
+                            [encoder setThreadgroupMemoryLength:MAX(16, nth*sizeof(float)) atIndex:0];
 
                             const int64_t nrows = ggml_nrows(src0);