|
@@ -105,29 +105,29 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
template <int block_size, bool do_multiply = false, bool do_add = false>
|
|
template <int block_size, bool do_multiply = false, bool do_add = false>
|
|
|
-static __global__ void rms_norm_f32(const float * x, float * dst,
|
|
|
|
|
|
|
+static __global__ void rms_norm_f32(const float * x,
|
|
|
|
|
+ float * dst,
|
|
|
const int ncols,
|
|
const int ncols,
|
|
|
const int64_t stride_row,
|
|
const int64_t stride_row,
|
|
|
const int64_t stride_channel,
|
|
const int64_t stride_channel,
|
|
|
const int64_t stride_sample,
|
|
const int64_t stride_sample,
|
|
|
const float eps,
|
|
const float eps,
|
|
|
- const float * mul = nullptr,
|
|
|
|
|
- const int64_t mul_stride_row = 0,
|
|
|
|
|
- const int64_t mul_stride_channel = 0,
|
|
|
|
|
- const int64_t mul_stride_sample = 0,
|
|
|
|
|
- const int mul_ncols = 0,
|
|
|
|
|
- const int mul_nrows = 0,
|
|
|
|
|
- const int mul_nchannels = 0,
|
|
|
|
|
- const int mul_nsamples = 0,
|
|
|
|
|
- const float * add = nullptr,
|
|
|
|
|
- const int64_t add_stride_row = 0,
|
|
|
|
|
- const int64_t add_stride_channel = 0,
|
|
|
|
|
- const int64_t add_stride_sample = 0,
|
|
|
|
|
- const int add_ncols = 0,
|
|
|
|
|
- const int add_nrows = 0,
|
|
|
|
|
- const int add_nchannels = 0,
|
|
|
|
|
- const int add_nsamples = 0) {
|
|
|
|
|
-
|
|
|
|
|
|
|
+ const float * mul = nullptr,
|
|
|
|
|
+ const int64_t mul_stride_row = 0,
|
|
|
|
|
+ const int64_t mul_stride_channel = 0,
|
|
|
|
|
+ const int64_t mul_stride_sample = 0,
|
|
|
|
|
+ const uint3 mul_ncols_packed = make_uint3(0, 0, 0),
|
|
|
|
|
+ const uint3 mul_nrows_packed = make_uint3(0, 0, 0),
|
|
|
|
|
+ const uint3 mul_nchannels_packed = make_uint3(0, 0, 0),
|
|
|
|
|
+ const uint3 mul_nsamples_packed = make_uint3(0, 0, 0),
|
|
|
|
|
+ const float * add = nullptr,
|
|
|
|
|
+ const int64_t add_stride_row = 0,
|
|
|
|
|
+ const int64_t add_stride_channel = 0,
|
|
|
|
|
+ const int64_t add_stride_sample = 0,
|
|
|
|
|
+ const uint3 add_ncols_packed = make_uint3(0, 0, 0),
|
|
|
|
|
+ const uint3 add_nrows_packed = make_uint3(0, 0, 0),
|
|
|
|
|
+ const uint3 add_nchannels_packed = make_uint3(0, 0, 0),
|
|
|
|
|
+ const uint3 add_nsamples_packed = make_uint3(0, 0, 0)) {
|
|
|
const int nrows = gridDim.x;
|
|
const int nrows = gridDim.x;
|
|
|
const int nchannels = gridDim.y;
|
|
const int nchannels = gridDim.y;
|
|
|
|
|
|
|
@@ -142,16 +142,16 @@ static __global__ void rms_norm_f32(const float * x, float * dst,
|
|
|
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
|
|
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
|
|
|
|
|
|
|
|
if constexpr (do_multiply) {
|
|
if constexpr (do_multiply) {
|
|
|
- const int mul_row = row % mul_nrows;
|
|
|
|
|
- const int mul_channel = channel % mul_nchannels;
|
|
|
|
|
- const int mul_sample = sample % mul_nsamples;
|
|
|
|
|
- mul += mul_sample*mul_stride_sample + mul_channel*mul_stride_channel + mul_row*mul_stride_row;
|
|
|
|
|
|
|
+ const uint32_t mul_row = fastmodulo(row, mul_nrows_packed);
|
|
|
|
|
+ const uint32_t mul_channel = fastmodulo(channel, mul_nchannels_packed);
|
|
|
|
|
+ const uint32_t mul_sample = fastmodulo(sample, mul_nsamples_packed);
|
|
|
|
|
+ mul += mul_sample * mul_stride_sample + mul_channel * mul_stride_channel + mul_row * mul_stride_row;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
if constexpr (do_add) {
|
|
if constexpr (do_add) {
|
|
|
- const int add_row = row % add_nrows;
|
|
|
|
|
- const int add_channel = channel % add_nchannels;
|
|
|
|
|
- const int add_sample = sample % add_nsamples;
|
|
|
|
|
|
|
+ const int add_row = fastmodulo(row, add_nrows_packed);
|
|
|
|
|
+ const int add_channel = fastmodulo(channel, add_nchannels_packed);
|
|
|
|
|
+ const int add_sample = fastmodulo(sample, add_nsamples_packed);
|
|
|
add += add_sample * add_stride_sample + add_channel * add_stride_channel + add_row * add_stride_row;
|
|
add += add_sample * add_stride_sample + add_channel * add_stride_channel + add_row * add_stride_row;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -165,15 +165,18 @@ static __global__ void rms_norm_f32(const float * x, float * dst,
|
|
|
// sum up partial sums
|
|
// sum up partial sums
|
|
|
tmp = warp_reduce_sum(tmp);
|
|
tmp = warp_reduce_sum(tmp);
|
|
|
if constexpr (block_size > WARP_SIZE) {
|
|
if constexpr (block_size > WARP_SIZE) {
|
|
|
- static_assert(block_size == 1024, "unexpected block_size");
|
|
|
|
|
|
|
+ static_assert((block_size <= 1024) && (block_size % 32 == 0), "unexpected block_size");
|
|
|
__shared__ float s_sum[32];
|
|
__shared__ float s_sum[32];
|
|
|
- const int warp_id = threadIdx.x / WARP_SIZE;
|
|
|
|
|
- const int lane_id = threadIdx.x % WARP_SIZE;
|
|
|
|
|
|
|
+ const int warp_id = tid / WARP_SIZE;
|
|
|
|
|
+ const int lane_id = tid % WARP_SIZE;
|
|
|
if (lane_id == 0) {
|
|
if (lane_id == 0) {
|
|
|
s_sum[warp_id] = tmp;
|
|
s_sum[warp_id] = tmp;
|
|
|
}
|
|
}
|
|
|
__syncthreads();
|
|
__syncthreads();
|
|
|
- tmp = s_sum[lane_id];
|
|
|
|
|
|
|
+ tmp = 0.0f;
|
|
|
|
|
+ if (lane_id < (block_size / WARP_SIZE)) {
|
|
|
|
|
+ tmp = s_sum[lane_id];
|
|
|
|
|
+ }
|
|
|
tmp = warp_reduce_sum(tmp);
|
|
tmp = warp_reduce_sum(tmp);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -182,12 +185,12 @@ static __global__ void rms_norm_f32(const float * x, float * dst,
|
|
|
|
|
|
|
|
for (int col = tid; col < ncols; col += block_size) {
|
|
for (int col = tid; col < ncols; col += block_size) {
|
|
|
if constexpr (do_multiply && do_add) {
|
|
if constexpr (do_multiply && do_add) {
|
|
|
- const int mul_col = col % mul_ncols;
|
|
|
|
|
- const int add_col = col % add_ncols;
|
|
|
|
|
- dst[col] = scale * x[col] * mul[mul_col] + add[add_col];
|
|
|
|
|
|
|
+ const int mul_col = fastmodulo(col, mul_ncols_packed);
|
|
|
|
|
+ const int add_col = fastmodulo(col, add_ncols_packed);
|
|
|
|
|
+ dst[col] = scale * x[col] * mul[mul_col] + add[add_col];
|
|
|
} else if constexpr (do_multiply) {
|
|
} else if constexpr (do_multiply) {
|
|
|
- const int mul_col = col % mul_ncols;
|
|
|
|
|
- dst[col] = scale * x[col] * mul[mul_col];
|
|
|
|
|
|
|
+ const int mul_col = fastmodulo(col, mul_ncols_packed);
|
|
|
|
|
+ dst[col] = scale * x[col] * mul[mul_col];
|
|
|
} else {
|
|
} else {
|
|
|
dst[col] = scale * x[col];
|
|
dst[col] = scale * x[col];
|
|
|
}
|
|
}
|
|
@@ -354,77 +357,86 @@ static void rms_norm_f32_cuda(
|
|
|
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
|
|
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
|
|
|
const dim3 blocks_num(nrows, nchannels, nsamples);
|
|
const dim3 blocks_num(nrows, nchannels, nsamples);
|
|
|
if (ncols < 1024) {
|
|
if (ncols < 1024) {
|
|
|
- const dim3 block_dims(WARP_SIZE, 1, 1);
|
|
|
|
|
- rms_norm_f32<WARP_SIZE, false><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
|
|
|
|
|
|
|
+ const dim3 block_dims(256, 1, 1);
|
|
|
|
|
+ rms_norm_f32<256, false><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
|
|
|
} else {
|
|
} else {
|
|
|
const dim3 block_dims(1024, 1, 1);
|
|
const dim3 block_dims(1024, 1, 1);
|
|
|
rms_norm_f32<1024, false><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
|
|
rms_norm_f32<1024, false><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-static void rms_norm_mul_f32_cuda(const float * x,
|
|
|
|
|
- const float * mul,
|
|
|
|
|
- const float * add,
|
|
|
|
|
- float * dst,
|
|
|
|
|
- const int ncols,
|
|
|
|
|
- const int nrows,
|
|
|
|
|
- const int nchannels,
|
|
|
|
|
- const int nsamples,
|
|
|
|
|
- const int64_t stride_row,
|
|
|
|
|
- const int64_t stride_channel,
|
|
|
|
|
- const int64_t stride_sample,
|
|
|
|
|
- const int64_t mul_stride_row,
|
|
|
|
|
- const int64_t mul_stride_channel,
|
|
|
|
|
- const int64_t mul_stride_sample,
|
|
|
|
|
- const int mul_ncols,
|
|
|
|
|
- const int mul_nrows,
|
|
|
|
|
- const int mul_nchannels,
|
|
|
|
|
- const int mul_nsamples,
|
|
|
|
|
- const int64_t add_stride_row,
|
|
|
|
|
- const int64_t add_stride_channel,
|
|
|
|
|
- const int64_t add_stride_sample,
|
|
|
|
|
- const int add_ncols,
|
|
|
|
|
- const int add_nrows,
|
|
|
|
|
- const int add_nchannels,
|
|
|
|
|
- const int add_nsamples,
|
|
|
|
|
- const float eps,
|
|
|
|
|
- cudaStream_t stream) {
|
|
|
|
|
|
|
+static void rms_norm_mul_f32_cuda(const float * x,
|
|
|
|
|
+ const float * mul,
|
|
|
|
|
+ const float * add,
|
|
|
|
|
+ float * dst,
|
|
|
|
|
+ const int ncols,
|
|
|
|
|
+ const int nrows,
|
|
|
|
|
+ const int nchannels,
|
|
|
|
|
+ const int nsamples,
|
|
|
|
|
+ const int64_t stride_row,
|
|
|
|
|
+ const int64_t stride_channel,
|
|
|
|
|
+ const int64_t stride_sample,
|
|
|
|
|
+ const int64_t mul_stride_row,
|
|
|
|
|
+ const int64_t mul_stride_channel,
|
|
|
|
|
+ const int64_t mul_stride_sample,
|
|
|
|
|
+ const uint32_t mul_ncols,
|
|
|
|
|
+ const uint32_t mul_nrows,
|
|
|
|
|
+ const uint32_t mul_nchannels,
|
|
|
|
|
+ const uint32_t mul_nsamples,
|
|
|
|
|
+ const int64_t add_stride_row,
|
|
|
|
|
+ const int64_t add_stride_channel,
|
|
|
|
|
+ const int64_t add_stride_sample,
|
|
|
|
|
+ const uint32_t add_ncols,
|
|
|
|
|
+ const uint32_t add_nrows,
|
|
|
|
|
+ const uint32_t add_nchannels,
|
|
|
|
|
+ const uint32_t add_nsamples,
|
|
|
|
|
+ const float eps,
|
|
|
|
|
+ cudaStream_t stream) {
|
|
|
const dim3 blocks_num(nrows, nchannels, nsamples);
|
|
const dim3 blocks_num(nrows, nchannels, nsamples);
|
|
|
if (mul == nullptr) {
|
|
if (mul == nullptr) {
|
|
|
rms_norm_f32_cuda(x, dst, ncols, nrows, nchannels, nsamples, stride_row, stride_channel, stride_sample, eps, stream);
|
|
rms_norm_f32_cuda(x, dst, ncols, nrows, nchannels, nsamples, stride_row, stride_channel, stride_sample, eps, stream);
|
|
|
return;
|
|
return;
|
|
|
}
|
|
}
|
|
|
if (add == nullptr) {
|
|
if (add == nullptr) {
|
|
|
|
|
+ const uint3 mul_ncols_packed = init_fastdiv_values(mul_ncols);
|
|
|
|
|
+ const uint3 mul_nrows_packed = init_fastdiv_values(mul_nrows);
|
|
|
|
|
+ const uint3 mul_nchannels_packed = init_fastdiv_values(mul_nchannels);
|
|
|
|
|
+ const uint3 mul_nsamples_packed = init_fastdiv_values(mul_nsamples);
|
|
|
if (ncols < 1024) {
|
|
if (ncols < 1024) {
|
|
|
- const dim3 block_dims(WARP_SIZE, 1, 1);
|
|
|
|
|
- rms_norm_f32<WARP_SIZE, true><<<blocks_num, block_dims, 0, stream>>>(x, dst,
|
|
|
|
|
- ncols, stride_row, stride_channel, stride_sample, eps,
|
|
|
|
|
- mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
|
|
|
|
|
- mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
|
|
|
|
|
|
|
+ const dim3 block_dims(256, 1, 1);
|
|
|
|
|
+ rms_norm_f32<256, true><<<blocks_num, block_dims, 0, stream>>>(
|
|
|
|
|
+ x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
|
|
|
|
|
+ mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
|
|
|
} else {
|
|
} else {
|
|
|
const dim3 block_dims(1024, 1, 1);
|
|
const dim3 block_dims(1024, 1, 1);
|
|
|
- rms_norm_f32<1024, true><<<blocks_num, block_dims, 0, stream>>>(x, dst,
|
|
|
|
|
- ncols, stride_row, stride_channel, stride_sample, eps,
|
|
|
|
|
- mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
|
|
|
|
|
- mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
|
|
|
|
|
|
|
+ rms_norm_f32<1024, true><<<blocks_num, block_dims, 0, stream>>>(
|
|
|
|
|
+ x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
|
|
|
|
|
+ mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
|
|
|
}
|
|
}
|
|
|
} else {
|
|
} else {
|
|
|
|
|
+ const uint3 mul_ncols_packed = init_fastdiv_values(mul_ncols);
|
|
|
|
|
+ const uint3 mul_nrows_packed = init_fastdiv_values(mul_nrows);
|
|
|
|
|
+ const uint3 mul_nchannels_packed = init_fastdiv_values(mul_nchannels);
|
|
|
|
|
+ const uint3 mul_nsamples_packed = init_fastdiv_values(mul_nsamples);
|
|
|
|
|
+
|
|
|
|
|
+ const uint3 add_ncols_packed = init_fastdiv_values(add_ncols);
|
|
|
|
|
+ const uint3 add_nrows_packed = init_fastdiv_values(add_nrows);
|
|
|
|
|
+ const uint3 add_nchannels_packed = init_fastdiv_values(add_nchannels);
|
|
|
|
|
+ const uint3 add_nsamples_packed = init_fastdiv_values(add_nsamples);
|
|
|
if (ncols < 1024) {
|
|
if (ncols < 1024) {
|
|
|
- const dim3 block_dims(WARP_SIZE, 1, 1);
|
|
|
|
|
- rms_norm_f32<WARP_SIZE, true, true><<<blocks_num, block_dims, 0, stream>>>(x, dst,
|
|
|
|
|
- ncols, stride_row, stride_channel, stride_sample, eps,
|
|
|
|
|
- mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
|
|
|
|
|
- mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
|
|
|
|
|
- add, add_stride_row, add_stride_channel, add_stride_sample,
|
|
|
|
|
- add_ncols, add_nrows, add_nchannels, add_nsamples);
|
|
|
|
|
|
|
+ const dim3 block_dims(256, 1, 1);
|
|
|
|
|
+ rms_norm_f32<256, true, true><<<blocks_num, block_dims, 0, stream>>>(
|
|
|
|
|
+ x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
|
|
|
|
|
+ mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
|
|
|
|
|
+ add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
|
|
|
|
|
+ add_nchannels_packed, add_nsamples_packed);
|
|
|
} else {
|
|
} else {
|
|
|
const dim3 block_dims(1024, 1, 1);
|
|
const dim3 block_dims(1024, 1, 1);
|
|
|
- rms_norm_f32<1024, true, true><<<blocks_num, block_dims, 0, stream>>>(x, dst,
|
|
|
|
|
- ncols, stride_row, stride_channel, stride_sample, eps,
|
|
|
|
|
- mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
|
|
|
|
|
- mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
|
|
|
|
|
- add, add_stride_row, add_stride_channel, add_stride_sample,
|
|
|
|
|
- add_ncols, add_nrows, add_nchannels, add_nsamples);
|
|
|
|
|
|
|
+ rms_norm_f32<1024, true, true><<<blocks_num, block_dims, 0, stream>>>(
|
|
|
|
|
+ x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
|
|
|
|
|
+ mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
|
|
|
|
|
+ add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
|
|
|
|
|
+ add_nchannels_packed, add_nsamples_packed);
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|