Procházet zdrojové kódy

metal : faster argsort (#17315)

* metal : faster argsort

* cont : keep data in registers
Georgi Gerganov před 2 měsíci
rodič
revize
3347e6d904

+ 0 - 2
ggml/src/ggml-metal/ggml-metal-ops.cpp

@@ -3726,8 +3726,6 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
         ggml_metal_encoder_set_buffer  (enc, bid_dst,  2);
         ggml_metal_encoder_set_buffer  (enc, bid_tmp,  3);
 
-        ggml_metal_encoder_set_threadgroup_memory_size(enc, 0, 0);
-
         ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1);
 
         std::swap(bid_dst, bid_tmp);

+ 87 - 39
ggml/src/ggml-metal/ggml-metal.metal

@@ -4739,12 +4739,13 @@ kernel void kernel_argsort_merge_f32_i32(
         uint3   tgpig[[threadgroup_position_in_grid]],
         ushort3 tpitg[[thread_position_in_threadgroup]],
         ushort3   ntg[[threads_per_threadgroup]]) {
-    int im  = tgpig[0] / args.ne01;
-    int i01 = tgpig[0] % args.ne01;
-    int i02 = tgpig[1];
-    int i03 = tgpig[2];
 
-    const int start = im * (2*args.len);
+    const int im  = tgpig[0] / args.ne01;
+    const int i01 = tgpig[0] % args.ne01;
+    const int i02 = tgpig[1];
+    const int i03 = tgpig[2];
+
+    const int start = im * (2 * args.len);
 
     const int len0 = MIN(args.len, MAX(0, args.ne00 - (int)(start)));
     const int len1 = MIN(args.len, MAX(0, args.ne00 - (int)(start + args.len)));
@@ -4768,54 +4769,101 @@ kernel void kernel_argsort_merge_f32_i32(
         + args.nb02*i02
         + args.nb03*i03);
 
-    for (int k = tpitg.x; k < (int) total; k += ntg.x) {
-        // find partition (i,j) such that i+j = k
-        int low  = k > len1 ? k - len1 : 0;
-        int high = MIN(k, len0);
+    if (total == 0) {
+        return;
+    }
 
-        while (low < high) {
-            const int mid = (low + high) >> 1;
+    const int chunk = (total + ntg.x - 1) / ntg.x;
 
-            const int32_t idx0 = tmp0[mid];
-            const int32_t idx1 = tmp1[k - mid - 1];
+    const int k0 = tpitg.x * chunk;
+    const int k1 = min(k0 + chunk, total);
 
-            const float val0 = src0_row[idx0];
-            const float val1 = src0_row[idx1];
+    if (k0 >= total) {
+        return;
+    }
 
-            if (order == GGML_SORT_ORDER_ASC) {
-                if (val0 <= val1) {
-                    low = mid + 1;
-                } else {
-                    high = mid;
-                }
-            } else {
-                if (val0 >= val1) {
-                    low = mid + 1;
-                } else {
-                    high = mid;
-                }
-            }
+    int low  = k0 > len1 ? k0 - len1 : 0;
+    int high = MIN(k0, len0);
+
+    // binary-search partition (i, j) such that i + j = k
+    while (low < high) {
+        const int mid = (low + high) >> 1;
+
+        const int32_t idx0 = tmp0[mid];
+        const int32_t idx1 = tmp1[k0 - mid - 1];
+
+        const float val0 = src0_row[idx0];
+        const float val1 = src0_row[idx1];
+
+        bool take_left;
+        if (order == GGML_SORT_ORDER_ASC) {
+            take_left = (val0 <= val1);
+        } else {
+            take_left = (val0 >= val1);
         }
 
-        const int i = low;
-        const int j = k - i;
+        if (take_left) {
+            low = mid + 1;
+        } else {
+            high = mid;
+        }
+    }
+
+    int i = low;
+    int j = k0 - i;
+
+    // keep the merge fronts into registers
+    int32_t idx0 = 0;
+    float   val0 = 0.0f;
+    if (i < len0) {
+        idx0 = tmp0[i];
+        val0 = src0_row[idx0];
+    }
+
+    int32_t idx1 = 0;
+    float   val1 = 0.0f;
+    if (j < len1) {
+        idx1 = tmp1[j];
+        val1 = src0_row[idx1];
+    }
 
+    for (int k = k0; k < k1; ++k) {
         int32_t out_idx;
 
         if (i >= len0) {
-            out_idx = tmp1[j];
+            while (k < k1) {
+                dst[k++] = tmp1[j++];
+            }
+            break;
         } else if (j >= len1) {
-            out_idx = tmp0[i];
+            while (k < k1) {
+                dst[k++] = tmp0[i++];
+            }
+            break;
         } else {
-            const int32_t idx0 = tmp0[i];
-            const int32_t idx1 = tmp1[j];
+            bool take_left;
 
-            const float val0 = src0_row[idx0];
-            const float val1 = src0_row[idx1];
+            if (order == GGML_SORT_ORDER_ASC) {
+                take_left = (val0 <= val1);
+            } else {
+                take_left = (val0 >= val1);
+            }
 
-            out_idx = (order == GGML_SORT_ORDER_ASC)
-                ? (val0 <= val1 ? idx0 : idx1)
-                : (val0 >= val1 ? idx0 : idx1);
+            if (take_left) {
+                out_idx = idx0;
+                ++i;
+                if (i < len0) {
+                    idx0 = tmp0[i];
+                    val0 = src0_row[idx0];
+                }
+            } else {
+                out_idx = idx1;
+                ++j;
+                if (j < len1) {
+                    idx1 = tmp1[j];
+                    val1 = src0_row[idx1];
+                }
+            }
         }
 
         dst[k] = out_idx;