ソースを参照

metal : support permuted matrix multiplicaions (#10033)

* metal : support permuted matrix multiplicaions

ggml-ci

* cont : use nb01 directly for row steps

ggml-ci

* cont : add comments [no ci]

* metal : minor refactor

* metal : minor
Georgi Gerganov 1 年間 前
コミット
668750357e
2 ファイル変更263 行追加155 行削除
  1. 42 33
      ggml/src/ggml-metal.m
  2. 221 122
      ggml/src/ggml-metal.metal

+ 42 - 33
ggml/src/ggml-metal.m

@@ -1015,19 +1015,21 @@ static void ggml_metal_encode_node(
     id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
     id<MTLBuffer> id_dst  = dst  ? ggml_metal_get_buffer(dst,  &offs_dst)  : nil;
 
-    //GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
-    //if (src0) {
-    //    GGML_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
-    //            ggml_is_contiguous(src0), src0->name);
-    //}
-    //if (src1) {
-    //    GGML_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
-    //            ggml_is_contiguous(src1), src1->name);
-    //}
-    //if (dst) {
-    //    GGML_LOG_INFO("%s: dst  - %4s [%5lld, %5lld, %5lld], 1, %s\n",  __func__, ggml_type_name(dstt),  ne0,  ne1,  ne2,
-    //            dst->name);
-    //}
+#if 0
+    GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
+    if (src0) {
+        GGML_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03,
+                ggml_is_contiguous(src0), src0->name);
+    }
+    if (src1) {
+        GGML_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
+                ggml_is_contiguous(src1), src1->name);
+    }
+    if (dst) {
+        GGML_LOG_INFO("%s: dst  - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
+                dst->name);
+    }
+#endif
 
     id<MTLDevice> device = ctx_dev->mtl_device;
 
@@ -1810,14 +1812,16 @@ static void ggml_metal_encode_node(
                             [encoder setBytes:&ne02    length:sizeof(ne02) atIndex:4];
                             [encoder setBytes:&nb01    length:sizeof(nb01) atIndex:5];
                             [encoder setBytes:&nb02    length:sizeof(nb02) atIndex:6];
-                            [encoder setBytes:&ne12    length:sizeof(ne12) atIndex:7];
-                            [encoder setBytes:&nb10    length:sizeof(nb10) atIndex:8];
-                            [encoder setBytes:&nb11    length:sizeof(nb11) atIndex:9];
-                            [encoder setBytes:&nb12    length:sizeof(nb12) atIndex:10];
-                            [encoder setBytes:&ne0     length:sizeof(ne0)  atIndex:11];
-                            [encoder setBytes:&ne1     length:sizeof(ne1)  atIndex:12];
-                            [encoder setBytes:&r2      length:sizeof(r2)   atIndex:13];
-                            [encoder setBytes:&r3      length:sizeof(r3)   atIndex:14];
+                            [encoder setBytes:&nb03    length:sizeof(nb03) atIndex:7];
+                            [encoder setBytes:&ne12    length:sizeof(ne12) atIndex:8];
+                            [encoder setBytes:&nb10    length:sizeof(nb10) atIndex:9];
+                            [encoder setBytes:&nb11    length:sizeof(nb11) atIndex:10];
+                            [encoder setBytes:&nb12    length:sizeof(nb12) atIndex:11];
+                            [encoder setBytes:&nb13    length:sizeof(nb13) atIndex:12];
+                            [encoder setBytes:&ne0     length:sizeof(ne0)  atIndex:13];
+                            [encoder setBytes:&ne1     length:sizeof(ne1)  atIndex:14];
+                            [encoder setBytes:&r2      length:sizeof(r2)   atIndex:15];
+                            [encoder setBytes:&r3      length:sizeof(r3)   atIndex:16];
                             [encoder setThreadgroupMemoryLength:8192 atIndex:0];
                             [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
                         } else {
@@ -1986,20 +1990,22 @@ static void ggml_metal_encode_node(
                             [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
                             [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
                             [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
-                            [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
-                            [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
-                            [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11];
-                            [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12];
-                            [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13];
-                            [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
-                            [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:15];
-                            [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:16];
-                            [encoder setBytes:&r2   length:sizeof(r2)   atIndex:17];
-                            [encoder setBytes:&r3   length:sizeof(r3)   atIndex:18];
+                            [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
+                            [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
+                            [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
+                            [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
+                            [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:13];
+                            [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:14];
+                            [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:15];
+                            [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:16];
+                            [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:17];
+                            [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:18];
+                            [encoder setBytes:&r2   length:sizeof(r2)   atIndex:19];
+                            [encoder setBytes:&r3   length:sizeof(r3)   atIndex:20];
 
                             if (src0t == GGML_TYPE_Q4_0  || src0t == GGML_TYPE_Q4_1  || src0t == GGML_TYPE_Q5_0 ||
-                                    src0t == GGML_TYPE_Q5_1  || src0t == GGML_TYPE_Q8_0  || src0t == GGML_TYPE_Q2_K ||
-                                    src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
+                                src0t == GGML_TYPE_Q5_1  || src0t == GGML_TYPE_Q8_0  || src0t == GGML_TYPE_Q2_K ||
+                                src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
                                 [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                             }
                             else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
@@ -2048,6 +2054,9 @@ static void ggml_metal_encode_node(
 
                 GGML_ASSERT(src1t == GGML_TYPE_F32);
 
+                GGML_ASSERT(ne03 == 1);
+                GGML_ASSERT(ne13 == 1);
+
                 // find the break-even point where the matrix-matrix kernel becomes more efficient compared
                 // to the matrix-vector kernel
                 // ne20 = n_used_experts

ファイルの差分が大きいため隠しています
+ 221 - 122
ggml/src/ggml-metal.metal


この差分においてかなりの量のファイルが変更されているため、一部のファイルを表示していません