|
|
@@ -25,19 +25,8 @@ static __global__ void norm_f32(
|
|
|
}
|
|
|
|
|
|
// sum up partial sums
|
|
|
- mean_var = warp_reduce_sum(mean_var);
|
|
|
- if constexpr (block_size > WARP_SIZE) {
|
|
|
- static_assert(block_size == 1024, "unexpected block_size");
|
|
|
- __shared__ float2 s_sum[32];
|
|
|
- const int warp_id = threadIdx.x / WARP_SIZE;
|
|
|
- const int lane_id = threadIdx.x % WARP_SIZE;
|
|
|
- if (lane_id == 0) {
|
|
|
- s_sum[warp_id] = mean_var;
|
|
|
- }
|
|
|
- __syncthreads();
|
|
|
- mean_var = s_sum[lane_id];
|
|
|
- mean_var = warp_reduce_sum(mean_var);
|
|
|
- }
|
|
|
+ extern __shared__ float2 s_sum2[];
|
|
|
+ mean_var = block_reduce<block_reduce_method::SUM, block_size>(mean_var, s_sum2);
|
|
|
|
|
|
const float mean = mean_var.x / ncols;
|
|
|
const float var = mean_var.y / ncols - mean * mean;
|
|
|
@@ -61,19 +50,8 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
|
|
|
tmp += x[j];
|
|
|
}
|
|
|
|
|
|
- tmp = warp_reduce_sum(tmp);
|
|
|
- if constexpr (block_size > WARP_SIZE) {
|
|
|
- static_assert(block_size == 1024, "unexpected block_size");
|
|
|
- __shared__ float s_sum[32];
|
|
|
- const int warp_id = threadIdx.x / WARP_SIZE;
|
|
|
- const int lane_id = threadIdx.x % WARP_SIZE;
|
|
|
- if (lane_id == 0) {
|
|
|
- s_sum[warp_id] = tmp;
|
|
|
- }
|
|
|
- __syncthreads();
|
|
|
- tmp = s_sum[lane_id];
|
|
|
- tmp = warp_reduce_sum(tmp);
|
|
|
- }
|
|
|
+ extern __shared__ float s_sum[];
|
|
|
+ tmp = block_reduce<block_reduce_method::SUM, block_size>(tmp, s_sum);
|
|
|
|
|
|
const float mean = tmp / group_size;
|
|
|
tmp = 0.0f;
|
|
|
@@ -84,18 +62,7 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
|
|
|
tmp += xi * xi;
|
|
|
}
|
|
|
|
|
|
- tmp = warp_reduce_sum(tmp);
|
|
|
- if (block_size > WARP_SIZE) {
|
|
|
- __shared__ float s_sum[32];
|
|
|
- const int warp_id = threadIdx.x / WARP_SIZE;
|
|
|
- const int lane_id = threadIdx.x % WARP_SIZE;
|
|
|
- if (lane_id == 0) {
|
|
|
- s_sum[warp_id] = tmp;
|
|
|
- }
|
|
|
- __syncthreads();
|
|
|
- tmp = s_sum[lane_id];
|
|
|
- tmp = warp_reduce_sum(tmp);
|
|
|
- }
|
|
|
+ tmp = block_reduce<block_reduce_method::SUM, block_size>(tmp, s_sum);
|
|
|
|
|
|
const float variance = tmp / group_size;
|
|
|
const float scale = rsqrtf(variance + eps);
|
|
|
@@ -163,22 +130,8 @@ static __global__ void rms_norm_f32(const float * x,
|
|
|
}
|
|
|
|
|
|
// sum up partial sums
|
|
|
- tmp = warp_reduce_sum(tmp);
|
|
|
- if constexpr (block_size > WARP_SIZE) {
|
|
|
- static_assert((block_size <= 1024) && (block_size % 32 == 0), "unexpected block_size");
|
|
|
- __shared__ float s_sum[32];
|
|
|
- const int warp_id = tid / WARP_SIZE;
|
|
|
- const int lane_id = tid % WARP_SIZE;
|
|
|
- if (lane_id == 0) {
|
|
|
- s_sum[warp_id] = tmp;
|
|
|
- }
|
|
|
- __syncthreads();
|
|
|
- tmp = 0.0f;
|
|
|
- if (lane_id < (block_size / WARP_SIZE)) {
|
|
|
- tmp = s_sum[lane_id];
|
|
|
- }
|
|
|
- tmp = warp_reduce_sum(tmp);
|
|
|
- }
|
|
|
+ extern __shared__ float s_sum[];
|
|
|
+ tmp = block_reduce<block_reduce_method::SUM, block_size>(tmp, s_sum);
|
|
|
|
|
|
const float mean = tmp / ncols;
|
|
|
const float scale = rsqrtf(mean + eps);
|
|
|
@@ -306,19 +259,8 @@ static __global__ void l2_norm_f32(
|
|
|
}
|
|
|
|
|
|
// sum up partial sums
|
|
|
- tmp = warp_reduce_sum(tmp);
|
|
|
- if constexpr (block_size > WARP_SIZE) {
|
|
|
- static_assert(block_size == 1024, "unexpected block_size");
|
|
|
- __shared__ float s_sum[32];
|
|
|
- const int warp_id = threadIdx.x / WARP_SIZE;
|
|
|
- const int lane_id = threadIdx.x % WARP_SIZE;
|
|
|
- if (lane_id == 0) {
|
|
|
- s_sum[warp_id] = tmp;
|
|
|
- }
|
|
|
- __syncthreads();
|
|
|
- tmp = s_sum[lane_id];
|
|
|
- tmp = warp_reduce_sum(tmp);
|
|
|
- }
|
|
|
+ extern __shared__ float s_sum[];
|
|
|
+ tmp = block_reduce<block_reduce_method::SUM, block_size>(tmp, s_sum);
|
|
|
|
|
|
// from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
|
|
|
const float scale = rsqrtf(fmaxf(tmp, eps * eps));
|
|
|
@@ -337,7 +279,7 @@ static void norm_f32_cuda(
|
|
|
norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
|
|
|
} else {
|
|
|
const dim3 block_dims(1024, 1, 1);
|
|
|
- norm_f32<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
|
|
|
+ norm_f32<1024><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float2): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -348,7 +290,7 @@ static void group_norm_f32_cuda(
|
|
|
group_norm_f32<WARP_SIZE><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
|
|
|
} else {
|
|
|
const dim3 block_dims(1024, 1, 1);
|
|
|
- group_norm_f32<1024><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
|
|
|
+ group_norm_f32<1024><<<num_groups, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, group_size, ne_elements, eps);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -358,10 +300,10 @@ static void rms_norm_f32_cuda(
|
|
|
const dim3 blocks_num(nrows, nchannels, nsamples);
|
|
|
if (ncols < 1024) {
|
|
|
const dim3 block_dims(256, 1, 1);
|
|
|
- rms_norm_f32<256, false><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
|
|
|
+ rms_norm_f32<256, false><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
|
|
|
} else {
|
|
|
const dim3 block_dims(1024, 1, 1);
|
|
|
- rms_norm_f32<1024, false><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
|
|
|
+ rms_norm_f32<1024, false><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -404,12 +346,12 @@ static void rms_norm_mul_f32_cuda(const float * x,
|
|
|
const uint3 mul_nsamples_packed = init_fastdiv_values(mul_nsamples);
|
|
|
if (ncols < 1024) {
|
|
|
const dim3 block_dims(256, 1, 1);
|
|
|
- rms_norm_f32<256, true><<<blocks_num, block_dims, 0, stream>>>(
|
|
|
+ rms_norm_f32<256, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
|
|
|
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
|
|
|
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
|
|
|
} else {
|
|
|
const dim3 block_dims(1024, 1, 1);
|
|
|
- rms_norm_f32<1024, true><<<blocks_num, block_dims, 0, stream>>>(
|
|
|
+ rms_norm_f32<1024, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
|
|
|
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
|
|
|
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
|
|
|
}
|
|
|
@@ -425,14 +367,14 @@ static void rms_norm_mul_f32_cuda(const float * x,
|
|
|
const uint3 add_nsamples_packed = init_fastdiv_values(add_nsamples);
|
|
|
if (ncols < 1024) {
|
|
|
const dim3 block_dims(256, 1, 1);
|
|
|
- rms_norm_f32<256, true, true><<<blocks_num, block_dims, 0, stream>>>(
|
|
|
+ rms_norm_f32<256, true, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
|
|
|
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
|
|
|
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
|
|
|
add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
|
|
|
add_nchannels_packed, add_nsamples_packed);
|
|
|
} else {
|
|
|
const dim3 block_dims(1024, 1, 1);
|
|
|
- rms_norm_f32<1024, true, true><<<blocks_num, block_dims, 0, stream>>>(
|
|
|
+ rms_norm_f32<1024, true, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
|
|
|
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
|
|
|
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
|
|
|
add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
|
|
|
@@ -460,7 +402,7 @@ static void l2_norm_f32_cuda(
|
|
|
l2_norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
|
|
|
} else {
|
|
|
const dim3 block_dims(1024, 1, 1);
|
|
|
- l2_norm_f32<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
|
|
|
+ l2_norm_f32<1024><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
|
|
|
}
|
|
|
}
|
|
|
|