Просмотр исходного кода

enhance argsort for UT (#17573)

Co-authored-by: Neo Zhang <zhang.jianyu@outlook.com>
Neo Zhang Jianyu 1 месяц назад
Родитель
Сommit
98bd9ab1e4
1 измененных файлов с 7 добавлено и 1 удалено
  1. 7 1
      ggml/src/ggml-sycl/ggml-sycl.cpp

+ 7 - 1
ggml/src/ggml-sycl/ggml-sycl.cpp

@@ -1787,6 +1787,7 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
     const sycl::range<3> block_dims(1, 1, nth);
     const sycl::range<3> block_nums(1, nrows, 1);
     const size_t shared_mem = ncols_pad * sizeof(int);
+    GGML_ASSERT(shared_mem<=ggml_sycl_info().devices[device].smpbo);
 
     if (order == GGML_SORT_ORDER_ASC) {
         stream->submit([&](sycl::handler &cgh) {
@@ -4348,6 +4349,9 @@ static ggml_backend_buffer_t ggml_backend_sycl_device_buffer_from_host_ptr(ggml_
 }
 
 static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
+    ggml_backend_sycl_device_context *sycl_ctx =
+        (ggml_backend_sycl_device_context *)dev->context;
+    int device = sycl_ctx->device;
     switch (op->op) {
         case GGML_OP_CONV_TRANSPOSE_1D:
             {
@@ -4601,8 +4605,10 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_SUM:
         case GGML_OP_SUM_ROWS:
         case GGML_OP_MEAN:
-        case GGML_OP_ARGSORT:
             return ggml_is_contiguous(op->src[0]);
+        case GGML_OP_ARGSORT:
+            return op->src[0]->ne[0] * sizeof(int) <=
+                   ggml_sycl_info().devices[device].smpbo;
         case GGML_OP_POOL_2D:
         case GGML_OP_ACC:
             return true;