|
|
@@ -7,8 +7,9 @@
|
|
|
#include "unary-ops.h"
|
|
|
#include "vec.h"
|
|
|
|
|
|
-#include <float.h>
|
|
|
+#include <cfloat>
|
|
|
#include <algorithm>
|
|
|
+#include <functional>
|
|
|
|
|
|
// ggml_compute_forward_dup
|
|
|
|
|
|
@@ -7682,24 +7683,24 @@ static void ggml_compute_forward_argsort_f32(
|
|
|
ggml_sort_order order = (ggml_sort_order) ggml_get_op_params_i32(dst, 0);
|
|
|
|
|
|
for (int64_t i = ith; i < nr; i += nth) {
|
|
|
- int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
|
|
|
const float * src_data = (float *)((char *) src0->data + i*nb01);
|
|
|
|
|
|
+ int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
|
|
|
+
|
|
|
for (int64_t j = 0; j < ne0; j++) {
|
|
|
dst_data[j] = j;
|
|
|
}
|
|
|
|
|
|
- // C doesn't have a functional sort, so we do a bubble sort instead
|
|
|
- for (int64_t j = 0; j < ne0; j++) {
|
|
|
- for (int64_t k = j + 1; k < ne0; k++) {
|
|
|
- if ((order == GGML_SORT_ORDER_ASC && src_data[dst_data[j]] > src_data[dst_data[k]]) ||
|
|
|
- (order == GGML_SORT_ORDER_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) {
|
|
|
- int32_t tmp = dst_data[j];
|
|
|
- dst_data[j] = dst_data[k];
|
|
|
- dst_data[k] = tmp;
|
|
|
- }
|
|
|
- }
|
|
|
+ std::function<bool(int32_t, int32_t)> cmp;
|
|
|
+
|
|
|
+ // note: this might be causing memory allocations? ideally should be avoided if it's the case
|
|
|
+ switch (order) {
|
|
|
+ case GGML_SORT_ORDER_ASC: cmp = [src_data](int32_t a, int32_t b) { return src_data[a] < src_data[b]; }; break;
|
|
|
+ case GGML_SORT_ORDER_DESC: cmp = [src_data](int32_t a, int32_t b) { return src_data[a] > src_data[b]; }; break;
|
|
|
+ default: GGML_ABORT("invalid sort order");
|
|
|
}
|
|
|
+
|
|
|
+ std::sort(dst_data, dst_data + ne0, cmp);
|
|
|
}
|
|
|
}
|
|
|
|