unary.cu 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. #include "unary.cuh"
  2. static __global__ void gelu_f32(const float * x, float * dst, const int k) {
  3. const float GELU_COEF_A = 0.044715f;
  4. const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
  5. const int i = blockDim.x*blockIdx.x + threadIdx.x;
  6. if (i >= k) {
  7. return;
  8. }
  9. float xi = x[i];
  10. dst[i] = 0.5f*xi*(1.0f + tanhf(SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi)));
  11. }
  12. static __global__ void gelu_quick_f32(const float * x, float * dst, int k) {
  13. const float GELU_QUICK_COEF = -1.702f;
  14. const int i = blockDim.x*blockIdx.x + threadIdx.x;
  15. if (i >= k) {
  16. return;
  17. }
  18. dst[i] = x[i] * (1.0f / (1.0f + expf(GELU_QUICK_COEF * x[i])));
  19. }
  20. static __global__ void silu_f32(const float * x, float * dst, const int k) {
  21. const int i = blockDim.x*blockIdx.x + threadIdx.x;
  22. if (i >= k) {
  23. return;
  24. }
  25. dst[i] = x[i] / (1.0f + expf(-x[i]));
  26. }
  27. static __global__ void tanh_f32(const float * x, float * dst, int k) {
  28. const int i = blockDim.x*blockIdx.x + threadIdx.x;
  29. if (i >= k) {
  30. return;
  31. }
  32. dst[i] = tanhf(x[i]);
  33. }
  34. static __global__ void relu_f32(const float * x, float * dst, const int k) {
  35. const int i = blockDim.x*blockIdx.x + threadIdx.x;
  36. if (i >= k) {
  37. return;
  38. }
  39. dst[i] = fmaxf(x[i], 0);
  40. }
  41. static __global__ void hardsigmoid_f32(const float * x, float * dst, const int k) {
  42. const int i = blockDim.x*blockIdx.x + threadIdx.x;
  43. if (i >= k) {
  44. return;
  45. }
  46. dst[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f));
  47. }
  48. static __global__ void hardswish_f32(const float * x, float * dst, const int k) {
  49. const int i = blockDim.x*blockIdx.x + threadIdx.x;
  50. if (i >= k) {
  51. return;
  52. }
  53. dst[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f));
  54. }
  55. static __global__ void leaky_relu_f32(const float * x, float * dst, const int k, const float negative_slope) {
  56. const int i = blockDim.x*blockIdx.x + threadIdx.x;
  57. if (i >= k) {
  58. return;
  59. }
  60. dst[i] = fmaxf(x[i], 0) + fminf(x[i], 0.0f) * negative_slope;
  61. }
  62. static __global__ void sqr_f32(const float * x, float * dst, const int k) {
  63. const int i = blockDim.x*blockIdx.x + threadIdx.x;
  64. if (i >= k) {
  65. return;
  66. }
  67. dst[i] = x[i] * x[i];
  68. }
  69. static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
  70. const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
  71. gelu_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
  72. }
  73. static void gelu_quick_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
  74. const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
  75. gelu_quick_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
  76. }
  77. static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
  78. const int num_blocks = (k + CUDA_SILU_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE;
  79. silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
  80. }
  81. static void tanh_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
  82. const int num_blocks = (k + CUDA_TANH_BLOCK_SIZE - 1) / CUDA_TANH_BLOCK_SIZE;
  83. tanh_f32<<<num_blocks, CUDA_TANH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
  84. }
  85. static void relu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
  86. const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
  87. relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
  88. }
  89. static void hardsigmoid_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
  90. const int num_blocks = (k + CUDA_HARDSIGMOID_BLOCK_SIZE - 1) / CUDA_HARDSIGMOID_BLOCK_SIZE;
  91. hardsigmoid_f32<<<num_blocks, CUDA_HARDSIGMOID_BLOCK_SIZE, 0, stream>>>(x, dst, k);
  92. }
  93. static void hardswish_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
  94. const int num_blocks = (k + CUDA_HARDSWISH_BLOCK_SIZE - 1) / CUDA_HARDSWISH_BLOCK_SIZE;
  95. hardswish_f32<<<num_blocks, CUDA_HARDSWISH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
  96. }
  97. static void leaky_relu_f32_cuda(const float * x, float * dst, const int k, const float negative_slope, cudaStream_t stream) {
  98. const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
  99. leaky_relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k, negative_slope);
  100. }
  101. static void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
  102. const int num_blocks = (k + CUDA_SQR_BLOCK_SIZE - 1) / CUDA_SQR_BLOCK_SIZE;
  103. sqr_f32<<<num_blocks, CUDA_SQR_BLOCK_SIZE, 0, stream>>>(x, dst, k);
  104. }
  105. void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  106. const ggml_tensor * src0 = dst->src[0];
  107. const float * src0_d = (const float *)src0->data;
  108. float * dst_d = (float *)dst->data;
  109. cudaStream_t stream = ctx.stream();
  110. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  111. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  112. gelu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
  113. }
  114. void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  115. const ggml_tensor * src0 = dst->src[0];
  116. const float * src0_d = (const float *)src0->data;
  117. float * dst_d = (float *)dst->data;
  118. cudaStream_t stream = ctx.stream();
  119. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  120. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  121. silu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
  122. }
  123. void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  124. const ggml_tensor * src0 = dst->src[0];
  125. const float * src0_d = (const float *)src0->data;
  126. float * dst_d = (float *)dst->data;
  127. cudaStream_t stream = ctx.stream();
  128. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  129. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  130. gelu_quick_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
  131. }
  132. void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  133. const ggml_tensor * src0 = dst->src[0];
  134. const float * src0_d = (const float *)src0->data;
  135. float * dst_d = (float *)dst->data;
  136. cudaStream_t stream = ctx.stream();
  137. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  138. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  139. tanh_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
  140. }
  141. void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  142. const ggml_tensor * src0 = dst->src[0];
  143. const float * src0_d = (const float *)src0->data;
  144. float * dst_d = (float *)dst->data;
  145. cudaStream_t stream = ctx.stream();
  146. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  147. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  148. relu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
  149. }
  150. void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  151. const ggml_tensor * src0 = dst->src[0];
  152. const float * src0_d = (const float *)src0->data;
  153. float * dst_d = (float *)dst->data;
  154. cudaStream_t stream = ctx.stream();
  155. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  156. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  157. hardsigmoid_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
  158. }
  159. void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  160. const ggml_tensor * src0 = dst->src[0];
  161. const float * src0_d = (const float *)src0->data;
  162. float * dst_d = (float *)dst->data;
  163. cudaStream_t stream = ctx.stream();
  164. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  165. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  166. hardswish_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
  167. }
  168. void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  169. const ggml_tensor * src0 = dst->src[0];
  170. const float * src0_d = (const float *)src0->data;
  171. float * dst_d = (float *)dst->data;
  172. cudaStream_t stream = ctx.stream();
  173. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  174. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  175. float negative_slope;
  176. memcpy(&negative_slope, dst->op_params, sizeof(float));
  177. leaky_relu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), negative_slope, stream);
  178. }
  179. void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  180. const ggml_tensor * src0 = dst->src[0];
  181. const float * src0_d = (const float *)src0->data;
  182. float * dst_d = (float *)dst->data;
  183. cudaStream_t stream = ctx.stream();
  184. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  185. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  186. sqr_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
  187. }