|
|
@@ -134,6 +134,15 @@ kernel void kernel_rms_norm_mul(
|
|
|
src1 = src1 + offset1;
|
|
|
dst = dst + offsetd;
|
|
|
|
|
|
+ // The size of sum is sizeof(float)*subgroup_size.
|
|
|
+ // Each subgroup writes its partial sum to this array.
|
|
|
+ // So the number of subgroups per workgroup for this kernel cannot exceed the subgroup size.
|
|
|
+ // This is generally true -
|
|
|
+ // for subgroup size 64, workgroup size should be less than 4096 (the max is usually 1024).
|
|
|
+ if (get_sub_group_id() == 0) {
|
|
|
+ sum[get_sub_group_local_id()] = 0.0f;
|
|
|
+ }
|
|
|
+
|
|
|
int i03 = get_group_id(2);
|
|
|
int i02 = get_group_id(1);
|
|
|
int i01 = get_group_id(0);
|
|
|
@@ -148,24 +157,30 @@ kernel void kernel_rms_norm_mul(
|
|
|
sumf += dot(x[i00], x[i00]);
|
|
|
}
|
|
|
sumf = sub_group_reduce_add(sumf);
|
|
|
+
|
|
|
+ barrier(CLK_LOCAL_MEM_FENCE);
|
|
|
+
|
|
|
if (get_sub_group_local_id() == 0) {
|
|
|
sum[get_sub_group_id()] = sumf;
|
|
|
}
|
|
|
|
|
|
barrier(CLK_LOCAL_MEM_FENCE);
|
|
|
|
|
|
- for (uint i = get_local_size(0) / get_max_sub_group_size() / 2; i > 0; i /= 2) {
|
|
|
- if (get_local_id(0) < i) {
|
|
|
- sum[get_local_id(0)] += sum[get_local_id(0) + i];
|
|
|
- }
|
|
|
- }
|
|
|
- if (get_local_id(0) == 0) {
|
|
|
- sum[0] /= ne00;
|
|
|
- }
|
|
|
+ //for (uint i = get_local_size(0) / get_max_sub_group_size() / 2; i > 0; i /= 2) {
|
|
|
+ // if (get_local_id(0) < i) {
|
|
|
+ // sum[get_local_id(0)] += sum[get_local_id(0) + i];
|
|
|
+ // }
|
|
|
+ //}
|
|
|
+ //if (get_local_id(0) == 0) {
|
|
|
+ // sum[0] /= ne00;
|
|
|
+ //}
|
|
|
|
|
|
- barrier(CLK_LOCAL_MEM_FENCE);
|
|
|
+ //barrier(CLK_LOCAL_MEM_FENCE);
|
|
|
+
|
|
|
+ sumf = sum[get_sub_group_local_id()];
|
|
|
+ sumf = sub_group_reduce_add(sumf);
|
|
|
|
|
|
- float mean = sum[0];
|
|
|
+ float mean = sumf / ne00;
|
|
|
float scale = 1.0f/sqrt(mean + eps);
|
|
|
|
|
|
global float4 * y = (global float4 *) (dst + i03*nb3 + i02*nb2 + i01*nb1);
|