1
0

silu_avx512.s 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. //go:build amd64
  2. // +build amd64
  3. #include "textflag.h"
  4. // func siluAVX512Asm(x *float32, n int)
  5. TEXT ·siluAVX512Asm(SB), NOSPLIT, $0-16
  6. // Load args
  7. MOVQ x+0(FP), DI
  8. MOVQ n+8(FP), CX
  9. CMPQ CX, $0
  10. JLE done
  11. // Broadcast constants
  12. VBROADCASTSS ·expHi(SB), Z14
  13. VBROADCASTSS ·expLo(SB), Z13
  14. VBROADCASTSS ·log2EF(SB), Z12
  15. VBROADCASTSS ·halfConst(SB), Z11
  16. VBROADCASTSS ·expC1(SB), Z10
  17. VBROADCASTSS ·expC2(SB), Z9
  18. VBROADCASTSS ·oneConst(SB), Z8
  19. VPBROADCASTD ·signMaskConst(SB), Z15
  20. loop:
  21. CMPQ CX, $16
  22. JL done
  23. VMOVUPS (DI), Z0 // original x
  24. VMOVAPS Z0, Z1 // copy for neg
  25. VXORPS Z15, Z1, Z1 // z1 = -x
  26. VMINPS Z14, Z1, Z1 // clamp hi
  27. VMAXPS Z13, Z1, Z1 // clamp lo
  28. VMULPS Z12, Z1, Z2 // z2 = x * log2e
  29. VADDPS Z11, Z2, Z2 // +0.5
  30. VRNDSCALEPS $1, Z2, Z2 // floor
  31. VCVTPS2DQ Z2, Z6 // integer exponent
  32. VCVTDQ2PS Z6, Z5 // fx as float
  33. VMULPS Z10, Z5, Z3 // fx * C1
  34. VSUBPS Z3, Z1, Z1
  35. VMULPS Z9, Z5, Z3 // fx * C2
  36. VSUBPS Z3, Z1, Z1
  37. VMULPS Z1, Z1, Z3 // z = x*x
  38. VBROADCASTSS ·polyP0(SB), Z4
  39. VMULPS Z1, Z4, Z4
  40. VBROADCASTSS ·polyP1(SB), Z5
  41. VADDPS Z5, Z4, Z4
  42. VMULPS Z1, Z4, Z4
  43. VBROADCASTSS ·polyP2(SB), Z5
  44. VADDPS Z5, Z4, Z4
  45. VMULPS Z1, Z4, Z4
  46. VBROADCASTSS ·polyP3(SB), Z5
  47. VADDPS Z5, Z4, Z4
  48. VMULPS Z1, Z4, Z4
  49. VBROADCASTSS ·polyP4(SB), Z5
  50. VADDPS Z5, Z4, Z4
  51. VMULPS Z1, Z4, Z4
  52. VBROADCASTSS ·polyP5(SB), Z5
  53. VADDPS Z5, Z4, Z4
  54. VMULPS Z3, Z4, Z4 // y *= z
  55. VADDPS Z1, Z4, Z4 // y += x
  56. VADDPS Z8, Z4, Z4 // y += 1
  57. VPBROADCASTD ·expBiasConst(SB), Z5
  58. VPADDD Z5, Z6, Z6
  59. VPSLLD $23, Z6, Z6
  60. VMULPS Z6, Z4, Z4 // exp(-x)
  61. VADDPS Z8, Z4, Z3 // denom = 1 + exp(-x)
  62. VDIVPS Z3, Z8, Z3 // 1 / denom
  63. VMULPS Z0, Z3, Z0 // x * sigmoid(x)
  64. VMOVUPS Z0, (DI)
  65. ADDQ $64, DI
  66. SUBQ $16, CX
  67. JMP loop
  68. done:
  69. RET