|
@@ -95,6 +95,15 @@ static __global__ void hardswish_f32(const float * x, float * dst, const int k)
|
|
|
dst[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f));
|
|
dst[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f));
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+static __global__ void exp_f32(const float * x, float * dst, const int k) {
|
|
|
|
|
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
|
|
|
|
+
|
|
|
|
|
+ if (i >= k) {
|
|
|
|
|
+ return;
|
|
|
|
|
+ }
|
|
|
|
|
+ dst[i] = expf(x[i]);
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
static __global__ void leaky_relu_f32(const float * x, float * dst, const int k, const float negative_slope) {
|
|
static __global__ void leaky_relu_f32(const float * x, float * dst, const int k, const float negative_slope) {
|
|
|
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
|
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
|
|
if (i >= k) {
|
|
if (i >= k) {
|
|
@@ -189,6 +198,11 @@ static void hardswish_f32_cuda(const float * x, float * dst, const int k, cudaSt
|
|
|
hardswish_f32<<<num_blocks, CUDA_HARDSWISH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
|
hardswish_f32<<<num_blocks, CUDA_HARDSWISH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+static void exp_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
|
|
|
|
+ const int num_blocks = (k + CUDA_EXP_BLOCK_SIZE - 1) / CUDA_EXP_BLOCK_SIZE;
|
|
|
|
|
+ exp_f32<<<num_blocks, CUDA_EXP_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
static void leaky_relu_f32_cuda(const float * x, float * dst, const int k, const float negative_slope, cudaStream_t stream) {
|
|
static void leaky_relu_f32_cuda(const float * x, float * dst, const int k, const float negative_slope, cudaStream_t stream) {
|
|
|
const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
|
|
const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
|
|
|
leaky_relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k, negative_slope);
|
|
leaky_relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k, negative_slope);
|
|
@@ -354,6 +368,20 @@ void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
|
|
|
hardswish_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
|
|
hardswish_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
|
|
+ const ggml_tensor * src0 = dst->src[0];
|
|
|
|
|
+ const float * src0_d = (const float *)src0->data;
|
|
|
|
|
+ float * dst_d = (float *)dst->data;
|
|
|
|
|
+ cudaStream_t stream = ctx.stream();
|
|
|
|
|
+
|
|
|
|
|
+ GGML_ASSERT(ggml_is_contiguous(src0));
|
|
|
|
|
+
|
|
|
|
|
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
|
|
|
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
|
|
|
|
+
|
|
|
|
|
+ exp_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
const ggml_tensor * src0 = dst->src[0];
|
|
const ggml_tensor * src0 = dst->src[0];
|
|
|
const float * src0_d = (const float *)src0->data;
|
|
const float * src0_d = (const float *)src0->data;
|