|
@@ -7665,6 +7665,18 @@ void ggml_compute_forward_timestep_embedding(
|
|
|
|
|
|
|
|
// ggml_compute_forward_argsort
|
|
// ggml_compute_forward_argsort
|
|
|
|
|
|
|
|
|
|
+template<enum ggml_sort_order order>
|
|
|
|
|
+struct argsort_cmp {
|
|
|
|
|
+ const float * data;
|
|
|
|
|
+ bool operator()(int32_t a, int32_t b) const {
|
|
|
|
|
+ if constexpr (order == GGML_SORT_ORDER_ASC) {
|
|
|
|
|
+ return data[a] < data[b];
|
|
|
|
|
+ } else {
|
|
|
|
|
+ return data[a] > data[b];
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+};
|
|
|
|
|
+
|
|
|
static void ggml_compute_forward_argsort_f32(
|
|
static void ggml_compute_forward_argsort_f32(
|
|
|
const ggml_compute_params * params,
|
|
const ggml_compute_params * params,
|
|
|
ggml_tensor * dst) {
|
|
ggml_tensor * dst) {
|
|
@@ -7691,16 +7703,18 @@ static void ggml_compute_forward_argsort_f32(
|
|
|
dst_data[j] = j;
|
|
dst_data[j] = j;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- std::function<bool(int32_t, int32_t)> cmp;
|
|
|
|
|
-
|
|
|
|
|
- // note: this might be causing memory allocations? ideally should be avoided if it's the case
|
|
|
|
|
switch (order) {
|
|
switch (order) {
|
|
|
- case GGML_SORT_ORDER_ASC: cmp = [src_data](int32_t a, int32_t b) { return src_data[a] < src_data[b]; }; break;
|
|
|
|
|
- case GGML_SORT_ORDER_DESC: cmp = [src_data](int32_t a, int32_t b) { return src_data[a] > src_data[b]; }; break;
|
|
|
|
|
- default: GGML_ABORT("invalid sort order");
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ case GGML_SORT_ORDER_ASC:
|
|
|
|
|
+ std::sort(dst_data, dst_data + ne0, argsort_cmp<GGML_SORT_ORDER_ASC>{src_data});
|
|
|
|
|
+ break;
|
|
|
|
|
|
|
|
- std::sort(dst_data, dst_data + ne0, cmp);
|
|
|
|
|
|
|
+ case GGML_SORT_ORDER_DESC:
|
|
|
|
|
+ std::sort(dst_data, dst_data + ne0, argsort_cmp<GGML_SORT_ORDER_DESC>{src_data});
|
|
|
|
|
+ break;
|
|
|
|
|
+
|
|
|
|
|
+ default:
|
|
|
|
|
+ GGML_ABORT("invalid sort order");
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|