simd_avx2.s 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. //go:build amd64
  2. // +build amd64
  3. #include "textflag.h"
  4. // func dotAVX2(a *float32, b *float32, n int) float32
  5. TEXT ·dotAVX2(SB), NOSPLIT, $0-24
  6. MOVQ a+0(FP), DI
  7. MOVQ b+8(FP), SI
  8. MOVQ n+16(FP), CX
  9. // Accumulator ymm0 = 0
  10. VXORPS Y0, Y0, Y0
  11. TESTQ CX, CX
  12. JLE dot_zero
  13. // Process 8 floats per iteration
  14. loop8:
  15. CMPQ CX, $8
  16. JL fold
  17. VMOVUPS (DI), Y1
  18. VMOVUPS (SI), Y2
  19. VFMADD231PS Y1, Y2, Y0 // Y0 += Y1 * Y2
  20. ADDQ $32, DI
  21. ADDQ $32, SI
  22. SUBQ $8, CX
  23. JMP loop8
  24. // Fold ymm0 upper half into xmm0 before handling tails.
  25. fold:
  26. VEXTRACTF128 $1, Y0, X1
  27. VADDPS X1, X0, X0
  28. // Scalar tail
  29. loop1:
  30. CMPQ CX, $4
  31. JL loop_scalar
  32. VMOVUPS (DI), X1
  33. VMOVUPS (SI), X2
  34. VFMADD231PS X1, X2, X0
  35. ADDQ $16, DI
  36. ADDQ $16, SI
  37. SUBQ $4, CX
  38. JMP loop1
  39. loop_scalar:
  40. TESTQ CX, CX
  41. JE reduce4
  42. MOVSS (DI), X1
  43. MOVSS (SI), X2
  44. VFMADD231SS X1, X2, X0
  45. ADDQ $4, DI
  46. ADDQ $4, SI
  47. DECQ CX
  48. JMP loop_scalar
  49. // Horizontal sum of xmm0 (4 lanes) to scalar
  50. reduce4:
  51. VMOVHLPS X0, X0, X1
  52. VADDPS X1, X0, X0
  53. VPSHUFD $0xB1, X0, X1
  54. VADDPS X1, X0, X0
  55. MOVSS X0, ret+24(FP)
  56. VZEROUPPER
  57. RET
  58. dot_zero:
  59. VXORPS X0, X0, X0
  60. MOVSS X0, ret+24(FP)
  61. RET
  62. // func axpyAVX2(alpha float32, x *float32, y *float32, n int)
  63. TEXT ·axpyAVX2(SB), NOSPLIT, $0-28
  64. MOVSS alpha+0(FP), X0
  65. VBROADCASTSS X0, Y0
  66. MOVQ x+8(FP), DI
  67. MOVQ y+16(FP), SI
  68. MOVQ n+24(FP), CX
  69. TESTQ CX, CX
  70. JLE axpy_done
  71. axpy_loop8:
  72. CMPQ CX, $8
  73. JL axpy_loop1
  74. VMOVUPS (DI), Y1
  75. VMOVUPS (SI), Y2
  76. VFMADD231PS Y0, Y1, Y2
  77. VMOVUPS Y2, (SI)
  78. ADDQ $32, DI
  79. ADDQ $32, SI
  80. SUBQ $8, CX
  81. JMP axpy_loop8
  82. axpy_loop1:
  83. TESTQ CX, CX
  84. JE axpy_done
  85. MOVSS (DI), X1
  86. MOVSS (SI), X2
  87. VFMADD231SS X0, X1, X2
  88. MOVSS X2, (SI)
  89. ADDQ $4, DI
  90. ADDQ $4, SI
  91. DECQ CX
  92. JMP axpy_loop1
  93. axpy_done:
  94. RET