Bläddra i källkod

sycl : unify unary kernels with a generic implementation and enable wide operator support (#17213)

* SYCL: add generic unary op implementation for multiple ops (ABS/SGN/…); unify non-contiguous access

* SYCL: update documentation and sycl.csv to reflect new unary op support

* update ops.md after syncing SYCL.csv changes

* Fix SYCL.csv merge conflict

* Update ops.md after fixing SYCL.csv conflicts

* Fix SYCL.csv tail after merge conflict and regenerate ops.md

* Fix line endings and final newline in SYCL.csv

* Remove TOPK_MOE entries from SYCL.csv as requested

* Update ops.md after removing TOPK_MOE from SYCL.csv

* Regenerated SYCL.csv and synced ops.md with upstream

* Update ops.md using create_ops_docs.py
shani-f 2 månader sedan
förälder
incheckning
72bd7321a7
4 ändrade filer med 513 tillägg och 424 borttagningar
  1. 25 26
      docs/ops.md
  2. 370 143
      docs/ops/SYCL.csv
  3. 112 250
      ggml/src/ggml-sycl/element_wise.cpp
  4. 6 5
      ggml/src/ggml-sycl/ggml-sycl.cpp

+ 25 - 26
docs/ops.md

@@ -14,7 +14,7 @@ Legend:
 
 | Operation | BLAS | CANN | CPU | CUDA | Metal | OpenCL | SYCL | Vulkan | zDNN |
 |-----------|------|------|------|------|------|------|------|------|------|
-|                              ABS | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
+|                              ABS | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ |  | 🟡 | ❌ |
 |                              ACC | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
 |                              ADD | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
 |                             ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
@@ -23,7 +23,7 @@ Legend:
 |                           ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
 |                          ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ❌ |
 |                             CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
-|                            CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 |  | 🟡 | ❌ |
+|                            CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
 |                           CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ❌ |
 |                             CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ❌ |
 |                          CONV_2D | ❌ | ❌ | ✅ | ✅ | ❌ | ✅ | ❌ | ✅ | ❌ |
@@ -31,7 +31,7 @@ Legend:
 |                          CONV_3D | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
 |                CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
 |                CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
-|                              COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ |  | 🟡 | ❌ |
+|                              COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
 |                      COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
 |                              CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
 |               CROSS_ENTROPY_LOSS | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
@@ -40,8 +40,8 @@ Legend:
 |                    DIAG_MASK_INF | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
 |                              DIV | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
 |                              DUP | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
-|                              ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
-|                              EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
+|                              ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ |  | ❌ | ❌ |
+|                              EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ |  | 🟡 | ❌ |
 |                            EXPM1 | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ |
 |                             FILL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
 |                   FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ |
@@ -50,27 +50,27 @@ Legend:
 |                            GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
 |                        GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
 |                      GEGLU_QUICK | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
-|                             GELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
-|                         GELU_ERF | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
-|                       GELU_QUICK | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
+|                             GELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 |  | 🟡 | ❌ |
+|                         GELU_ERF | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 |  | 🟡 | ❌ |
+|                       GELU_QUICK | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 |  | 🟡 | ❌ |
 |                         GET_ROWS | ❌ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | ❌ |
 |                    GET_ROWS_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ |
 |                       GROUP_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
-|               GROUP_NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |  | ❌ | ❌ |
-|                      HARDSIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
-|                        HARDSWISH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
+|               GROUP_NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |  | ❌ | ❌ |
+|                      HARDSIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ |  | 🟡 | ❌ |
+|                        HARDSWISH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ |  | 🟡 | ❌ |
 |                           IM2COL | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ |
 |                        IM2COL_3D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
 |                          L2_NORM | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
 |                       LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | 🟡 | ❌ |
-|                              LOG | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ |  | ❌ | ❌ |
+|                              LOG | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ | ❌ |
 |                             MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
 |                              MUL | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
 |                          MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |
 |                       MUL_MAT_ID | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ❌ |
-|                              NEG | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
+|                              NEG | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ |  | 🟡 | ❌ |
 |                             NORM | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
-|                     NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |  | ❌ | ❌ |
+|                     NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |  | ❌ | ❌ |
 |                   OPT_STEP_ADAMW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
 |                     OPT_STEP_SGD | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
 |                         OUT_PROD | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
@@ -78,12 +78,12 @@ Legend:
 |                   PAD_REFLECT_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ |
 |                          POOL_2D | ❌ | 🟡 | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
 |                            REGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
-|                             RELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
+|                             RELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 |  | 🟡 | ❌ |
 |                           REPEAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | ❌ |
 |                      REPEAT_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
 |                         RMS_NORM | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ |
 |                    RMS_NORM_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
-|                 RMS_NORM_MUL_ADD | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ |  | ❌ | ❌ |
+|                 RMS_NORM_MUL_ADD | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ |  | ❌ | ❌ |
 |                             ROLL | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
 |                             ROPE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
 |                        ROPE_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
@@ -93,29 +93,28 @@ Legend:
 |                            SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
 |                              SET | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | 🟡 | ❌ | ❌ |
 |                         SET_ROWS | ❌ | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
-|                              SGN | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
-|                          SIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
-|                             SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
+|                              SGN | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ |  | ❌ | ❌ |
+|                          SIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 |  | 🟡 | ❌ |
+|                             SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 |  | 🟡 | ❌ |
 |                        SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
-|                              SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ |  | 🟡 | ❌ |
-|                          SOFTCAP | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |  | ❌ | ❌ |
+|                              SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
+|                          SOFTCAP | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |  | ❌ | ❌ |
 |                         SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ |
 |                         SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
 |                    SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ |
 |                        SOLVE_TRI | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
-|                              SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ |  | 🟡 | ❌ |
-|                             SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ |  | 🟡 | ❌ |
+|                              SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
+|                             SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
 |                         SSM_CONV | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
 |                         SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ |
-|                             STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
+|                             STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ |  | ❌ | ❌ |
 |                              SUB | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
 |                              SUM | ❌ | ✅ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ |
 |                         SUM_ROWS | ❌ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ |
 |                           SWIGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
 |                       SWIGLU_OAI | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | 🟡 | ❌ |
-|                             TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | 🟡 | ❌ |
+|                             TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ |  | 🟡 | ❌ |
 |               TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
-|                         TOPK_MOE | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
 |                              TRI | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
 |                            TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
 |                          UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ |

Filskillnaden har hållts tillbaka eftersom den är för stor
+ 370 - 143
docs/ops/SYCL.csv


+ 112 - 250
ggml/src/ggml-sycl/element_wise.cpp

@@ -170,73 +170,31 @@ static __dpct_inline__ T op_trunc(T x) {
     return sycl::trunc(x);
 }
 
-template<typename T>
-static void unary_op_sgn_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
-    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
-        dst[i] = op_sgn(x[i]);
-    }
-}
-
-template<typename T>
-static void unary_op_abs_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
-    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
-        dst[i] = op_abs(x[i]);
-    }
-}
-
-template<typename T>
-static void unary_op_elu_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
-    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
-        dst[i] = op_elu(x[i]);
-    }
-}
-
-template<typename T>
-static void unary_op_gelu_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
-    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
-        dst[i] = op_gelu(x[i]);
-    }
-}
-
-template<typename T>
-static void unary_op_silu_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
-    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
-        dst[i] = op_silu(x[i]);
-    }
-}
-
-template<typename T>
-static void unary_op_gelu_quick_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
-    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
-        dst[i] = op_gelu_quick(x[i]);
-    }
-}
-
-template<typename T>
-static void unary_op_gelu_erf_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
+template<typename T, typename F>
+static void unary_op_generic_kernel(
+        const T * x,
+        T * dst,
+        const int k,
+        const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3,
+        const size_t nb0,  const size_t nb1,  const size_t nb2,  const size_t nb3,
+        const size_t nbd0, const size_t nbd1, const size_t nbd2, const size_t nbd3,
+        const sycl::nd_item<1> & item_ct1,
+        F func) {
+
+        (void) ne3;
     SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
-        dst[i] = op_gelu_erf(x[i]);
-    }
-}
+        const int64_t i0 =  i % ne0;
+        const int64_t i1 = (i / ne0)        % ne1;
+        const int64_t i2 = (i / (ne0*ne1))  % ne2;
+        const int64_t i3 =  i / (ne0*ne1*ne2);
 
-template<typename T>
-static void unary_op_tanh_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
-    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
-        dst[i] = op_tanh(x[i]);
-    }
-}
+        const char * src_base = (const char *) x;
+        char       * dst_base = (char *) dst;
 
-template<typename T>
-static void unary_op_relu_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
-    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
-        dst[i] = op_relu(x[i]);
-    }
-}
+        const T * srcp = (const T *)(src_base + i0*nb0  + i1*nb1  + i2*nb2  + i3*nb3 );
+        T *       dstp = (T *)(dst_base + i0*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3);
 
-template<typename T>
-static void unary_op_sigmoid_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
-    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
-        dst[i] = op_sigmoid(x[i]);
+        *dstp = func(*srcp);
     }
 }
 
@@ -261,27 +219,6 @@ static void unary_op_cos_kernel(const T * x, T * dst, const int k, const sycl::n
     }
 }
 
-template<typename T>
-static void unary_op_hardsigmoid_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
-    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
-        dst[i] = op_hardsigmoid(x[i]);
-    }
-}
-
-template<typename T>
-static void unary_op_hardswish_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
-    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
-        dst[i] = op_hardswish(x[i]);
-    }
-}
-
-template<typename T>
-static void unary_op_exp_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
-    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
-        dst[i] = op_exp(x[i]);
-    }
-}
-
 template<typename T>
 static void unary_op_log_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
     SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
@@ -289,19 +226,6 @@ static void unary_op_log_kernel(const T * x, T * dst, const int k, const sycl::n
     }
 }
 
-template<typename T>
-static void unary_op_neg_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
-    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
-        dst[i] = op_neg(x[i]);
-    }
-}
-
-template<typename T>
-static void unary_op_step_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
-    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
-        dst[i] = op_step(x[i]);
-    }
-}
 
 template<typename T>
 static void unary_op_leaky_relu_kernel(const T * x, T * dst, const int k, float negative_slope, const sycl::nd_item<1> &item_ct1) {
@@ -620,6 +544,48 @@ static inline void dispatch_ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx
     }
 }
 
+template<typename F>
+static inline void ggml_sycl_op_unary(
+        ggml_backend_sycl_context & ctx, ggml_tensor * dst, F func) {
+
+    ggml_tensor * src0 = dst->src[0];
+
+    const int64_t ne0  = dst->ne[0];
+    const int64_t ne1  = dst->ne[1];
+    const int64_t ne2  = dst->ne[2];
+    const int64_t ne3  = dst->ne[3];
+
+    const size_t  nb0  = src0->nb[0];
+    const size_t  nb1  = src0->nb[1];
+    const size_t  nb2  = src0->nb[2];
+    const size_t  nb3  = src0->nb[3];
+
+    const size_t  nbd0 = dst->nb[0];
+    const size_t  nbd1 = dst->nb[1];
+    const size_t  nbd2 = dst->nb[2];
+    const size_t  nbd3 = dst->nb[3];
+
+    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
+        [=](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
+
+            const int num_blocks = ceil_div(k_elements, 256);
+
+            stream->parallel_for(
+                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
+                                  sycl::range<1>(256)),
+                [=](sycl::nd_item<1> item_ct1) {
+                    unary_op_generic_kernel(
+                        src, dst_ptr, k_elements,
+                        ne0, ne1, ne2, ne3,
+                        nb0, nb1, nb2, nb3,
+                        nbd0, nbd1, nbd2, nbd3,
+                        item_ct1,
+                        func
+                    );
+                });
+        });
+}
+
 
 static inline void ggml_sycl_op_arange(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     GGML_ASSERT(dst->type == GGML_TYPE_F32);
@@ -645,159 +611,75 @@ static inline void ggml_sycl_op_arange(ggml_backend_sycl_context & ctx, ggml_ten
 
 
 static inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
-    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
-        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
-            const int num_blocks = ceil_div(k_elements, 256);
-            stream->parallel_for(
-                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
-                                  sycl::range<1>(256)),
-                [=](sycl::nd_item<1> item_ct1) {
-                    unary_op_sgn_kernel(src, dst_ptr, k_elements, item_ct1);
-                });
-        });
+    ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
+        return op_sgn(x);
+    });
 }
 
+
 static inline void ggml_sycl_op_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
-    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
-        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
-            const int num_blocks = ceil_div(k_elements, 256);
-            stream->parallel_for(
-                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
-                                  sycl::range<1>(256)),
-                [=](sycl::nd_item<1> item_ct1) {
-                    unary_op_abs_kernel(src, dst_ptr, k_elements, item_ct1);
-                });
-        });
+    ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
+        return op_abs(x);
+    });
 }
 
 static inline void ggml_sycl_op_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
-    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
-        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
-            const int num_blocks = ceil_div(k_elements, 256);
-            stream->parallel_for(
-                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
-                                  sycl::range<1>(256)),
-                [=](sycl::nd_item<1> item_ct1) {
-                    unary_op_elu_kernel(src, dst_ptr, k_elements, item_ct1);
-                });
-        });
+    ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
+        return op_elu(x);
+    });
 }
-
 static inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
-    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
-        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
-            const int num_blocks = ceil_div(k_elements, SYCL_SILU_BLOCK_SIZE);
-            stream->parallel_for(
-                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SILU_BLOCK_SIZE),
-                                  sycl::range<1>(SYCL_SILU_BLOCK_SIZE)),
-                [=](sycl::nd_item<1> item_ct1) {
-                    unary_op_silu_kernel(src, dst_ptr, k_elements, item_ct1);
-                });
-        });
+    ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
+        return op_silu(x);
+    });
 }
 
 static inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
-    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
-        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
-            const int num_blocks = ceil_div(k_elements, SYCL_GELU_BLOCK_SIZE);
-            stream->parallel_for(
-                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_GELU_BLOCK_SIZE),
-                                  sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
-                [=](sycl::nd_item<1> item_ct1) {
-                    unary_op_gelu_kernel(src, dst_ptr, k_elements, item_ct1);
-                });
-        });
+    ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
+        return op_gelu(x);
+    });
 }
 
-static inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
-    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
-        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
-            const int num_blocks = ceil_div(k_elements, SYCL_GELU_BLOCK_SIZE);
-            stream->parallel_for(
-                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_GELU_BLOCK_SIZE),
-                                  sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
-                [=](sycl::nd_item<1> item_ct1) {
-                    unary_op_gelu_quick_kernel(src, dst_ptr, k_elements, item_ct1);
-                });
-        });
+static inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
+        return op_gelu_quick(x);
+    });
 }
 
-static inline void ggml_sycl_op_gelu_erf(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
-    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
-        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
-            const int num_blocks = ceil_div(k_elements, SYCL_GELU_BLOCK_SIZE);
-            stream->parallel_for(
-                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_GELU_BLOCK_SIZE),
-                                  sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
-                [=](sycl::nd_item<1> item_ct1) {
-                    unary_op_gelu_erf_kernel(src, dst_ptr, k_elements, item_ct1);
-                });
-        });
+static inline void ggml_sycl_op_gelu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
+        return op_gelu_erf(x);
+    });
 }
 
 static inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
-    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
-        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
-            const int num_blocks = ceil_div(k_elements, SYCL_TANH_BLOCK_SIZE);
-            stream->parallel_for(
-                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_TANH_BLOCK_SIZE),
-                                  sycl::range<1>(SYCL_TANH_BLOCK_SIZE)),
-                [=](sycl::nd_item<1> item_ct1) {
-                    unary_op_tanh_kernel(src, dst_ptr, k_elements, item_ct1);
-                });
-        });
+    ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
+        return op_tanh(x);
+    });
 }
 
 static inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
-    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
-        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
-            const int num_blocks = ceil_div(k_elements, SYCL_RELU_BLOCK_SIZE);
-            stream->parallel_for(
-                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_RELU_BLOCK_SIZE),
-                                  sycl::range<1>(SYCL_RELU_BLOCK_SIZE)),
-                [=](sycl::nd_item<1> item_ct1) {
-                    unary_op_relu_kernel(src, dst_ptr, k_elements, item_ct1);
-                });
-        });
+    ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
+        return op_relu(x);
+    });
 }
 
 static inline void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
-    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
-        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
-            const int num_blocks = ceil_div(k_elements, SYCL_HARDSIGMOID_BLOCK_SIZE);
-            stream->parallel_for(
-                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_HARDSIGMOID_BLOCK_SIZE),
-                                  sycl::range<1>(SYCL_HARDSIGMOID_BLOCK_SIZE)),
-                [=](sycl::nd_item<1> item_ct1) {
-                    unary_op_hardsigmoid_kernel(src, dst_ptr, k_elements, item_ct1);
-                });
-        });
+    ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
+        return op_hardsigmoid(x);
+    });
 }
 
 static inline void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
-    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
-        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
-            const int num_blocks = ceil_div(k_elements, SYCL_HARDSWISH_BLOCK_SIZE);
-            stream->parallel_for(
-                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_HARDSWISH_BLOCK_SIZE),
-                                  sycl::range<1>(SYCL_HARDSWISH_BLOCK_SIZE)),
-                [=](sycl::nd_item<1> item_ct1) {
-                    unary_op_hardswish_kernel(src, dst_ptr, k_elements, item_ct1);
-                });
-        });
+    ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
+        return op_hardswish(x);
+    });
 }
 
 static inline void ggml_sycl_op_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
-    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
-        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
-            const int num_blocks = ceil_div(k_elements, SYCL_EXP_BLOCK_SIZE);
-            stream->parallel_for(
-                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_EXP_BLOCK_SIZE),
-                                  sycl::range<1>(SYCL_EXP_BLOCK_SIZE)),
-                [=](sycl::nd_item<1> item_ct1) {
-                    unary_op_exp_kernel(src, dst_ptr, k_elements, item_ct1);
-                });
-        });
+    ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
+        return op_exp(x);
+    });
 }
 
 static inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
@@ -814,42 +696,22 @@ static inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor
 }
 
 static inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
-    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
-        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
-            const int num_blocks = ceil_div(k_elements, SYCL_NEG_BLOCK_SIZE);
-            stream->parallel_for(
-                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_NEG_BLOCK_SIZE),
-                                  sycl::range<1>(SYCL_NEG_BLOCK_SIZE)),
-                [=](sycl::nd_item<1> item_ct1) {
-                    unary_op_neg_kernel(src, dst_ptr, k_elements, item_ct1);
-                });
-        });
+    ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
+        return op_neg(x);
+    });
 }
 
+
 static inline void ggml_sycl_op_step(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
-    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
-        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
-            const int num_blocks = ceil_div(k_elements, SYCL_NEG_BLOCK_SIZE); // Using NEG block size
-            stream->parallel_for(
-                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_NEG_BLOCK_SIZE),
-                                  sycl::range<1>(SYCL_NEG_BLOCK_SIZE)),
-                [=](sycl::nd_item<1> item_ct1) {
-                    unary_op_step_kernel(src, dst_ptr, k_elements, item_ct1);
-                });
-        });
+    ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
+        return op_step(x);
+    });
 }
 
 static inline void ggml_sycl_op_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
-    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
-        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
-            const int num_blocks = ceil_div(k_elements, SYCL_SIGMOID_BLOCK_SIZE);
-            stream->parallel_for(
-                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIGMOID_BLOCK_SIZE),
-                                  sycl::range<1>(SYCL_SIGMOID_BLOCK_SIZE)),
-                [=](sycl::nd_item<1> item_ct1) {
-                    unary_op_sigmoid_kernel(src, dst_ptr, k_elements, item_ct1);
-                });
-        });
+    ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
+        return op_sigmoid(x);
+    });
 }
 
 static inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {

+ 6 - 5
ggml/src/ggml-sycl/ggml-sycl.cpp

@@ -4360,21 +4360,22 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
             }
         case GGML_OP_UNARY:
             switch (ggml_get_unary_op(op)) {
+                case GGML_UNARY_OP_SGN:
+                case GGML_UNARY_OP_ABS:
                 case GGML_UNARY_OP_NEG:
                 case GGML_UNARY_OP_STEP:
+                case GGML_UNARY_OP_RELU:
+                case GGML_UNARY_OP_HARDSIGMOID:
+                case GGML_UNARY_OP_TANH:
                 case GGML_UNARY_OP_GELU:
                 case GGML_UNARY_OP_SILU:
-                case GGML_UNARY_OP_RELU:
                 case GGML_UNARY_OP_SIGMOID:
-                case GGML_UNARY_OP_HARDSIGMOID:
                 case GGML_UNARY_OP_HARDSWISH:
                 case GGML_UNARY_OP_GELU_QUICK:
                 case GGML_UNARY_OP_GELU_ERF:
-                case GGML_UNARY_OP_TANH:
                 case GGML_UNARY_OP_EXP:
-                case GGML_UNARY_OP_SGN:
-                case GGML_UNARY_OP_ABS:
                 case GGML_UNARY_OP_ELU:
+                    return true;
                 case GGML_UNARY_OP_FLOOR:
                 case GGML_UNARY_OP_CEIL:
                 case GGML_UNARY_OP_ROUND:

Vissa filer visades inte eftersom för många filer har ändrats