Просмотр исходного кода

ggml : use std::sort in ggml_argsort CPU implementation (#17211)

* ggml : use std::sort in ggml_argsort CPU implementation

* cont : add missing header
Georgi Gerganov 2 месяцев назад
Родитель
Сommit
374fe09cdd
1 измененных файлов с 13 добавлено и 12 удалено
  1. 13 12
      ggml/src/ggml-cpu/ops.cpp

+ 13 - 12
ggml/src/ggml-cpu/ops.cpp

@@ -7,8 +7,9 @@
 #include "unary-ops.h"
 #include "vec.h"
 
-#include <float.h>
+#include <cfloat>
 #include <algorithm>
+#include <functional>
 
 // ggml_compute_forward_dup
 
@@ -7682,24 +7683,24 @@ static void ggml_compute_forward_argsort_f32(
     ggml_sort_order order = (ggml_sort_order) ggml_get_op_params_i32(dst, 0);
 
     for (int64_t i = ith; i < nr; i += nth) {
-        int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
         const float * src_data = (float *)((char *) src0->data + i*nb01);
 
+        int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
+
         for (int64_t j = 0; j < ne0; j++) {
             dst_data[j] = j;
         }
 
-        // C doesn't have a functional sort, so we do a bubble sort instead
-        for (int64_t j = 0; j < ne0; j++) {
-            for (int64_t k = j + 1; k < ne0; k++) {
-                if ((order == GGML_SORT_ORDER_ASC  && src_data[dst_data[j]] > src_data[dst_data[k]]) ||
-                    (order == GGML_SORT_ORDER_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) {
-                    int32_t tmp = dst_data[j];
-                    dst_data[j] = dst_data[k];
-                    dst_data[k] = tmp;
-                }
-            }
+        std::function<bool(int32_t, int32_t)> cmp;
+
+        // note: this might be causing memory allocations? ideally should be avoided if it's the case
+        switch (order) {
+            case GGML_SORT_ORDER_ASC:  cmp = [src_data](int32_t a, int32_t b) { return src_data[a] < src_data[b]; }; break;
+            case GGML_SORT_ORDER_DESC: cmp = [src_data](int32_t a, int32_t b) { return src_data[a] > src_data[b]; }; break;
+            default: GGML_ABORT("invalid sort order");
         }
+
+        std::sort(dst_data, dst_data + ne0, cmp);
     }
 }