Przeglądaj źródła

opencl: add fastdiv and use it in set_rows, ported from cuda (#17090)

* opencl: add fastdiv for mm q8_0

* opencl: use uint4 for fastdiv vals

* opencl: use fastdiv for set_rows

* opencl: do not use fastdiv for q8_0 mm
lhez 2 miesięcy temu
rodzic
commit
ece0f5c177

+ 36 - 2
ggml/src/ggml-opencl/ggml-opencl.cpp

@@ -53,6 +53,37 @@
 
 
 bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor);
 bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor);
 
 
+// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
+// Precompute mp (m' in the paper) and L such that division
+// can be computed using a multiply (high 32b of 64b result)
+// and a shift:
+//
+// n/d = (mulhi(n, mp) + n) >> L;
+struct fastdiv_vals {
+    uint32_t mp;
+    uint32_t L;
+    uint32_t d;
+    uint32_t pad;
+};
+static_assert(sizeof(fastdiv_vals) == 16, "fastdiv_vals size incorrect");
+
+static fastdiv_vals init_fastdiv_values(uint64_t d_64) {
+    GGML_ASSERT(d_64 != 0);
+    GGML_ASSERT(d_64 <= std::numeric_limits<uint32_t>::max());
+
+    uint32_t d = (uint32_t)d_64;
+
+    // compute L = ceil(log2(d));
+    uint32_t L = 0;
+    while (L < 32 && (uint32_t{ 1 } << L) < d) {
+        L++;
+    }
+
+    uint32_t mp = (uint32_t) ((uint64_t{ 1 } << 32) * ((uint64_t{ 1 } << L) - d) / d + 1);
+    // pack divisor as well to reduce error surface
+    return { mp, L, d, 0 };
+}
+
 enum GPU_FAMILY {
 enum GPU_FAMILY {
     ADRENO,
     ADRENO,
     INTEL,
     INTEL,
@@ -4464,6 +4495,9 @@ static void ggml_cl_set_rows(ggml_backend_t backend, const ggml_tensor * src0, c
             GGML_ABORT("not implemented");
             GGML_ABORT("not implemented");
     }
     }
 
 
+    fastdiv_vals ne11_ = init_fastdiv_values(ne11);
+    fastdiv_vals ne12_ = init_fastdiv_values(ne12);
+
     CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));
     CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));
     CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));
     CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));
     CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));
     CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));
@@ -4474,8 +4508,8 @@ static void ggml_cl_set_rows(ggml_backend_t backend, const ggml_tensor * src0, c
     CL_CHECK(clSetKernelArg(kernel,  7, sizeof(cl_ulong), &nb01));
     CL_CHECK(clSetKernelArg(kernel,  7, sizeof(cl_ulong), &nb01));
     CL_CHECK(clSetKernelArg(kernel,  8, sizeof(cl_ulong), &nb02));
     CL_CHECK(clSetKernelArg(kernel,  8, sizeof(cl_ulong), &nb02));
     CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong), &nb03));
     CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong), &nb03));
-    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int),      &ne11));
-    CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int),      &ne12));
+    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(fastdiv_vals), &ne11_));
+    CL_CHECK(clSetKernelArg(kernel, 11, sizeof(fastdiv_vals), &ne12_));
     CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb10));
     CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb10));
     CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb11));
     CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb11));
     CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb12));
     CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb12));

+ 35 - 16
ggml/src/ggml-opencl/kernels/set_rows.cl

@@ -1,5 +1,16 @@
 #pragma OPENCL EXTENSION cl_khr_fp16 : enable
 #pragma OPENCL EXTENSION cl_khr_fp16 : enable
 
 
+// v = { mp, L, d }
+inline uint fastdiv(uint n, uint4 v) {
+    uint msbs;
+    msbs = mul_hi(n, v.s0);
+    return (msbs + n) >> v.s1;
+}
+inline uint fastmod(uint n, uint4 v) {
+    uint q = fastdiv(n, v);
+    return n - q * v.s2;
+}
+
 kernel void kernel_set_rows_f32_i64(
 kernel void kernel_set_rows_f32_i64(
         global char * src0,
         global char * src0,
         ulong         offset0,
         ulong         offset0,
@@ -11,8 +22,8 @@ kernel void kernel_set_rows_f32_i64(
         ulong         nb01,
         ulong         nb01,
         ulong         nb02,
         ulong         nb02,
         ulong         nb03,
         ulong         nb03,
-        int           ne11,
-        int           ne12,
+        uint4         ne11,
+        uint4         ne12,
         ulong         nb10,
         ulong         nb10,
         ulong         nb11,
         ulong         nb11,
         ulong         nb12,
         ulong         nb12,
@@ -33,8 +44,10 @@ kernel void kernel_set_rows_f32_i64(
         return;
         return;
     }
     }
 
 
-    int i12 = i03%ne12;
-    int i11 = i02%ne11;
+    //int i12 = i03%ne12;
+    //int i11 = i02%ne11;
+    int i12 = fastmod(i03, ne12);
+    int i11 = fastmod(i02, ne11);
 
 
     int i10 = i01;
     int i10 = i01;
     long i1 = ((global long *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
     long i1 = ((global long *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
@@ -58,8 +71,8 @@ kernel void kernel_set_rows_f16_i64(
         ulong         nb01,
         ulong         nb01,
         ulong         nb02,
         ulong         nb02,
         ulong         nb03,
         ulong         nb03,
-        int           ne11,
-        int           ne12,
+        uint4         ne11,
+        uint4         ne12,
         ulong         nb10,
         ulong         nb10,
         ulong         nb11,
         ulong         nb11,
         ulong         nb12,
         ulong         nb12,
@@ -80,8 +93,10 @@ kernel void kernel_set_rows_f16_i64(
         return;
         return;
     }
     }
 
 
-    int i12 = i03%ne12;
-    int i11 = i02%ne11;
+    //int i12 = i03%ne12;
+    //int i11 = i02%ne11;
+    int i12 = fastmod(i03, ne12);
+    int i11 = fastmod(i02, ne11);
 
 
     int i10 = i01;
     int i10 = i01;
     long i1 = ((global long *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
     long i1 = ((global long *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
@@ -105,8 +120,8 @@ kernel void kernel_set_rows_f32_i32(
         ulong         nb01,
         ulong         nb01,
         ulong         nb02,
         ulong         nb02,
         ulong         nb03,
         ulong         nb03,
-        int           ne11,
-        int           ne12,
+        uint4         ne11,
+        uint4         ne12,
         ulong         nb10,
         ulong         nb10,
         ulong         nb11,
         ulong         nb11,
         ulong         nb12,
         ulong         nb12,
@@ -127,8 +142,10 @@ kernel void kernel_set_rows_f32_i32(
         return;
         return;
     }
     }
 
 
-    int i12 = i03%ne12;
-    int i11 = i02%ne11;
+    //int i12 = i03%ne12;
+    //int i11 = i02%ne11;
+    int i12 = fastmod(i03, ne12);
+    int i11 = fastmod(i02, ne11);
 
 
     int i10 = i01;
     int i10 = i01;
     int i1  = ((global int *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
     int i1  = ((global int *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
@@ -152,8 +169,8 @@ kernel void kernel_set_rows_f16_i32(
         ulong         nb01,
         ulong         nb01,
         ulong         nb02,
         ulong         nb02,
         ulong         nb03,
         ulong         nb03,
-        int           ne11,
-        int           ne12,
+        uint4         ne11,
+        uint4         ne12,
         ulong         nb10,
         ulong         nb10,
         ulong         nb11,
         ulong         nb11,
         ulong         nb12,
         ulong         nb12,
@@ -174,8 +191,10 @@ kernel void kernel_set_rows_f16_i32(
         return;
         return;
     }
     }
 
 
-    int i12 = i03%ne12;
-    int i11 = i02%ne11;
+    //int i12 = i03%ne12;
+    //int i11 = i02%ne11;
+    int i12 = fastmod(i03, ne12);
+    int i11 = fastmod(i02, ne11);
 
 
     int i10 = i01;
     int i10 = i01;
     int i1  = ((global int *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
     int i1  = ((global int *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];