소스 검색

ggml-cpu : use template for argsort (#17222)

Diego Devesa 2 달 전
부모
커밋
879dec341a
2개의 변경된 파일24개의 추가작업 그리고 8개의 파일을 삭제
  1. 22 8
      ggml/src/ggml-cpu/ops.cpp
  2. 2 0
      tests/test-backend-ops.cpp

+ 22 - 8
ggml/src/ggml-cpu/ops.cpp

@@ -7665,6 +7665,18 @@ void ggml_compute_forward_timestep_embedding(
 
 // ggml_compute_forward_argsort
 
+template<enum ggml_sort_order order>
+struct argsort_cmp {
+    const float * data;
+    bool operator()(int32_t a, int32_t b) const {
+        if constexpr (order == GGML_SORT_ORDER_ASC) {
+            return data[a] < data[b];
+        } else {
+            return data[a] > data[b];
+        }
+    }
+};
+
 static void ggml_compute_forward_argsort_f32(
     const ggml_compute_params * params,
     ggml_tensor * dst) {
@@ -7691,16 +7703,18 @@ static void ggml_compute_forward_argsort_f32(
             dst_data[j] = j;
         }
 
-        std::function<bool(int32_t, int32_t)> cmp;
-
-        // note: this might be causing memory allocations? ideally should be avoided if it's the case
         switch (order) {
-            case GGML_SORT_ORDER_ASC:  cmp = [src_data](int32_t a, int32_t b) { return src_data[a] < src_data[b]; }; break;
-            case GGML_SORT_ORDER_DESC: cmp = [src_data](int32_t a, int32_t b) { return src_data[a] > src_data[b]; }; break;
-            default: GGML_ABORT("invalid sort order");
-        }
+            case GGML_SORT_ORDER_ASC:
+                std::sort(dst_data, dst_data + ne0, argsort_cmp<GGML_SORT_ORDER_ASC>{src_data});
+                break;
 
-        std::sort(dst_data, dst_data + ne0, cmp);
+            case GGML_SORT_ORDER_DESC:
+                std::sort(dst_data, dst_data + ne0, argsort_cmp<GGML_SORT_ORDER_DESC>{src_data});
+                break;
+
+            default:
+                GGML_ABORT("invalid sort order");
+        }
     }
 }
 

+ 2 - 0
tests/test-backend-ops.cpp

@@ -7631,6 +7631,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
         test_cases.emplace_back(new test_sum(GGML_TYPE_F32, it));
     }
 
+    test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {65000, 16, 1, 1}));
+
     return test_cases;
 }