Kaynağa Gözat

SYCL: Add fp16 type support to unary op kernels (#12788)

* SYCL: Add fp16 support to some elementwise OP kernels

* remove comment

ggml-ci

* Use static_cast directly

* remove not needed cast from tanh

* Use static cast and remove unneeded castings

* Adjust device_support_op for unary OPs

* Use cast_data and typed_data struct to deduplicate casting code
Akarshan Biswas 9 ay önce
ebeveyn
işleme
fccf9cae83

Dosya farkı çok büyük olduğundan ihmal edildi
+ 702 - 215
ggml/src/ggml-sycl/element_wise.cpp


+ 24 - 0
ggml/src/ggml-sycl/element_wise.hpp

@@ -2,6 +2,13 @@
 #define GGML_SYCL_ELEMENTWISE_HPP
 
 #include "common.hpp"
+#include "ggml.h"
+#include <limits.h>
+
+template <typename T>
+T neg_infinity() {
+    return -std::numeric_limits<T>::infinity();
+}
 
 static __dpct_inline__ float op_repeat(const float a, const float b) {
     return b;
@@ -24,6 +31,19 @@ static __dpct_inline__ float op_div(const float a, const float b) {
     return a / b;
 }
 
+template<typename T>
+struct typed_data {
+    const T * src;
+    T * dst;
+};
+
+template<typename T>
+typed_data<T> cast_data(ggml_tensor * dst) {
+    return {
+        /* .src = */ static_cast<const T *>(dst->src[0]->data),
+        /* .dst = */ static_cast<T *>(dst->data)
+    };
+}
 
 void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 
@@ -65,6 +85,10 @@ void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 
 void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 
+void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
+
+// ---------
+
 void ggml_sycl_add(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 
 void ggml_sycl_sub(ggml_backend_sycl_context & ctx, ggml_tensor * dst);

+ 11 - 49
ggml/src/ggml-sycl/ggml-sycl.cpp

@@ -1617,17 +1617,6 @@ static void scale_f32(const float * x, float * dst, const float scale, const int
     dst[i] = scale * x[i];
 }
 
-static void clamp_f32(const float * x, float * dst, const float min, const float max, const int k,
-                      const sycl::nd_item<3> &item_ct1) {
-    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                  item_ct1.get_local_id(2);
-
-    if (i >= k) {
-        return;
-    }
-
-    dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
-}
 
 template <typename Ti, typename To>
 static  void pool2d_nchw_kernel(
@@ -1768,18 +1757,6 @@ static void scale_f32_sycl(const float *x, float *dst, const float scale,
         });
 }
 
-static void clamp_f32_sycl(const float *x, float *dst, const float min,
-                           const float max, const int k,
-                           queue_ptr stream) {
-    const int num_blocks = (k + SYCL_CLAMP_BLOCK_SIZE - 1) / SYCL_CLAMP_BLOCK_SIZE;
-    stream->parallel_for(
-        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
-                              sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE),
-                          sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE)),
-        [=](sycl::nd_item<3> item_ct1) {
-            clamp_f32(x, dst, min, max, k, item_ct1);
-        });
-}
 
 static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
                               const int nrows, queue_ptr stream) {
@@ -2258,26 +2235,6 @@ inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor *dst
     SYCL_CHECK(0);
 }
 
-inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
-
-    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-    const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
-    float *       dst_dd  = static_cast<float *>(dst->data);
-
-    float min;
-    float max;
-    memcpy(&min, dst->op_params, sizeof(float));
-    memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
-
-    clamp_f32_sycl(src0_dd, dst_dd, min, max, ggml_nelements(dst->src[0]), ctx.stream());
-    /*
-    DPCT1010:88: SYCL uses exceptions to report errors and does not use the
-    error codes. The call was replaced with 0. You need to rewrite this code.
-    */
-    SYCL_CHECK(0);
-}
-
 static void ggml_sycl_set_peer_access(const int n_tokens, int main_device) {
     static bool peer_access_enabled = false;
 
@@ -3218,10 +3175,6 @@ static void ggml_sycl_scale(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
     ggml_sycl_op_scale(ctx, dst);
 }
 
-static void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
-    ggml_sycl_op_clamp(ctx, dst);
-}
-
 static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     ggml_sycl_op_diag_mask_inf(ctx, dst);
 }
@@ -3900,7 +3853,11 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
                 case GGML_UNARY_OP_GELU_QUICK:
                 case GGML_UNARY_OP_TANH:
                 case GGML_UNARY_OP_EXP:
-                    return ggml_is_contiguous(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32);
+#if defined (GGML_SYCL_F16)
+                    return ggml_is_contiguous(op->src[0]) && (op->type == op->src[0]->type);
+#else
+                    return ggml_is_contiguous(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) && (op->type == op->src[0]->type);
+#endif
                 default:
                     return false;
             }
@@ -4022,13 +3979,18 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_SUB:
         case GGML_OP_MUL:
         case GGML_OP_DIV:
+            return (op->src[0]->type == GGML_TYPE_F32);
         case GGML_OP_SQR:
         case GGML_OP_SQRT:
         case GGML_OP_SIN:
         case GGML_OP_COS:
         case GGML_OP_CLAMP:
         case GGML_OP_LOG:
-            return (op->src[0]->type == GGML_TYPE_F32);
+#if defined (GGML_SYCL_F16)
+            return ((op->type == GGML_TYPE_F32 || op->type == GGML_SYCL_F16) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_SYCL_F16) && (op->type == op->src[0]->type));
+#else
+            return (op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32) && (op->type == op->src[0]->type);
+#endif
         case GGML_OP_NORM:
         case GGML_OP_RMS_NORM:
         case GGML_OP_L2_NORM:

Bu fark içinde çok fazla dosya değişikliği olduğu için bazı dosyalar gösterilmiyor