softmax_avx512.s 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. //go:build amd64
  2. // +build amd64
  3. #include "textflag.h"
  4. // func softmaxMaxAVX512Asm(x *float32, n int) float32
  5. TEXT ·softmaxMaxAVX512Asm(SB), NOSPLIT, $0-24
  6. MOVQ x+0(FP), DI
  7. MOVQ n+8(FP), CX
  8. VBROADCASTSS ·negInfConst(SB), Z0
  9. loop_max:
  10. CMPQ CX, $16
  11. JL max_reduce
  12. VMOVUPS (DI), Z1
  13. VMAXPS Z1, Z0, Z0
  14. ADDQ $64, DI
  15. SUBQ $16, CX
  16. JMP loop_max
  17. max_reduce:
  18. VEXTRACTF32X8 $1, Z0, Y1
  19. VMAXPS Y1, Y0, Y0
  20. VEXTRACTF128 $1, Y0, X1
  21. VMAXPS X1, X0, X0
  22. VPERMILPS $0x4E, X0, X1
  23. VMAXPS X1, X0, X0
  24. VPERMILPS $0xB1, X0, X1
  25. VMAXPS X1, X0, X0
  26. // tail
  27. TESTQ CX, CX
  28. JE max_done
  29. max_tail:
  30. VMOVSS (DI), X1
  31. VMAXSS X1, X0, X0
  32. ADDQ $4, DI
  33. DECQ CX
  34. JNZ max_tail
  35. max_done:
  36. MOVSS X0, ret+16(FP)
  37. RET
  38. // func softmaxExpSumAVX512Asm(x *float32, n int, max float32) float32
  39. TEXT ·softmaxExpSumAVX512Asm(SB), NOSPLIT, $0-28
  40. MOVQ x+0(FP), DI
  41. MOVQ n+8(FP), CX
  42. MOVSS max+16(FP), X5
  43. VBROADCASTSS X5, Z5 // max
  44. VBROADCASTSS ·expHi(SB), Z14
  45. VBROADCASTSS ·expLo(SB), Z13
  46. VBROADCASTSS ·log2EF(SB), Z12
  47. VBROADCASTSS ·halfConst(SB), Z11
  48. VBROADCASTSS ·expC1(SB), Z10
  49. VBROADCASTSS ·expC2(SB), Z9
  50. VBROADCASTSS ·oneConst(SB), Z8
  51. VPBROADCASTD ·expBiasConst(SB), Z15
  52. VXORPS Z7, Z7, Z7 // sum accumulator
  53. loop_exp:
  54. CMPQ CX, $16
  55. JL exp_reduce
  56. VMOVUPS (DI), Z1
  57. VSUBPS Z5, Z1, Z1 // x - max
  58. VMINPS Z14, Z1, Z1
  59. VMAXPS Z13, Z1, Z1
  60. VMULPS Z12, Z1, Z2
  61. VADDPS Z11, Z2, Z2
  62. VRNDSCALEPS $1, Z2, Z2
  63. VCVTPS2DQ Z2, Z6
  64. VCVTDQ2PS Z6, Z3
  65. VMULPS Z10, Z3, Z4
  66. VSUBPS Z4, Z1, Z1
  67. VMULPS Z9, Z3, Z4
  68. VSUBPS Z4, Z1, Z1
  69. VMULPS Z1, Z1, Z4 // z = x*x
  70. VBROADCASTSS ·polyP0(SB), Z0
  71. VMULPS Z1, Z0, Z0
  72. VBROADCASTSS ·polyP1(SB), Z3
  73. VADDPS Z3, Z0, Z0
  74. VMULPS Z1, Z0, Z0
  75. VBROADCASTSS ·polyP2(SB), Z3
  76. VADDPS Z3, Z0, Z0
  77. VMULPS Z1, Z0, Z0
  78. VBROADCASTSS ·polyP3(SB), Z3
  79. VADDPS Z3, Z0, Z0
  80. VMULPS Z1, Z0, Z0
  81. VBROADCASTSS ·polyP4(SB), Z3
  82. VADDPS Z3, Z0, Z0
  83. VMULPS Z1, Z0, Z0
  84. VBROADCASTSS ·polyP5(SB), Z3
  85. VADDPS Z3, Z0, Z0
  86. VMULPS Z4, Z0, Z0 // y *= z
  87. VADDPS Z1, Z0, Z0 // y += x
  88. VADDPS Z8, Z0, Z0 // y += 1
  89. VPADDD Z15, Z6, Z6
  90. VPSLLD $23, Z6, Z6
  91. VMULPS Z6, Z0, Z0 // exp(x - max)
  92. VADDPS Z0, Z7, Z7
  93. VMOVUPS Z0, (DI)
  94. ADDQ $64, DI
  95. SUBQ $16, CX
  96. JMP loop_exp
  97. exp_reduce:
  98. VEXTRACTF32X8 $1, Z7, Y1
  99. VADDPS Y1, Y0, Y0
  100. VEXTRACTF128 $1, Y0, X1
  101. VADDPS X1, X0, X0
  102. VHADDPS X0, X0, X0
  103. VHADDPS X0, X0, X0
  104. MOVSS X0, ret+24(FP)
  105. RET
  106. // func softmaxScaleAVX512Asm(x *float32, n int, inv float32)
  107. TEXT ·softmaxScaleAVX512Asm(SB), NOSPLIT, $0-24
  108. MOVQ x+0(FP), DI
  109. MOVQ n+8(FP), CX
  110. MOVSS inv+16(FP), X1
  111. VBROADCASTSS X1, Z1
  112. loop_scale:
  113. CMPQ CX, $16
  114. JL scale_tail
  115. VMULPS (DI), Z1, Z0
  116. VMOVUPS Z0, (DI)
  117. ADDQ $64, DI
  118. SUBQ $16, CX
  119. JMP loop_scale
  120. scale_tail:
  121. TESTQ CX, CX
  122. JE scale_done
  123. scale_tail_loop:
  124. MOVSS (DI), X0
  125. VMULSS X1, X0, X0
  126. MOVSS X0, (DI)
  127. ADDQ $4, DI
  128. DECQ CX
  129. JNZ scale_tail_loop
  130. scale_done:
  131. RET