unary.cu 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286
  1. #include "unary.cuh"
  2. static __global__ void gelu_f32(const float * x, float * dst, const int k) {
  3. const float GELU_COEF_A = 0.044715f;
  4. const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
  5. const int i = blockDim.x*blockIdx.x + threadIdx.x;
  6. if (i >= k) {
  7. return;
  8. }
  9. float xi = x[i];
  10. dst[i] = 0.5f*xi*(1.0f + tanhf(SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi)));
  11. }
  12. static __global__ void gelu_quick_f32(const float * x, float * dst, int k) {
  13. const float GELU_QUICK_COEF = -1.702f;
  14. const int i = blockDim.x*blockIdx.x + threadIdx.x;
  15. if (i >= k) {
  16. return;
  17. }
  18. dst[i] = x[i] * (1.0f / (1.0f + expf(GELU_QUICK_COEF * x[i])));
  19. }
  20. static __global__ void silu_f32(const float * x, float * dst, const int k) {
  21. const int i = blockDim.x*blockIdx.x + threadIdx.x;
  22. if (i >= k) {
  23. return;
  24. }
  25. dst[i] = x[i] / (1.0f + expf(-x[i]));
  26. }
  27. static __global__ void tanh_f32(const float * x, float * dst, int k) {
  28. const int i = blockDim.x*blockIdx.x + threadIdx.x;
  29. if (i >= k) {
  30. return;
  31. }
  32. dst[i] = tanhf(x[i]);
  33. }
  34. static __global__ void relu_f32(const float * x, float * dst, const int k) {
  35. const int i = blockDim.x*blockIdx.x + threadIdx.x;
  36. if (i >= k) {
  37. return;
  38. }
  39. dst[i] = fmaxf(x[i], 0);
  40. }
  41. static __global__ void sigmoid_f32(const float * x, float * dst, const int k) {
  42. const int i = blockDim.x*blockIdx.x + threadIdx.x;
  43. if (i >= k) {
  44. return;
  45. }
  46. dst[i] = 1.0f / (1.0f + expf(-x[i]));
  47. }
  48. static __global__ void hardsigmoid_f32(const float * x, float * dst, const int k) {
  49. const int i = blockDim.x*blockIdx.x + threadIdx.x;
  50. if (i >= k) {
  51. return;
  52. }
  53. dst[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f));
  54. }
  55. static __global__ void hardswish_f32(const float * x, float * dst, const int k) {
  56. const int i = blockDim.x*blockIdx.x + threadIdx.x;
  57. if (i >= k) {
  58. return;
  59. }
  60. dst[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f));
  61. }
  62. static __global__ void leaky_relu_f32(const float * x, float * dst, const int k, const float negative_slope) {
  63. const int i = blockDim.x*blockIdx.x + threadIdx.x;
  64. if (i >= k) {
  65. return;
  66. }
  67. dst[i] = fmaxf(x[i], 0) + fminf(x[i], 0.0f) * negative_slope;
  68. }
  69. static __global__ void sqr_f32(const float * x, float * dst, const int k) {
  70. const int i = blockDim.x*blockIdx.x + threadIdx.x;
  71. if (i >= k) {
  72. return;
  73. }
  74. dst[i] = x[i] * x[i];
  75. }
  76. static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
  77. const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
  78. gelu_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
  79. }
  80. static void gelu_quick_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
  81. const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
  82. gelu_quick_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
  83. }
  84. static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
  85. const int num_blocks = (k + CUDA_SILU_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE;
  86. silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
  87. }
  88. static void tanh_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
  89. const int num_blocks = (k + CUDA_TANH_BLOCK_SIZE - 1) / CUDA_TANH_BLOCK_SIZE;
  90. tanh_f32<<<num_blocks, CUDA_TANH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
  91. }
  92. static void relu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
  93. const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
  94. relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
  95. }
  96. static void sigmoid_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
  97. const int num_blocks = (k + CUDA_SIGMOID_BLOCK_SIZE - 1) / CUDA_SIGMOID_BLOCK_SIZE;
  98. sigmoid_f32<<<num_blocks, CUDA_SIGMOID_BLOCK_SIZE, 0, stream>>>(x, dst, k);
  99. }
  100. static void hardsigmoid_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
  101. const int num_blocks = (k + CUDA_HARDSIGMOID_BLOCK_SIZE - 1) / CUDA_HARDSIGMOID_BLOCK_SIZE;
  102. hardsigmoid_f32<<<num_blocks, CUDA_HARDSIGMOID_BLOCK_SIZE, 0, stream>>>(x, dst, k);
  103. }
  104. static void hardswish_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
  105. const int num_blocks = (k + CUDA_HARDSWISH_BLOCK_SIZE - 1) / CUDA_HARDSWISH_BLOCK_SIZE;
  106. hardswish_f32<<<num_blocks, CUDA_HARDSWISH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
  107. }
  108. static void leaky_relu_f32_cuda(const float * x, float * dst, const int k, const float negative_slope, cudaStream_t stream) {
  109. const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
  110. leaky_relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k, negative_slope);
  111. }
  112. static void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
  113. const int num_blocks = (k + CUDA_SQR_BLOCK_SIZE - 1) / CUDA_SQR_BLOCK_SIZE;
  114. sqr_f32<<<num_blocks, CUDA_SQR_BLOCK_SIZE, 0, stream>>>(x, dst, k);
  115. }
  116. void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  117. const ggml_tensor * src0 = dst->src[0];
  118. const float * src0_d = (const float *)src0->data;
  119. float * dst_d = (float *)dst->data;
  120. cudaStream_t stream = ctx.stream();
  121. GGML_ASSERT(ggml_is_contiguous(src0));
  122. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  123. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  124. gelu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
  125. }
  126. void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  127. const ggml_tensor * src0 = dst->src[0];
  128. const float * src0_d = (const float *)src0->data;
  129. float * dst_d = (float *)dst->data;
  130. cudaStream_t stream = ctx.stream();
  131. GGML_ASSERT(ggml_is_contiguous(src0));
  132. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  133. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  134. silu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
  135. }
  136. void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  137. const ggml_tensor * src0 = dst->src[0];
  138. const float * src0_d = (const float *)src0->data;
  139. float * dst_d = (float *)dst->data;
  140. cudaStream_t stream = ctx.stream();
  141. GGML_ASSERT(ggml_is_contiguous(src0));
  142. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  143. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  144. gelu_quick_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
  145. }
  146. void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  147. const ggml_tensor * src0 = dst->src[0];
  148. const float * src0_d = (const float *)src0->data;
  149. float * dst_d = (float *)dst->data;
  150. cudaStream_t stream = ctx.stream();
  151. GGML_ASSERT(ggml_is_contiguous(src0));
  152. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  153. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  154. tanh_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
  155. }
  156. void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  157. const ggml_tensor * src0 = dst->src[0];
  158. const float * src0_d = (const float *)src0->data;
  159. float * dst_d = (float *)dst->data;
  160. cudaStream_t stream = ctx.stream();
  161. GGML_ASSERT(ggml_is_contiguous(src0));
  162. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  163. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  164. relu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
  165. }
  166. void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  167. const ggml_tensor * src0 = dst->src[0];
  168. const float * src0_d = (const float *)src0->data;
  169. float * dst_d = (float *)dst->data;
  170. cudaStream_t stream = ctx.stream();
  171. GGML_ASSERT(ggml_is_contiguous(src0));
  172. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  173. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  174. sigmoid_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
  175. }
  176. void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  177. const ggml_tensor * src0 = dst->src[0];
  178. const float * src0_d = (const float *)src0->data;
  179. float * dst_d = (float *)dst->data;
  180. cudaStream_t stream = ctx.stream();
  181. GGML_ASSERT(ggml_is_contiguous(src0));
  182. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  183. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  184. hardsigmoid_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
  185. }
  186. void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  187. const ggml_tensor * src0 = dst->src[0];
  188. const float * src0_d = (const float *)src0->data;
  189. float * dst_d = (float *)dst->data;
  190. cudaStream_t stream = ctx.stream();
  191. GGML_ASSERT(ggml_is_contiguous(src0));
  192. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  193. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  194. hardswish_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
  195. }
  196. void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  197. const ggml_tensor * src0 = dst->src[0];
  198. const float * src0_d = (const float *)src0->data;
  199. float * dst_d = (float *)dst->data;
  200. cudaStream_t stream = ctx.stream();
  201. GGML_ASSERT(ggml_is_contiguous(src0));
  202. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  203. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  204. float negative_slope;
  205. memcpy(&negative_slope, dst->op_params, sizeof(float));
  206. leaky_relu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), negative_slope, stream);
  207. }
  208. void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  209. const ggml_tensor * src0 = dst->src[0];
  210. const float * src0_d = (const float *)src0->data;
  211. float * dst_d = (float *)dst->data;
  212. cudaStream_t stream = ctx.stream();
  213. GGML_ASSERT(ggml_is_contiguous(src0));
  214. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  215. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  216. sqr_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
  217. }