Просмотр исходного кода

opencl: fix rms_norm_mul (#17250)

* opencl: use subgrroup reduce for reduction in rms_norm_mul

* opencl: add comment about workgroup size
lhez 2 месяцев назад
Родитель
Сommit
52e5d421f1
2 измененных файлов с 26 добавлено и 11 удалено
  1. 1 1
      ggml/src/ggml-opencl/ggml-opencl.cpp
  2. 25 10
      ggml/src/ggml-opencl/kernels/rms_norm.cl

+ 1 - 1
ggml/src/ggml-opencl/ggml-opencl.cpp

@@ -5705,7 +5705,7 @@ static void ggml_opencl_op_rms_norm_fused(ggml_backend_t backend, ggml_tensor *
     CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong),      &nb2));
     CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong),      &nb2));
     CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong),      &nb3));
     CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong),      &nb3));
     CL_CHECK(clSetKernelArg(kernel, 23, sizeof(float),         &eps));
     CL_CHECK(clSetKernelArg(kernel, 23, sizeof(float),         &eps));
-    CL_CHECK(clSetKernelArg(kernel, 24, sizeof(float)*nth/sgs, NULL));
+    CL_CHECK(clSetKernelArg(kernel, 24, sizeof(float)*sgs,     NULL));
 
 
     backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
     backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
 }
 }

+ 25 - 10
ggml/src/ggml-opencl/kernels/rms_norm.cl

@@ -134,6 +134,15 @@ kernel void kernel_rms_norm_mul(
     src1 = src1 + offset1;
     src1 = src1 + offset1;
     dst  = dst  + offsetd;
     dst  = dst  + offsetd;
 
 
+    // The size of sum is sizeof(float)*subgroup_size.
+    // Each subgroup writes its partial sum to this array.
+    // So the number of subgroups per workgroup for this kernel cannot exceed the subgroup size.
+    // This is generally true -
+    // for subgroup size 64, workgroup size should be less than 4096 (the max is usually 1024).
+    if (get_sub_group_id() == 0) {
+        sum[get_sub_group_local_id()] = 0.0f;
+    }
+
     int i03 = get_group_id(2);
     int i03 = get_group_id(2);
     int i02 = get_group_id(1);
     int i02 = get_group_id(1);
     int i01 = get_group_id(0);
     int i01 = get_group_id(0);
@@ -148,24 +157,30 @@ kernel void kernel_rms_norm_mul(
         sumf += dot(x[i00], x[i00]);
         sumf += dot(x[i00], x[i00]);
     }
     }
     sumf = sub_group_reduce_add(sumf);
     sumf = sub_group_reduce_add(sumf);
+
+    barrier(CLK_LOCAL_MEM_FENCE);
+
     if (get_sub_group_local_id() == 0) {
     if (get_sub_group_local_id() == 0) {
         sum[get_sub_group_id()] = sumf;
         sum[get_sub_group_id()] = sumf;
     }
     }
 
 
     barrier(CLK_LOCAL_MEM_FENCE);
     barrier(CLK_LOCAL_MEM_FENCE);
 
 
-    for (uint i = get_local_size(0) / get_max_sub_group_size() / 2; i > 0; i /= 2) {
-       if (get_local_id(0) < i) {
-           sum[get_local_id(0)] += sum[get_local_id(0) + i];
-       }
-    }
-    if (get_local_id(0) == 0) {
-        sum[0] /= ne00;
-    }
+    //for (uint i = get_local_size(0) / get_max_sub_group_size() / 2; i > 0; i /= 2) {
+    //   if (get_local_id(0) < i) {
+    //       sum[get_local_id(0)] += sum[get_local_id(0) + i];
+    //   }
+    //}
+    //if (get_local_id(0) == 0) {
+    //    sum[0] /= ne00;
+    //}
 
 
-    barrier(CLK_LOCAL_MEM_FENCE);
+    //barrier(CLK_LOCAL_MEM_FENCE);
+
+    sumf = sum[get_sub_group_local_id()];
+    sumf = sub_group_reduce_add(sumf);
 
 
-    float mean  = sum[0];
+    float mean  = sumf / ne00;
     float scale = 1.0f/sqrt(mean + eps);
     float scale = 1.0f/sqrt(mean + eps);
 
 
     global float4 * y = (global float4 *) (dst + i03*nb3 + i02*nb2 + i01*nb1);
     global float4 * y = (global float4 *) (dst + i03*nb3 + i02*nb2 + i01*nb1);