Przeglądaj źródła

CUDA: properly handle nb00=nb02 case for cpy (#17081)

bssrdf 2 miesięcy temu
rodzic
commit
299f5d782c
2 zmienionych plików z 2 dodań i 3 usunięć
  1. 1 3
      ggml/src/ggml-cuda/cpy.cu
  2. 1 0
      tests/test-backend-ops.cpp

+ 1 - 3
ggml/src/ggml-cuda/cpy.cu

@@ -198,7 +198,7 @@ static void ggml_cpy_flt_cuda(
     if (transposed) {
         GGML_ASSERT(ne == ne00*ne01*ne02);  // ne[3] is 1 assumed
         int ne00n, ne01n, ne02n;
-        if (nb00 < nb02) {
+        if (nb00 <= nb02) { // most likely safe to handle nb00 = nb02 case here
             ne00n = ne00;
             ne01n = ne01;
             ne02n = ne02;
@@ -206,8 +206,6 @@ static void ggml_cpy_flt_cuda(
             ne00n = ne00;
             ne01n = ne01*ne02;
             ne02n = 1;
-        } else {
-            GGML_ASSERT(false);
         }
 
         dim3 dimGrid( (ne01n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D,

+ 1 - 0
tests/test-backend-ops.cpp

@@ -6648,6 +6648,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
     test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
     test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
     test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
+    test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 1, 4, 1}, {1, 2, 0, 3}, {0, 0, 0, 0}));
 
     test_cases.emplace_back(new test_cont());
     test_cases.emplace_back(new test_cont(GGML_TYPE_F32, {2, 1, 1 ,1}));