Prechádzať zdrojové kódy

ggml : check cuda and metal argsort limits and add test (#16323)

* check cuda argsort limits and add test

* add metal check
Sigbjørn Skjæret 3 mesiacov pred
rodič
commit
adc76347d7

+ 3 - 1
ggml/src/ggml-cuda/ggml-cuda.cu

@@ -3639,9 +3639,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_CONV_TRANSPOSE_2D:
         case GGML_OP_CONV_TRANSPOSE_2D:
         case GGML_OP_POOL_2D:
         case GGML_OP_POOL_2D:
         case GGML_OP_SUM:
         case GGML_OP_SUM:
-        case GGML_OP_ARGSORT:
         case GGML_OP_ACC:
         case GGML_OP_ACC:
             return true;
             return true;
+        case GGML_OP_ARGSORT:
+            // TODO: Support arbitrary column width
+            return op->src[0]->ne[0] <= 1024;
         case GGML_OP_SUM_ROWS:
         case GGML_OP_SUM_ROWS:
         case GGML_OP_MEAN:
         case GGML_OP_MEAN:
         case GGML_OP_GROUP_NORM:
         case GGML_OP_GROUP_NORM:

+ 3 - 1
ggml/src/ggml-metal/ggml-metal-device.m

@@ -683,9 +683,11 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
                    (ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0);
                    (ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0);
         case GGML_OP_PAD_REFLECT_1D:
         case GGML_OP_PAD_REFLECT_1D:
         case GGML_OP_TIMESTEP_EMBEDDING:
         case GGML_OP_TIMESTEP_EMBEDDING:
-        case GGML_OP_ARGSORT:
         case GGML_OP_LEAKY_RELU:
         case GGML_OP_LEAKY_RELU:
             return op->src[0]->type == GGML_TYPE_F32;
             return op->src[0]->type == GGML_TYPE_F32;
+        case GGML_OP_ARGSORT:
+            // TODO: Support arbitrary column width
+            return op->src[0]->ne[0] <= 1024;
         case GGML_OP_ARANGE:
         case GGML_OP_ARANGE:
             return true;
             return true;
         case GGML_OP_FLASH_ATTN_EXT:
         case GGML_OP_FLASH_ATTN_EXT:

+ 1 - 0
tests/test-backend-ops.cpp

@@ -6567,6 +6567,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
         test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order));
         test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order));
         test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {60, 10, 10, 10}, order)); // qwen
         test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {60, 10, 10, 10}, order)); // qwen
         test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1024, 1, 1, 1}, order));
         test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1024, 1, 1, 1}, order));
+        test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16384, 1, 1, 1}, order)); // bailingmoe2 (group selection)
     }
     }
 
 
     for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR}) {
     for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR}) {