gemv_f32_tile8_avx512.s 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. //go:build amd64
  2. // +build amd64
  3. #include "textflag.h"
  4. // func gemvF32Tile8AVX512(a *float32, b *float32, out *float32, K int)
  5. // Computes 8 independent dot products:
  6. // out[t] = sum_{i=0..K-1} a[i] * b[t*K+i], for t=0..7
  7. // Vectorizes over K with AVX-512/FMA and reuses each A vector across 8 outputs.
  8. TEXT ·gemvF32Tile8AVX512(SB), NOSPLIT, $96-32
  9. // Preserve general-purpose registers (Go ABI + ABI wrappers).
  10. MOVQ AX, 0(SP)
  11. MOVQ BX, 8(SP)
  12. MOVQ CX, 16(SP)
  13. MOVQ DX, 24(SP)
  14. MOVQ DI, 32(SP)
  15. MOVQ SI, 40(SP)
  16. MOVQ R8, 48(SP)
  17. MOVQ R9, 56(SP)
  18. MOVQ R10, 64(SP)
  19. MOVQ R11, 72(SP)
  20. MOVQ R12, 80(SP)
  21. MOVQ R13, 88(SP)
  22. MOVQ a+0(FP), DI
  23. MOVQ b+8(FP), SI
  24. MOVQ out+16(FP), DX
  25. MOVQ K+24(FP), CX
  26. // strideBytes = K * 4
  27. MOVQ CX, BX
  28. SHLQ $2, BX
  29. // kMain = K &^ 15 (multiple of 16 floats)
  30. ANDQ $-16, CX
  31. JLE zero
  32. // b1..b7 pointers
  33. MOVQ SI, R8
  34. ADDQ BX, R8
  35. MOVQ R8, R9
  36. ADDQ BX, R9
  37. MOVQ R9, R10
  38. ADDQ BX, R10
  39. MOVQ R10, R11
  40. ADDQ BX, R11
  41. MOVQ R11, R12
  42. ADDQ BX, R12
  43. MOVQ R12, R13
  44. ADDQ BX, R13
  45. MOVQ R13, AX
  46. ADDQ BX, AX
  47. // zero accumulators Z0..Z7
  48. VXORPS Z0, Z0, Z0
  49. VXORPS Z1, Z1, Z1
  50. VXORPS Z2, Z2, Z2
  51. VXORPS Z3, Z3, Z3
  52. VXORPS Z4, Z4, Z4
  53. VXORPS Z5, Z5, Z5
  54. VXORPS Z6, Z6, Z6
  55. VXORPS Z7, Z7, Z7
  56. loop:
  57. VMOVUPS (DI), Z8
  58. VMOVUPS (SI), Z9
  59. VFMADD231PS Z8, Z9, Z0
  60. VMOVUPS (R8), Z9
  61. VFMADD231PS Z8, Z9, Z1
  62. VMOVUPS (R9), Z9
  63. VFMADD231PS Z8, Z9, Z2
  64. VMOVUPS (R10), Z9
  65. VFMADD231PS Z8, Z9, Z3
  66. VMOVUPS (R11), Z9
  67. VFMADD231PS Z8, Z9, Z4
  68. VMOVUPS (R12), Z9
  69. VFMADD231PS Z8, Z9, Z5
  70. VMOVUPS (R13), Z9
  71. VFMADD231PS Z8, Z9, Z6
  72. VMOVUPS (AX), Z9
  73. VFMADD231PS Z8, Z9, Z7
  74. ADDQ $64, DI
  75. ADDQ $64, SI
  76. ADDQ $64, R8
  77. ADDQ $64, R9
  78. ADDQ $64, R10
  79. ADDQ $64, R11
  80. ADDQ $64, R12
  81. ADDQ $64, R13
  82. ADDQ $64, AX
  83. SUBQ $16, CX
  84. JNZ loop
  85. // Reduce each accumulator to scalar and store.
  86. // Z0 -> out[0]
  87. VEXTRACTF32X8 $1, Z0, Y8
  88. VADDPS Y8, Y0, Y0
  89. VEXTRACTF128 $1, Y0, X8
  90. VADDPS X8, X0, X0
  91. VPSHUFD $0x4E, X0, X8
  92. VADDPS X8, X0, X0
  93. VPSHUFD $0xB1, X0, X8
  94. VADDPS X8, X0, X0
  95. MOVSS X0, 0(DX)
  96. // Z1 -> out[1]
  97. VEXTRACTF32X8 $1, Z1, Y8
  98. VADDPS Y8, Y1, Y1
  99. VEXTRACTF128 $1, Y1, X8
  100. VADDPS X8, X1, X1
  101. VPSHUFD $0x4E, X1, X8
  102. VADDPS X8, X1, X1
  103. VPSHUFD $0xB1, X1, X8
  104. VADDPS X8, X1, X1
  105. MOVSS X1, 4(DX)
  106. // Z2 -> out[2]
  107. VEXTRACTF32X8 $1, Z2, Y8
  108. VADDPS Y8, Y2, Y2
  109. VEXTRACTF128 $1, Y2, X8
  110. VADDPS X8, X2, X2
  111. VPSHUFD $0x4E, X2, X8
  112. VADDPS X8, X2, X2
  113. VPSHUFD $0xB1, X2, X8
  114. VADDPS X8, X2, X2
  115. MOVSS X2, 8(DX)
  116. // Z3 -> out[3]
  117. VEXTRACTF32X8 $1, Z3, Y8
  118. VADDPS Y8, Y3, Y3
  119. VEXTRACTF128 $1, Y3, X8
  120. VADDPS X8, X3, X3
  121. VPSHUFD $0x4E, X3, X8
  122. VADDPS X8, X3, X3
  123. VPSHUFD $0xB1, X3, X8
  124. VADDPS X8, X3, X3
  125. MOVSS X3, 12(DX)
  126. // Z4 -> out[4]
  127. VEXTRACTF32X8 $1, Z4, Y8
  128. VADDPS Y8, Y4, Y4
  129. VEXTRACTF128 $1, Y4, X8
  130. VADDPS X8, X4, X4
  131. VPSHUFD $0x4E, X4, X8
  132. VADDPS X8, X4, X4
  133. VPSHUFD $0xB1, X4, X8
  134. VADDPS X8, X4, X4
  135. MOVSS X4, 16(DX)
  136. // Z5 -> out[5]
  137. VEXTRACTF32X8 $1, Z5, Y8
  138. VADDPS Y8, Y5, Y5
  139. VEXTRACTF128 $1, Y5, X8
  140. VADDPS X8, X5, X5
  141. VPSHUFD $0x4E, X5, X8
  142. VADDPS X8, X5, X5
  143. VPSHUFD $0xB1, X5, X8
  144. VADDPS X8, X5, X5
  145. MOVSS X5, 20(DX)
  146. // Z6 -> out[6]
  147. VEXTRACTF32X8 $1, Z6, Y8
  148. VADDPS Y8, Y6, Y6
  149. VEXTRACTF128 $1, Y6, X8
  150. VADDPS X8, X6, X6
  151. VPSHUFD $0x4E, X6, X8
  152. VADDPS X8, X6, X6
  153. VPSHUFD $0xB1, X6, X8
  154. VADDPS X8, X6, X6
  155. MOVSS X6, 24(DX)
  156. // Z7 -> out[7]
  157. VEXTRACTF32X8 $1, Z7, Y8
  158. VADDPS Y8, Y7, Y7
  159. VEXTRACTF128 $1, Y7, X8
  160. VADDPS X8, X7, X7
  161. VPSHUFD $0x4E, X7, X8
  162. VADDPS X8, X7, X7
  163. VPSHUFD $0xB1, X7, X8
  164. VADDPS X8, X7, X7
  165. MOVSS X7, 28(DX)
  166. VZEROUPPER
  167. JMP epilogue
  168. zero:
  169. VXORPS X0, X0, X0
  170. MOVSS X0, 0(DX)
  171. MOVSS X0, 4(DX)
  172. MOVSS X0, 8(DX)
  173. MOVSS X0, 12(DX)
  174. MOVSS X0, 16(DX)
  175. MOVSS X0, 20(DX)
  176. MOVSS X0, 24(DX)
  177. MOVSS X0, 28(DX)
  178. VZEROUPPER
  179. JMP epilogue
  180. epilogue:
  181. MOVQ 0(SP), AX
  182. MOVQ 8(SP), BX
  183. MOVQ 16(SP), CX
  184. MOVQ 24(SP), DX
  185. MOVQ 32(SP), DI
  186. MOVQ 40(SP), SI
  187. MOVQ 48(SP), R8
  188. MOVQ 56(SP), R9
  189. MOVQ 64(SP), R10
  190. MOVQ 72(SP), R11
  191. MOVQ 80(SP), R12
  192. MOVQ 88(SP), R13
  193. RET