Przeglądaj źródła

metal : move mm_id indices to shared mem (#5982)

Georgi Gerganov 1 rok temu
rodzic
commit
bb6d00bbf9
2 zmienionych plików z 6 dodań i 6 usunięć
  1. 3 3
      ggml-metal.m
  2. 3 3
      ggml-metal.metal

+ 3 - 3
ggml-metal.m

@@ -1642,8 +1642,8 @@ static enum ggml_status ggml_metal_graph_compute(
                         // TODO: make this more general
                         GGML_ASSERT(n_as <= 8);
 
-                        // max size of the src1ids array in the kernel stack
-                        GGML_ASSERT(ne11 <= 512);
+                        // max size of the src1ids array in the kernel shared buffer
+                        GGML_ASSERT(ne11 <= 4096);
 
                         const int64_t  ne20 = src2 ? src2->ne[0] : 0;
                         const int64_t  ne21 = src2 ? src2->ne[1] : 0;
@@ -1741,7 +1741,7 @@ static enum ggml_status ggml_metal_graph_compute(
                                 [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
                             }
 
-                            [encoder setThreadgroupMemoryLength:8192 atIndex:0];
+                            [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + 2*ne11, 16) atIndex:0];
 
                             [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne21 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
                         } else {

+ 3 - 3
ggml-metal.metal

@@ -5386,7 +5386,7 @@ template<typename block_q, short nl, void (*dequantize_func)(device const block_
 void kernel_mul_mm_id_impl(
         device const  uchar * src0,
         device const  uchar * src1,
-        thread        short * src1ids,
+        threadgroup   short * src1ids,
         device        float * dst,
         constant    int64_t & ne00,
         constant    int64_t & ne02,
@@ -5589,9 +5589,9 @@ kernel void kernel_mul_mm_id(
     tgpig.z = tgpig.z%(ne12*ne13);
 
     // row indices of src1 for expert id
-    int64_t _ne1 = 0;
-    short src1ids[512];
+    threadgroup short * src1ids = (threadgroup short *)(shared_memory + 8192);
 
+    int64_t _ne1 = 0;
     for (int64_t i1 = 0; i1 < ne1; i1++) {
         if (((device int32_t *) (ids + i1*nbi1))[idx] == id) {
             src1ids[_ne1++] = i1;