unary-ops.cpp 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. #include "unary-ops.h"
  2. static inline float op_abs(float x) {
  3. return fabsf(x);
  4. }
  5. static inline float op_sgn(float x) {
  6. return (x > 0.f) ? 1.f : ((x < 0.f) ? -1.f : 0.f);
  7. }
  8. static inline float op_neg(float x) {
  9. return -x;
  10. }
  11. static inline float op_step(float x) {
  12. return (x > 0.f) ? 1.f : 0.f;
  13. }
  14. static inline float op_tanh(float x) {
  15. return tanhf(x);
  16. }
  17. static inline float op_elu(float x) {
  18. return (x > 0.f) ? x : expm1f(x);
  19. }
  20. static inline float op_relu(float x) {
  21. return (x > 0.f) ? x : 0.f;
  22. }
  23. static inline float op_sigmoid(float x) {
  24. return 1.f / (1.f + expf(-x));
  25. }
  26. static inline float op_hardsigmoid(float x) {
  27. return fminf(1.0f, fmaxf(0.0f, (x + 3.0f) / 6.0f));
  28. }
  29. static inline float op_exp(float x) {
  30. return expf(x);
  31. }
  32. static inline float op_hardswish(float x) {
  33. return x * fminf(1.0f, fmaxf(0.0f, (x + 3.0f) / 6.0f));
  34. }
  35. static inline float op_sqr(float x) {
  36. return x * x;
  37. }
  38. static inline float op_sqrt(float x) {
  39. return sqrtf(x);
  40. }
  41. static inline float op_xielu(float x, float alpha_n, float alpha_p, float beta, float eps) {
  42. if (x > 0.0f) {
  43. return alpha_p * x * x + beta * x;
  44. } else {
  45. const float min_x_eps = fminf(x, eps);
  46. return (expm1f(min_x_eps) - x) * alpha_n + beta * x;
  47. }
  48. }
  49. static inline float op_sin(float x) {
  50. return sinf(x);
  51. }
  52. static inline float op_cos(float x) {
  53. return cosf(x);
  54. }
  55. static inline float op_log(float x) {
  56. return logf(x);
  57. }
  58. template <float (*op)(float), typename src0_t, typename dst_t>
  59. static inline void vec_unary_op(int64_t n, dst_t * y, const src0_t * x) {
  60. constexpr auto src0_to_f32 = type_conversion_table<src0_t>::to_f32;
  61. constexpr auto f32_to_dst = type_conversion_table<dst_t >::from_f32;
  62. for (int i = 0; i < n; i++) {
  63. y[i] = f32_to_dst(op(src0_to_f32(x[i])));
  64. }
  65. }
  66. template <float (*op)(float), typename src0_t, typename dst_t>
  67. static void apply_unary_op(const ggml_compute_params * params, ggml_tensor * dst) {
  68. const ggml_tensor * src0 = dst->src[0];
  69. GGML_ASSERT(ggml_is_contiguous_1(src0) && ggml_is_contiguous_1(dst) && ggml_are_same_shape(src0, dst));
  70. GGML_TENSOR_UNARY_OP_LOCALS
  71. GGML_ASSERT( nb0 == sizeof(dst_t));
  72. GGML_ASSERT(nb00 == sizeof(src0_t));
  73. const auto [ir0, ir1] = get_thread_range(params, src0);
  74. for (int64_t ir = ir0; ir < ir1; ++ir) {
  75. const int64_t i03 = ir/(ne02*ne01);
  76. const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
  77. const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
  78. dst_t * dst_ptr = (dst_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
  79. const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
  80. vec_unary_op<op>(ne0, dst_ptr, src0_ptr);
  81. }
  82. }
  83. // TODO: Use the 'traits' lookup table (for type conversion fns), instead of a mass of 'if' conditions with long templates
  84. template <float (*op)(float)>
  85. static void unary_op(const ggml_compute_params * params, ggml_tensor * dst) {
  86. const ggml_tensor * src0 = dst->src[0];
  87. /* */ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { // all f32
  88. apply_unary_op<op, float, float>(params, dst);
  89. } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { // all f16
  90. apply_unary_op<op, ggml_fp16_t, ggml_fp16_t>(params, dst);
  91. } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_BF16) { // all bf16
  92. apply_unary_op<op, ggml_bf16_t, ggml_bf16_t>(params, dst);
  93. } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_F32) {
  94. apply_unary_op<op, ggml_bf16_t, float>(params, dst);
  95. } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
  96. apply_unary_op<op, ggml_fp16_t, float>(params, dst);
  97. } else {
  98. fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s\n", __func__,
  99. ggml_type_name(dst->type), ggml_type_name(src0->type));
  100. GGML_ABORT("fatal error");
  101. }
  102. }
  103. template <float (*op)(float, ggml_tensor *)>
  104. static void unary_op_params(const ggml_compute_params * params, ggml_tensor * dst) {
  105. const ggml_tensor * src0 = dst->src[0];
  106. /* */ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { // all f32
  107. apply_unary_op<op, float, float>(params, dst);
  108. } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { // all f16
  109. apply_unary_op<op, ggml_fp16_t, ggml_fp16_t>(params, dst);
  110. } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_BF16) { // all bf16
  111. apply_unary_op<op, ggml_bf16_t, ggml_bf16_t>(params, dst);
  112. } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_F32) {
  113. apply_unary_op<op, ggml_bf16_t, float>(params, dst);
  114. } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
  115. apply_unary_op<op, ggml_fp16_t, float>(params, dst);
  116. } else {
  117. fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s\n", __func__,
  118. ggml_type_name(dst->type), ggml_type_name(src0->type));
  119. GGML_ABORT("fatal error");
  120. }
  121. }
  122. // Extend vec_unary_op to support functors
  123. template <typename Op, typename src0_t, typename dst_t>
  124. static inline void vec_unary_op_functor(int64_t n, dst_t * y, const src0_t * x, Op op) {
  125. constexpr auto src0_to_f32 = type_conversion_table<src0_t>::to_f32;
  126. constexpr auto f32_to_dst = type_conversion_table<dst_t >::from_f32;
  127. for (int i = 0; i < n; i++) {
  128. y[i] = f32_to_dst(op(src0_to_f32(x[i])));
  129. }
  130. }
  131. // Extend apply_unary_op to support functors
  132. template <typename Op, typename src0_t, typename dst_t>
  133. static void apply_unary_op_functor(const ggml_compute_params * params, ggml_tensor * dst, Op op) {
  134. const ggml_tensor * src0 = dst->src[0];
  135. GGML_ASSERT(ggml_is_contiguous_1(src0) && ggml_is_contiguous_1(dst) && ggml_are_same_shape(src0, dst));
  136. GGML_TENSOR_UNARY_OP_LOCALS
  137. GGML_ASSERT( nb0 == sizeof(dst_t));
  138. GGML_ASSERT(nb00 == sizeof(src0_t));
  139. const auto [ir0, ir1] = get_thread_range(params, src0);
  140. for (int64_t ir = ir0; ir < ir1; ++ir) {
  141. const int64_t i03 = ir/(ne02*ne01);
  142. const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
  143. const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
  144. dst_t * dst_ptr = (dst_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
  145. const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
  146. vec_unary_op_functor(ne0, dst_ptr, src0_ptr, op);
  147. }
  148. }
  149. // Generic dispatcher for functors
  150. template <typename Op>
  151. static void unary_op_functor(const ggml_compute_params * params, ggml_tensor * dst, Op op) {
  152. const ggml_tensor * src0 = dst->src[0];
  153. /* */ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { // all f32
  154. apply_unary_op_functor<Op, float, float>(params, dst, op);
  155. } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { // all f16
  156. apply_unary_op_functor<Op, ggml_fp16_t, ggml_fp16_t>(params, dst, op);
  157. } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_BF16) { // all bf16
  158. apply_unary_op_functor<Op, ggml_bf16_t, ggml_bf16_t>(params, dst, op);
  159. } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_F32) {
  160. apply_unary_op_functor<Op, ggml_bf16_t, float>(params, dst, op);
  161. } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
  162. apply_unary_op_functor<Op, ggml_fp16_t, float>(params, dst, op);
  163. } else {
  164. fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s\n", __func__,
  165. ggml_type_name(dst->type), ggml_type_name(src0->type));
  166. GGML_ABORT("fatal error");
  167. }
  168. }
  169. void ggml_compute_forward_abs(const ggml_compute_params * params, ggml_tensor * dst) {
  170. unary_op<op_abs>(params, dst);
  171. }
  172. void ggml_compute_forward_sgn(const ggml_compute_params * params, ggml_tensor * dst) {
  173. unary_op<op_sgn>(params, dst);
  174. }
  175. void ggml_compute_forward_neg(const ggml_compute_params * params, ggml_tensor * dst) {
  176. unary_op<op_neg>(params, dst);
  177. }
  178. void ggml_compute_forward_step(const ggml_compute_params * params, ggml_tensor * dst) {
  179. unary_op<op_step>(params, dst);
  180. }
  181. void ggml_compute_forward_tanh(const ggml_compute_params * params, ggml_tensor * dst) {
  182. unary_op<op_tanh>(params, dst);
  183. }
  184. void ggml_compute_forward_elu(const ggml_compute_params * params, ggml_tensor * dst) {
  185. unary_op<op_elu>(params, dst);
  186. }
  187. void ggml_compute_forward_relu(const ggml_compute_params * params, ggml_tensor * dst) {
  188. unary_op<op_relu>(params, dst);
  189. }
  190. void ggml_compute_forward_sigmoid(const ggml_compute_params * params, ggml_tensor * dst) {
  191. unary_op<op_sigmoid>(params, dst);
  192. }
  193. void ggml_compute_forward_hardsigmoid(const ggml_compute_params * params, ggml_tensor * dst) {
  194. unary_op<op_hardsigmoid>(params, dst);
  195. }
  196. void ggml_compute_forward_exp(const ggml_compute_params * params, ggml_tensor * dst) {
  197. unary_op<op_exp>(params, dst);
  198. }
  199. void ggml_compute_forward_hardswish(const ggml_compute_params * params, ggml_tensor * dst) {
  200. unary_op<op_hardswish>(params, dst);
  201. }
  202. void ggml_compute_forward_sqr(const ggml_compute_params * params, ggml_tensor * dst) {
  203. unary_op<op_sqr>(params, dst);
  204. }
  205. void ggml_compute_forward_sqrt(const ggml_compute_params * params, ggml_tensor * dst) {
  206. unary_op<op_sqrt>(params, dst);
  207. }
  208. void ggml_compute_forward_sin(const ggml_compute_params * params, ggml_tensor * dst) {
  209. unary_op<op_sin>(params, dst);
  210. }
  211. void ggml_compute_forward_cos(const ggml_compute_params * params, ggml_tensor * dst) {
  212. unary_op<op_cos>(params, dst);
  213. }
  214. void ggml_compute_forward_log(const ggml_compute_params * params, ggml_tensor * dst) {
  215. unary_op<op_log>(params, dst);
  216. }
  217. void ggml_compute_forward_xielu(const ggml_compute_params * params, ggml_tensor * dst) {
  218. const float alpha_n = ggml_get_op_params_f32(dst, 1);
  219. const float alpha_p = ggml_get_op_params_f32(dst, 2);
  220. const float beta = ggml_get_op_params_f32(dst, 3);
  221. const float eps = ggml_get_op_params_f32(dst, 4);
  222. const auto xielu_op_params = [alpha_n, alpha_p, beta, eps](float f) {
  223. return op_xielu(f, alpha_n, alpha_p, beta, eps);
  224. };
  225. unary_op_functor(params, dst, xielu_op_params);
  226. }