Przeglądaj źródła

metal : fix synchronization in new matrix multiplication kernel (#2686)

Shouzheng Liu 2 lat temu
rodzic
commit
dadbed99e6
1 zmienionych plików z 2 dodań i 1 usunięć
  1. 2 1
      ggml-metal.metal

+ 2 - 1
ggml-metal.metal

@@ -1898,10 +1898,11 @@ kernel void kernel_mul_mm(device const  uchar * src0,
         threadgroup float *temp_str = ((threadgroup float *)shared_memory) \
         threadgroup float *temp_str = ((threadgroup float *)shared_memory) \
                                       + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
                                       + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
         for (int i = 0; i < 8; i++) {
         for (int i = 0; i < 8; i++) {
+            threadgroup_barrier(mem_flags::mem_device);
             simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
             simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
         }
         }
 
 
-        threadgroup_barrier(mem_flags::mem_threadgroup);
+        threadgroup_barrier(mem_flags::mem_device);
         device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
         device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
         if (sgitg==0) {
         if (sgitg==0) {
             for (int i = 0; i < n_rows; i++) {
             for (int i = 0; i < n_rows; i++) {