1
0

softmax_avx2.s 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. //go:build amd64
  2. // +build amd64
  3. #include "textflag.h"
  4. // func softmaxMaxAVX2Asm(x *float32, n int) float32
  5. TEXT ·softmaxMaxAVX2Asm(SB), NOSPLIT, $0-24
  6. MOVQ x+0(FP), DI
  7. MOVQ n+8(FP), CX
  8. VBROADCASTSS ·negInfConst(SB), Y0
  9. loop_max:
  10. CMPQ CX, $8
  11. JL max_reduce
  12. VMOVUPS (DI), Y1
  13. VMAXPS Y1, Y0, Y0
  14. ADDQ $32, DI
  15. SUBQ $8, CX
  16. JMP loop_max
  17. max_reduce:
  18. VEXTRACTF128 $0, Y0, X0
  19. VEXTRACTF128 $1, Y0, X1
  20. VMAXPS X1, X0, X0
  21. VPERMILPS $0x4E, X0, X1
  22. VMAXPS X1, X0, X0
  23. VPERMILPS $0xB1, X0, X1
  24. VMAXPS X1, X0, X0
  25. // tail
  26. TESTQ CX, CX
  27. JE max_done
  28. max_tail:
  29. VMOVSS (DI), X1
  30. VMAXSS X1, X0, X0
  31. ADDQ $4, DI
  32. DECQ CX
  33. JNZ max_tail
  34. max_done:
  35. MOVSS X0, ret+16(FP)
  36. RET
  37. // func softmaxExpSumAVX2Asm(x *float32, n int, max float32) float32
  38. TEXT ·softmaxExpSumAVX2Asm(SB), NOSPLIT, $0-28
  39. MOVQ x+0(FP), DI
  40. MOVQ n+8(FP), CX
  41. MOVSS max+16(FP), X5
  42. VBROADCASTSS X5, Y5 // max
  43. VBROADCASTSS ·expHi(SB), Y14
  44. VBROADCASTSS ·expLo(SB), Y13
  45. VBROADCASTSS ·log2EF(SB), Y12
  46. VBROADCASTSS ·halfConst(SB), Y11
  47. VBROADCASTSS ·expC1(SB), Y10
  48. VBROADCASTSS ·expC2(SB), Y9
  49. VBROADCASTSS ·oneConst(SB), Y8
  50. VPBROADCASTD ·expBiasConst(SB), Y15
  51. VXORPS Y7, Y7, Y7 // sum accumulator
  52. loop_exp:
  53. CMPQ CX, $8
  54. JL exp_reduce
  55. VMOVUPS (DI), Y1
  56. VSUBPS Y5, Y1, Y1 // x - max
  57. VMINPS Y14, Y1, Y1
  58. VMAXPS Y13, Y1, Y1
  59. VMULPS Y12, Y1, Y2
  60. VADDPS Y11, Y2, Y2
  61. VROUNDPS $1, Y2, Y2
  62. VCVTPS2DQ Y2, Y6
  63. VCVTDQ2PS Y6, Y3
  64. VMULPS Y10, Y3, Y4
  65. VSUBPS Y4, Y1, Y1
  66. VMULPS Y9, Y3, Y4
  67. VSUBPS Y4, Y1, Y1
  68. VMULPS Y1, Y1, Y4 // z = x*x
  69. VBROADCASTSS ·polyP0(SB), Y0
  70. VMULPS Y1, Y0, Y0
  71. VBROADCASTSS ·polyP1(SB), Y3
  72. VADDPS Y3, Y0, Y0
  73. VMULPS Y1, Y0, Y0
  74. VBROADCASTSS ·polyP2(SB), Y3
  75. VADDPS Y3, Y0, Y0
  76. VMULPS Y1, Y0, Y0
  77. VBROADCASTSS ·polyP3(SB), Y3
  78. VADDPS Y3, Y0, Y0
  79. VMULPS Y1, Y0, Y0
  80. VBROADCASTSS ·polyP4(SB), Y3
  81. VADDPS Y3, Y0, Y0
  82. VMULPS Y1, Y0, Y0
  83. VBROADCASTSS ·polyP5(SB), Y3
  84. VADDPS Y3, Y0, Y0
  85. VMULPS Y4, Y0, Y0 // y *= z
  86. VADDPS Y1, Y0, Y0 // y += x
  87. VADDPS Y8, Y0, Y0 // y += 1
  88. VPADDD Y15, Y6, Y6
  89. VPSLLD $23, Y6, Y6
  90. VMULPS Y6, Y0, Y0 // exp(x - max)
  91. VADDPS Y0, Y7, Y7
  92. VMOVUPS Y0, (DI)
  93. ADDQ $32, DI
  94. SUBQ $8, CX
  95. JMP loop_exp
  96. exp_reduce:
  97. VEXTRACTF128 $0, Y7, X0
  98. VEXTRACTF128 $1, Y7, X1
  99. VADDPS X1, X0, X0
  100. VHADDPS X0, X0, X0
  101. VHADDPS X0, X0, X0
  102. MOVSS X0, ret+24(FP)
  103. RET
  104. // func softmaxScaleAVX2Asm(x *float32, n int, inv float32)
  105. TEXT ·softmaxScaleAVX2Asm(SB), NOSPLIT, $0-24
  106. MOVQ x+0(FP), DI
  107. MOVQ n+8(FP), CX
  108. MOVSS inv+16(FP), X1
  109. VBROADCASTSS X1, Y1
  110. loop_scale:
  111. CMPQ CX, $8
  112. JL scale_tail
  113. VMULPS (DI), Y1, Y0
  114. VMOVUPS Y0, (DI)
  115. ADDQ $32, DI
  116. SUBQ $8, CX
  117. JMP loop_scale
  118. scale_tail:
  119. TESTQ CX, CX
  120. JE scale_done
  121. scale_tail_loop:
  122. MOVSS (DI), X0
  123. VMULSS X1, X0, X0
  124. MOVSS X0, (DI)
  125. ADDQ $4, DI
  126. DECQ CX
  127. JNZ scale_tail_loop
  128. scale_done:
  129. RET
  130. // Additional constants
  131. DATA ·negInfConst+0(SB)/4, $0xff800000
  132. GLOBL ·negInfConst(SB), RODATA, $4