Просмотр исходного кода

cuda : fix argsort with 64k+ rows (#16849)

Sigbjørn Skjæret 2 месяцев назад
Родитель
Сommit
229bf68628
2 измененных файлов с 4 добавлено и 3 удалено
  1. 2 2
      ggml/src/ggml-cuda/argsort.cu
  2. 2 1
      tests/test-backend-ops.cpp

+ 2 - 2
ggml/src/ggml-cuda/argsort.cu

@@ -87,7 +87,7 @@ template<ggml_sort_order order>
 static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad) {
     // bitonic sort
     int col = threadIdx.x;
-    int row = blockIdx.y;
+    int row = blockIdx.x;
 
     if (col >= ncols_pad) {
         return;
@@ -151,7 +151,7 @@ static void argsort_f32_i32_cuda_bitonic(const float *   x,
     const int ncols_pad = next_power_of_2(ncols);
 
     const dim3 block_dims(ncols_pad, 1, 1);
-    const dim3 block_nums(1, nrows, 1);
+    const dim3 block_nums(nrows, 1, 1);
     const size_t shared_mem = ncols_pad * sizeof(int);
 
     // FIXME: this limit could be raised by ~2-4x on Ampere or newer

+ 2 - 1
tests/test-backend-ops.cpp

@@ -7111,7 +7111,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
         test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order));
         test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {60, 10, 10, 10}, order)); // qwen
         test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1024, 1, 1, 1}, order));
-        test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16384, 1, 1, 1}, order)); // bailingmoe2 (group selection)
+        test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16384, 1, 1, 1}, order)); // many backends only handle up to 1024
+        test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2, 8, 8192, 1}, order)); // bailingmoe2 (group selection)
     }
 
     for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR}) {