simd_avx512.s 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. //go:build amd64
  2. // +build amd64
  3. #include "textflag.h"
  4. // func dotAVX512(a *float32, b *float32, n int) float32
  5. TEXT ·dotAVX512(SB), NOSPLIT, $0-24
  6. // Load args
  7. MOVQ a+0(FP), DI
  8. MOVQ b+8(FP), SI
  9. MOVQ n+16(FP), CX
  10. // Accumulator
  11. VXORPS Z0, Z0, Z0
  12. // If n <= 0 return 0
  13. TESTQ CX, CX
  14. JLE dot512_zero
  15. // Process 16 floats per iteration
  16. loop16:
  17. CMPQ CX, $16
  18. JL fold512
  19. VMOVUPS (DI), Z1
  20. VMOVUPS (SI), Z2
  21. VFMADD231PS Z1, Z2, Z0 // Z0 += Z1 * Z2
  22. ADDQ $64, DI
  23. ADDQ $64, SI
  24. SUBQ $16, CX
  25. JMP loop16
  26. // Fold zmm0 -> xmm0 before handling tails.
  27. fold512:
  28. VEXTRACTF32X8 $1, Z0, Y1
  29. VADDPS Y1, Y0, Y0
  30. VEXTRACTF128 $1, Y0, X1
  31. VADDPS X1, X0, X0
  32. // Scalar tail
  33. loop1:
  34. CMPQ CX, $4
  35. JL loop1_scalar
  36. VMOVUPS (DI), X1
  37. VMOVUPS (SI), X2
  38. VFMADD231PS X1, X2, X0
  39. ADDQ $16, DI
  40. ADDQ $16, SI
  41. SUBQ $4, CX
  42. JMP loop1
  43. loop1_scalar:
  44. TESTQ CX, CX
  45. JE reduce4
  46. MOVSS (DI), X1
  47. MOVSS (SI), X2
  48. VFMADD231SS X1, X2, X0
  49. ADDQ $4, DI
  50. ADDQ $4, SI
  51. DECQ CX
  52. JMP loop1_scalar
  53. // Horizontal sum of xmm0 (4 lanes) to scalar.
  54. reduce4:
  55. VMOVHLPS X0, X0, X1
  56. VADDPS X1, X0, X0
  57. VPSHUFD $0xB1, X0, X1
  58. VADDPS X1, X0, X0
  59. MOVSS X0, ret+24(FP)
  60. VZEROUPPER
  61. RET
  62. dot512_zero:
  63. VXORPS X0, X0, X0
  64. MOVSS X0, ret+24(FP)
  65. RET
  66. // func axpyAVX512(alpha float32, x *float32, y *float32, n int)
  67. TEXT ·axpyAVX512(SB), NOSPLIT, $0-28
  68. MOVSS alpha+0(FP), X0
  69. VBROADCASTSS X0, Z0
  70. MOVQ x+8(FP), DI
  71. MOVQ y+16(FP), SI
  72. MOVQ n+24(FP), CX
  73. TESTQ CX, CX
  74. JLE axpy512_done
  75. axpy512_loop16:
  76. CMPQ CX, $16
  77. JL axpy512_loop1
  78. VMOVUPS (DI), Z1
  79. VMOVUPS (SI), Z2
  80. VFMADD231PS Z0, Z1, Z2
  81. VMOVUPS Z2, (SI)
  82. ADDQ $64, DI
  83. ADDQ $64, SI
  84. SUBQ $16, CX
  85. JMP axpy512_loop16
  86. axpy512_loop1:
  87. TESTQ CX, CX
  88. JE axpy512_done
  89. MOVSS (DI), X1
  90. MOVSS (SI), X2
  91. VFMADD231SS X0, X1, X2
  92. MOVSS X2, (SI)
  93. ADDQ $4, DI
  94. ADDQ $4, SI
  95. DECQ CX
  96. JMP axpy512_loop1
  97. axpy512_done:
  98. RET