gemv_f32_tile8_avx2.s 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. //go:build amd64
  2. // +build amd64
  3. #include "textflag.h"
  4. // func gemvF32Tile8AVX2(a *float32, b *float32, out *float32, K int)
  5. // Computes 8 independent dot products:
  6. // out[t] = sum_{i=0..K-1} a[i] * b[t*K+i], for t=0..7
  7. // Vectorizes over K with AVX2/FMA and reuses each A vector across 8 outputs.
  8. TEXT ·gemvF32Tile8AVX2(SB), NOSPLIT, $96-32
  9. // Preserve general-purpose registers (Go ABI + ABI wrappers).
  10. MOVQ AX, 0(SP)
  11. MOVQ BX, 8(SP)
  12. MOVQ CX, 16(SP)
  13. MOVQ DX, 24(SP)
  14. MOVQ DI, 32(SP)
  15. MOVQ SI, 40(SP)
  16. MOVQ R8, 48(SP)
  17. MOVQ R9, 56(SP)
  18. MOVQ R10, 64(SP)
  19. MOVQ R11, 72(SP)
  20. MOVQ R12, 80(SP)
  21. MOVQ R13, 88(SP)
  22. MOVQ a+0(FP), DI
  23. MOVQ b+8(FP), SI
  24. MOVQ out+16(FP), DX
  25. MOVQ K+24(FP), CX
  26. // strideBytes = K * 4
  27. MOVQ CX, BX
  28. SHLQ $2, BX
  29. // kMain = K &^ 7 (multiple of 8 floats)
  30. ANDQ $-8, CX
  31. JLE zero
  32. // b1..b7 pointers
  33. MOVQ SI, R8
  34. ADDQ BX, R8
  35. MOVQ R8, R9
  36. ADDQ BX, R9
  37. MOVQ R9, R10
  38. ADDQ BX, R10
  39. MOVQ R10, R11
  40. ADDQ BX, R11
  41. MOVQ R11, R12
  42. ADDQ BX, R12
  43. MOVQ R12, R13
  44. ADDQ BX, R13
  45. MOVQ R13, AX
  46. ADDQ BX, AX
  47. // zero accumulators Y0..Y7
  48. VXORPS Y0, Y0, Y0
  49. VXORPS Y1, Y1, Y1
  50. VXORPS Y2, Y2, Y2
  51. VXORPS Y3, Y3, Y3
  52. VXORPS Y4, Y4, Y4
  53. VXORPS Y5, Y5, Y5
  54. VXORPS Y6, Y6, Y6
  55. VXORPS Y7, Y7, Y7
  56. loop:
  57. // load 8 floats from a
  58. VMOVUPS (DI), Y8
  59. // out0..out7 accumulate with shared A vector
  60. VMOVUPS (SI), Y9
  61. VFMADD231PS Y8, Y9, Y0
  62. VMOVUPS (R8), Y9
  63. VFMADD231PS Y8, Y9, Y1
  64. VMOVUPS (R9), Y9
  65. VFMADD231PS Y8, Y9, Y2
  66. VMOVUPS (R10), Y9
  67. VFMADD231PS Y8, Y9, Y3
  68. VMOVUPS (R11), Y9
  69. VFMADD231PS Y8, Y9, Y4
  70. VMOVUPS (R12), Y9
  71. VFMADD231PS Y8, Y9, Y5
  72. VMOVUPS (R13), Y9
  73. VFMADD231PS Y8, Y9, Y6
  74. VMOVUPS (AX), Y9
  75. VFMADD231PS Y8, Y9, Y7
  76. // advance pointers
  77. ADDQ $32, DI
  78. ADDQ $32, SI
  79. ADDQ $32, R8
  80. ADDQ $32, R9
  81. ADDQ $32, R10
  82. ADDQ $32, R11
  83. ADDQ $32, R12
  84. ADDQ $32, R13
  85. ADDQ $32, AX
  86. SUBQ $8, CX
  87. JNZ loop
  88. // Reduce each accumulator to scalar and store.
  89. // Y0 -> out[0]
  90. VEXTRACTF128 $1, Y0, X8
  91. VADDPS X8, X0, X0
  92. VMOVHLPS X0, X0, X8
  93. VADDPS X8, X0, X0
  94. VPSHUFD $0xB1, X0, X8
  95. VADDPS X8, X0, X0
  96. MOVSS X0, 0(DX)
  97. // Y1 -> out[1]
  98. VEXTRACTF128 $1, Y1, X8
  99. VADDPS X8, X1, X1
  100. VMOVHLPS X1, X1, X8
  101. VADDPS X8, X1, X1
  102. VPSHUFD $0xB1, X1, X8
  103. VADDPS X8, X1, X1
  104. MOVSS X1, 4(DX)
  105. // Y2 -> out[2]
  106. VEXTRACTF128 $1, Y2, X8
  107. VADDPS X8, X2, X2
  108. VMOVHLPS X2, X2, X8
  109. VADDPS X8, X2, X2
  110. VPSHUFD $0xB1, X2, X8
  111. VADDPS X8, X2, X2
  112. MOVSS X2, 8(DX)
  113. // Y3 -> out[3]
  114. VEXTRACTF128 $1, Y3, X8
  115. VADDPS X8, X3, X3
  116. VMOVHLPS X3, X3, X8
  117. VADDPS X8, X3, X3
  118. VPSHUFD $0xB1, X3, X8
  119. VADDPS X8, X3, X3
  120. MOVSS X3, 12(DX)
  121. // Y4 -> out[4]
  122. VEXTRACTF128 $1, Y4, X8
  123. VADDPS X8, X4, X4
  124. VMOVHLPS X4, X4, X8
  125. VADDPS X8, X4, X4
  126. VPSHUFD $0xB1, X4, X8
  127. VADDPS X8, X4, X4
  128. MOVSS X4, 16(DX)
  129. // Y5 -> out[5]
  130. VEXTRACTF128 $1, Y5, X8
  131. VADDPS X8, X5, X5
  132. VMOVHLPS X5, X5, X8
  133. VADDPS X8, X5, X5
  134. VPSHUFD $0xB1, X5, X8
  135. VADDPS X8, X5, X5
  136. MOVSS X5, 20(DX)
  137. // Y6 -> out[6]
  138. VEXTRACTF128 $1, Y6, X8
  139. VADDPS X8, X6, X6
  140. VMOVHLPS X6, X6, X8
  141. VADDPS X8, X6, X6
  142. VPSHUFD $0xB1, X6, X8
  143. VADDPS X8, X6, X6
  144. MOVSS X6, 24(DX)
  145. // Y7 -> out[7]
  146. VEXTRACTF128 $1, Y7, X8
  147. VADDPS X8, X7, X7
  148. VMOVHLPS X7, X7, X8
  149. VADDPS X8, X7, X7
  150. VPSHUFD $0xB1, X7, X8
  151. VADDPS X8, X7, X7
  152. MOVSS X7, 28(DX)
  153. VZEROUPPER
  154. JMP epilogue
  155. zero:
  156. VXORPS X0, X0, X0
  157. MOVSS X0, 0(DX)
  158. MOVSS X0, 4(DX)
  159. MOVSS X0, 8(DX)
  160. MOVSS X0, 12(DX)
  161. MOVSS X0, 16(DX)
  162. MOVSS X0, 20(DX)
  163. MOVSS X0, 24(DX)
  164. MOVSS X0, 28(DX)
  165. VZEROUPPER
  166. JMP epilogue
  167. epilogue:
  168. MOVQ 0(SP), AX
  169. MOVQ 8(SP), BX
  170. MOVQ 16(SP), CX
  171. MOVQ 24(SP), DX
  172. MOVQ 32(SP), DI
  173. MOVQ 40(SP), SI
  174. MOVQ 48(SP), R8
  175. MOVQ 56(SP), R9
  176. MOVQ 64(SP), R10
  177. MOVQ 72(SP), R11
  178. MOVQ 80(SP), R12
  179. MOVQ 88(SP), R13
  180. RET