| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289 |
- #include "unary-ops.h"
- static inline float op_abs(float x) {
- return fabsf(x);
- }
- static inline float op_sgn(float x) {
- return (x > 0.f) ? 1.f : ((x < 0.f) ? -1.f : 0.f);
- }
- static inline float op_neg(float x) {
- return -x;
- }
- static inline float op_step(float x) {
- return (x > 0.f) ? 1.f : 0.f;
- }
- static inline float op_tanh(float x) {
- return tanhf(x);
- }
- static inline float op_elu(float x) {
- return (x > 0.f) ? x : expm1f(x);
- }
- static inline float op_relu(float x) {
- return (x > 0.f) ? x : 0.f;
- }
- static inline float op_sigmoid(float x) {
- return 1.f / (1.f + expf(-x));
- }
- static inline float op_hardsigmoid(float x) {
- return fminf(1.0f, fmaxf(0.0f, (x + 3.0f) / 6.0f));
- }
- static inline float op_exp(float x) {
- return expf(x);
- }
- static inline float op_hardswish(float x) {
- return x * fminf(1.0f, fmaxf(0.0f, (x + 3.0f) / 6.0f));
- }
- static inline float op_sqr(float x) {
- return x * x;
- }
- static inline float op_sqrt(float x) {
- return sqrtf(x);
- }
- static inline float op_xielu(float x, float alpha_n, float alpha_p, float beta, float eps) {
- if (x > 0.0f) {
- return alpha_p * x * x + beta * x;
- } else {
- const float min_x_eps = fminf(x, eps);
- return (expm1f(min_x_eps) - x) * alpha_n + beta * x;
- }
- }
- static inline float op_sin(float x) {
- return sinf(x);
- }
- static inline float op_cos(float x) {
- return cosf(x);
- }
- static inline float op_log(float x) {
- return logf(x);
- }
- template <float (*op)(float), typename src0_t, typename dst_t>
- static inline void vec_unary_op(int64_t n, dst_t * y, const src0_t * x) {
- constexpr auto src0_to_f32 = type_conversion_table<src0_t>::to_f32;
- constexpr auto f32_to_dst = type_conversion_table<dst_t >::from_f32;
- for (int i = 0; i < n; i++) {
- y[i] = f32_to_dst(op(src0_to_f32(x[i])));
- }
- }
- template <float (*op)(float), typename src0_t, typename dst_t>
- static void apply_unary_op(const ggml_compute_params * params, ggml_tensor * dst) {
- const ggml_tensor * src0 = dst->src[0];
- GGML_ASSERT(ggml_is_contiguous_1(src0) && ggml_is_contiguous_1(dst) && ggml_are_same_shape(src0, dst));
- GGML_TENSOR_UNARY_OP_LOCALS
- GGML_ASSERT( nb0 == sizeof(dst_t));
- GGML_ASSERT(nb00 == sizeof(src0_t));
- const auto [ir0, ir1] = get_thread_range(params, src0);
- for (int64_t ir = ir0; ir < ir1; ++ir) {
- const int64_t i03 = ir/(ne02*ne01);
- const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
- const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
- dst_t * dst_ptr = (dst_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
- const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
- vec_unary_op<op>(ne0, dst_ptr, src0_ptr);
- }
- }
- // TODO: Use the 'traits' lookup table (for type conversion fns), instead of a mass of 'if' conditions with long templates
- template <float (*op)(float)>
- static void unary_op(const ggml_compute_params * params, ggml_tensor * dst) {
- const ggml_tensor * src0 = dst->src[0];
- /* */ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { // all f32
- apply_unary_op<op, float, float>(params, dst);
- } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { // all f16
- apply_unary_op<op, ggml_fp16_t, ggml_fp16_t>(params, dst);
- } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_BF16) { // all bf16
- apply_unary_op<op, ggml_bf16_t, ggml_bf16_t>(params, dst);
- } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_F32) {
- apply_unary_op<op, ggml_bf16_t, float>(params, dst);
- } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
- apply_unary_op<op, ggml_fp16_t, float>(params, dst);
- } else {
- fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s\n", __func__,
- ggml_type_name(dst->type), ggml_type_name(src0->type));
- GGML_ABORT("fatal error");
- }
- }
- template <float (*op)(float, ggml_tensor *)>
- static void unary_op_params(const ggml_compute_params * params, ggml_tensor * dst) {
- const ggml_tensor * src0 = dst->src[0];
- /* */ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { // all f32
- apply_unary_op<op, float, float>(params, dst);
- } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { // all f16
- apply_unary_op<op, ggml_fp16_t, ggml_fp16_t>(params, dst);
- } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_BF16) { // all bf16
- apply_unary_op<op, ggml_bf16_t, ggml_bf16_t>(params, dst);
- } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_F32) {
- apply_unary_op<op, ggml_bf16_t, float>(params, dst);
- } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
- apply_unary_op<op, ggml_fp16_t, float>(params, dst);
- } else {
- fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s\n", __func__,
- ggml_type_name(dst->type), ggml_type_name(src0->type));
- GGML_ABORT("fatal error");
- }
- }
- // Extend vec_unary_op to support functors
- template <typename Op, typename src0_t, typename dst_t>
- static inline void vec_unary_op_functor(int64_t n, dst_t * y, const src0_t * x, Op op) {
- constexpr auto src0_to_f32 = type_conversion_table<src0_t>::to_f32;
- constexpr auto f32_to_dst = type_conversion_table<dst_t >::from_f32;
- for (int i = 0; i < n; i++) {
- y[i] = f32_to_dst(op(src0_to_f32(x[i])));
- }
- }
- // Extend apply_unary_op to support functors
- template <typename Op, typename src0_t, typename dst_t>
- static void apply_unary_op_functor(const ggml_compute_params * params, ggml_tensor * dst, Op op) {
- const ggml_tensor * src0 = dst->src[0];
- GGML_ASSERT(ggml_is_contiguous_1(src0) && ggml_is_contiguous_1(dst) && ggml_are_same_shape(src0, dst));
- GGML_TENSOR_UNARY_OP_LOCALS
- GGML_ASSERT( nb0 == sizeof(dst_t));
- GGML_ASSERT(nb00 == sizeof(src0_t));
- const auto [ir0, ir1] = get_thread_range(params, src0);
- for (int64_t ir = ir0; ir < ir1; ++ir) {
- const int64_t i03 = ir/(ne02*ne01);
- const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
- const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
- dst_t * dst_ptr = (dst_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
- const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
- vec_unary_op_functor(ne0, dst_ptr, src0_ptr, op);
- }
- }
- // Generic dispatcher for functors
- template <typename Op>
- static void unary_op_functor(const ggml_compute_params * params, ggml_tensor * dst, Op op) {
- const ggml_tensor * src0 = dst->src[0];
- /* */ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { // all f32
- apply_unary_op_functor<Op, float, float>(params, dst, op);
- } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { // all f16
- apply_unary_op_functor<Op, ggml_fp16_t, ggml_fp16_t>(params, dst, op);
- } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_BF16) { // all bf16
- apply_unary_op_functor<Op, ggml_bf16_t, ggml_bf16_t>(params, dst, op);
- } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_F32) {
- apply_unary_op_functor<Op, ggml_bf16_t, float>(params, dst, op);
- } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
- apply_unary_op_functor<Op, ggml_fp16_t, float>(params, dst, op);
- } else {
- fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s\n", __func__,
- ggml_type_name(dst->type), ggml_type_name(src0->type));
- GGML_ABORT("fatal error");
- }
- }
- void ggml_compute_forward_abs(const ggml_compute_params * params, ggml_tensor * dst) {
- unary_op<op_abs>(params, dst);
- }
- void ggml_compute_forward_sgn(const ggml_compute_params * params, ggml_tensor * dst) {
- unary_op<op_sgn>(params, dst);
- }
- void ggml_compute_forward_neg(const ggml_compute_params * params, ggml_tensor * dst) {
- unary_op<op_neg>(params, dst);
- }
- void ggml_compute_forward_step(const ggml_compute_params * params, ggml_tensor * dst) {
- unary_op<op_step>(params, dst);
- }
- void ggml_compute_forward_tanh(const ggml_compute_params * params, ggml_tensor * dst) {
- unary_op<op_tanh>(params, dst);
- }
- void ggml_compute_forward_elu(const ggml_compute_params * params, ggml_tensor * dst) {
- unary_op<op_elu>(params, dst);
- }
- void ggml_compute_forward_relu(const ggml_compute_params * params, ggml_tensor * dst) {
- unary_op<op_relu>(params, dst);
- }
- void ggml_compute_forward_sigmoid(const ggml_compute_params * params, ggml_tensor * dst) {
- unary_op<op_sigmoid>(params, dst);
- }
- void ggml_compute_forward_hardsigmoid(const ggml_compute_params * params, ggml_tensor * dst) {
- unary_op<op_hardsigmoid>(params, dst);
- }
- void ggml_compute_forward_exp(const ggml_compute_params * params, ggml_tensor * dst) {
- unary_op<op_exp>(params, dst);
- }
- void ggml_compute_forward_hardswish(const ggml_compute_params * params, ggml_tensor * dst) {
- unary_op<op_hardswish>(params, dst);
- }
- void ggml_compute_forward_sqr(const ggml_compute_params * params, ggml_tensor * dst) {
- unary_op<op_sqr>(params, dst);
- }
- void ggml_compute_forward_sqrt(const ggml_compute_params * params, ggml_tensor * dst) {
- unary_op<op_sqrt>(params, dst);
- }
- void ggml_compute_forward_sin(const ggml_compute_params * params, ggml_tensor * dst) {
- unary_op<op_sin>(params, dst);
- }
- void ggml_compute_forward_cos(const ggml_compute_params * params, ggml_tensor * dst) {
- unary_op<op_cos>(params, dst);
- }
- void ggml_compute_forward_log(const ggml_compute_params * params, ggml_tensor * dst) {
- unary_op<op_log>(params, dst);
- }
- void ggml_compute_forward_xielu(const ggml_compute_params * params, ggml_tensor * dst) {
- const float alpha_n = ggml_get_op_params_f32(dst, 1);
- const float alpha_p = ggml_get_op_params_f32(dst, 2);
- const float beta = ggml_get_op_params_f32(dst, 3);
- const float eps = ggml_get_op_params_f32(dst, 4);
- const auto xielu_op_params = [alpha_n, alpha_p, beta, eps](float f) {
- return op_xielu(f, alpha_n, alpha_p, beta, eps);
- };
- unary_op_functor(params, dst, xielu_op_params);
- }
|