Browse Source

opencl: add TRI op support (#18979)

shaofeiqi 1 week ago
parent
commit
5516b9c16a

+ 1 - 0
ggml/src/ggml-opencl/CMakeLists.txt

@@ -57,6 +57,7 @@ set(GGML_OPENCL_KERNELS
     add
     add_id
     argsort
+    tri
     fill
     clamp
     cpy

+ 65 - 0
ggml/src/ggml-opencl/ggml-opencl.cpp

@@ -489,6 +489,7 @@ struct ggml_backend_opencl_context {
     cl_kernel kernel_gelu_quick, kernel_gelu_quick_4;
     cl_kernel kernel_relu;
     cl_kernel kernel_sigmoid_f32, kernel_sigmoid_f16;
+    cl_kernel kernel_tri;
     cl_kernel kernel_fill;
     cl_kernel kernel_clamp;
     cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_swiglu_oai, kernel_geglu_erf, kernel_geglu_quick,
@@ -793,6 +794,24 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
         GGML_LOG_CONT(".");
     }
 
+    // tri
+    {
+#ifdef GGML_OPENCL_EMBED_KERNELS
+        const std::string kernel_src {
+            #include "tri.cl.h"
+        };
+#else
+        const std::string kernel_src = read_file("tri.cl");
+#endif
+        cl_program prog =
+            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
+
+        CL_CHECK((backend_ctx->kernel_tri = clCreateKernel(prog, "kernel_tri_f32", &err), err));
+        GGML_LOG_CONT(".");
+
+        CL_CHECK(clReleaseProgram(prog));
+    }
+
     // fill
     {
 #ifdef GGML_OPENCL_EMBED_KERNELS
@@ -3205,6 +3224,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
                 default:
                     return false;
             }
+        case GGML_OP_TRI:
+            return op->type == GGML_TYPE_F32 && ggml_is_contiguous(op);
         case GGML_OP_FILL:
             return op->type == GGML_TYPE_F32 && ggml_is_contiguous(op);
         case GGML_OP_CLAMP:
@@ -5965,6 +5986,44 @@ static void ggml_cl_sigmoid(ggml_backend_t backend, const ggml_tensor * src0, co
     backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
 }
 
+static void ggml_cl_tri(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0);
+    GGML_ASSERT(src0->extra);
+    GGML_ASSERT(dst);
+    GGML_ASSERT(dst->extra);
+
+    UNUSED(src1);
+
+    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+
+    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
+    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
+
+    cl_ulong offset0 = extra0->offset + src0->view_offs;
+    cl_ulong offsetd = extrad->offset + dst->view_offs;
+
+    const int tri_type = ggml_get_op_params_i32(dst, 0);
+    const int64_t n = ggml_nelements(dst);
+    const int     ne0  = dst->ne[0];
+    const int     ne1  = dst->ne[1];
+
+    cl_kernel kernel = backend_ctx->kernel_tri;
+
+    CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));
+    CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
+    CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extrad->data_device));
+    CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
+    CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int),      &n));
+    CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int),      &ne0));
+    CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int),      &ne1));
+    CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int),      &tri_type));
+
+    size_t local_work_size[1] = { 256 };
+    size_t global_work_size[1] = { ((size_t)n + local_work_size[0] - 1) / local_work_size[0] * local_work_size[0] };
+
+    backend_ctx->enqueue_ndrange_kernel(kernel, 1, global_work_size, local_work_size, dst);
+}
+
 static void ggml_cl_fill(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     GGML_ASSERT(dst);
     GGML_ASSERT(dst->extra);
@@ -10012,6 +10071,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
             }
             func = ggml_cl_glu;
             break;
+        case GGML_OP_TRI:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cl_tri;
+            break;
         case GGML_OP_FILL:
             if (!any_on_device) {
                 return false;

+ 32 - 0
ggml/src/ggml-opencl/kernels/tri.cl

@@ -0,0 +1,32 @@
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+
+//------------------------------------------------------------------------------
+// tri
+//------------------------------------------------------------------------------
+__kernel void kernel_tri_f32(
+        global float * src0,
+        ulong offset0,
+        global float * dst,
+        ulong offsetd,
+        int n,
+        int ne0,
+        int ne1,
+        int tri_type
+) {
+    src0 = (global float*)((global char*)src0 + offset0);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    int idx = get_global_id(0);
+    if (idx >= n) return;
+
+    int i0 = idx % ne0;
+    int i1 = (idx / ne0) % ne1;
+
+    int keep = 0;
+    if (tri_type == 0) keep = (i0 >= i1);
+    else if (tri_type == 1) keep = (i0 >  i1);
+    else if (tri_type == 2) keep = (i0 <= i1);
+    else                    keep = (i0 <  i1);
+
+    dst[idx] = keep ? src0[idx] : 0.0f;
+}