|
|
@@ -104,10 +104,12 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-template <int block_size>
|
|
|
+template <int block_size, bool do_multiply = false>
|
|
|
static __global__ void rms_norm_f32(
|
|
|
const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
|
|
|
- const int64_t stride_sample, const float eps) {
|
|
|
+ const int64_t stride_sample, const float eps, const float * mul = nullptr, const int64_t mul_stride_row = 0,
|
|
|
+ const int64_t mul_stride_channel = 0, const int64_t mul_stride_sample = 0, const int mul_ncols = 0,
|
|
|
+ const int mul_nrows = 0, const int mul_nchannels = 0, const int mul_nsamples = 0) {
|
|
|
const int nrows = gridDim.x;
|
|
|
const int nchannels = gridDim.y;
|
|
|
|
|
|
@@ -119,6 +121,13 @@ static __global__ void rms_norm_f32(
|
|
|
x += sample*stride_sample + channel*stride_channel + row*stride_row;
|
|
|
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
|
|
|
|
|
|
+ if constexpr (do_multiply) {
|
|
|
+ const int mul_row = row % mul_nrows;
|
|
|
+ const int mul_channel = channel % mul_nchannels;
|
|
|
+ const int mul_sample = sample % mul_nsamples;
|
|
|
+ mul += mul_sample*mul_stride_sample + mul_channel*mul_stride_channel + mul_row*mul_stride_row;
|
|
|
+ }
|
|
|
+
|
|
|
float tmp = 0.0f; // partial sum for thread in warp
|
|
|
|
|
|
for (int col = tid; col < ncols; col += block_size) {
|
|
|
@@ -145,7 +154,12 @@ static __global__ void rms_norm_f32(
|
|
|
const float scale = rsqrtf(mean + eps);
|
|
|
|
|
|
for (int col = tid; col < ncols; col += block_size) {
|
|
|
- dst[col] = scale * x[col];
|
|
|
+ if constexpr (do_multiply) {
|
|
|
+ const int mul_col = col % mul_ncols;
|
|
|
+ dst[col] = scale * x[col] * mul[mul_col];
|
|
|
+ } else {
|
|
|
+ dst[col] = scale * x[col];
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -310,10 +324,30 @@ static void rms_norm_f32_cuda(
|
|
|
const dim3 blocks_num(nrows, nchannels, nsamples);
|
|
|
if (ncols < 1024) {
|
|
|
const dim3 block_dims(WARP_SIZE, 1, 1);
|
|
|
- rms_norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
|
|
|
+ rms_norm_f32<WARP_SIZE, false><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
|
|
|
+ } else {
|
|
|
+ const dim3 block_dims(1024, 1, 1);
|
|
|
+ rms_norm_f32<1024, false><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+static void rms_norm_mul_f32_cuda(
|
|
|
+ const float * x, const float * mul, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
|
|
|
+ const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample,
|
|
|
+ const int64_t mul_stride_row, const int64_t mul_stride_channel, const int64_t mul_stride_sample,
|
|
|
+ const int mul_ncols, const int mul_nrows, const int mul_nchannels, const int mul_nsamples,
|
|
|
+ const float eps, cudaStream_t stream) {
|
|
|
+ const dim3 blocks_num(nrows, nchannels, nsamples);
|
|
|
+ if (mul == nullptr) {
|
|
|
+ rms_norm_f32_cuda(x, dst, ncols, nrows, nchannels, nsamples, stride_row, stride_channel, stride_sample, eps, stream);
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ if (ncols < 1024) {
|
|
|
+ const dim3 block_dims(WARP_SIZE, 1, 1);
|
|
|
+ rms_norm_f32<WARP_SIZE, true><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
|
|
|
} else {
|
|
|
const dim3 block_dims(1024, 1, 1);
|
|
|
- rms_norm_f32<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
|
|
|
+ rms_norm_f32<1024, true><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -407,6 +441,59 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
rms_norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
|
|
|
}
|
|
|
|
|
|
+void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * mul_tensor) {
|
|
|
+ const ggml_tensor * rms_norm_src = (ggml_tensor *) dst->src[0];
|
|
|
+ float eps = 0.0f;
|
|
|
+
|
|
|
+ memcpy(&eps, dst->op_params, sizeof(float));
|
|
|
+
|
|
|
+ const float * src0_d = (const float *) rms_norm_src->data;
|
|
|
+ const float * mul_d = nullptr;
|
|
|
+ const ggml_tensor * mul_src = nullptr;
|
|
|
+
|
|
|
+ if (mul_tensor->src[0] == dst) {
|
|
|
+ mul_d = (float *) mul_tensor->src[1]->data;
|
|
|
+ mul_src = mul_tensor->src[1];
|
|
|
+ } else if(mul_tensor->src[1] == dst) {
|
|
|
+ mul_d = (float *) mul_tensor->src[0]->data;
|
|
|
+ mul_src = mul_tensor->src[0];
|
|
|
+ } else {
|
|
|
+ GGML_ASSERT(false);
|
|
|
+ }
|
|
|
+
|
|
|
+ float * dst_d = (float *) mul_tensor->data;
|
|
|
+ cudaStream_t stream = ctx.stream();
|
|
|
+
|
|
|
+ GGML_ASSERT(rms_norm_src->type == GGML_TYPE_F32);
|
|
|
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
|
|
+ GGML_ASSERT(mul_tensor->type == GGML_TYPE_F32);
|
|
|
+ GGML_ASSERT(eps >= 0.0f);
|
|
|
+
|
|
|
+ const int64_t ne00 = rms_norm_src->ne[0];
|
|
|
+ const int64_t ne01 = rms_norm_src->ne[1];
|
|
|
+ const int64_t ne02 = rms_norm_src->ne[2];
|
|
|
+ const int64_t ne03 = rms_norm_src->ne[3];
|
|
|
+
|
|
|
+ const size_t ts0 = ggml_type_size(rms_norm_src->type);
|
|
|
+ GGML_ASSERT(rms_norm_src->nb[0] == ts0);
|
|
|
+ const int64_t s01 = rms_norm_src->nb[1] / ts0;
|
|
|
+ const int64_t s02 = rms_norm_src->nb[2] / ts0;
|
|
|
+ const int64_t s03 = rms_norm_src->nb[3] / ts0;
|
|
|
+
|
|
|
+ const size_t ts_mul = ggml_type_size(mul_src->type);
|
|
|
+ GGML_ASSERT(mul_src->nb[0] == ts_mul);
|
|
|
+ const int64_t mul_s01 = mul_src->nb[1] / ts_mul;
|
|
|
+ const int64_t mul_s02 = mul_src->nb[2] / ts_mul;
|
|
|
+ const int64_t mul_s03 = mul_src->nb[3] / ts_mul;
|
|
|
+
|
|
|
+ const int mul_ncols = mul_src->ne[0];
|
|
|
+ const int mul_nrows = mul_src->ne[1];
|
|
|
+ const int mul_nchannels = mul_src->ne[2];
|
|
|
+ const int mul_nsamples = mul_src->ne[3];
|
|
|
+
|
|
|
+ rms_norm_mul_f32_cuda(src0_d, mul_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, mul_s01, mul_s02, mul_s03, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, eps, stream);
|
|
|
+}
|
|
|
+
|
|
|
void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
const ggml_tensor * grad = dst->src[0]; // gradients
|
|
|
const ggml_tensor * src0f = dst->src[1]; // src0 from forward pass
|