simd_dequant_avx512.s 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521
  1. //go:build amd64
  2. // +build amd64
  3. #include "textflag.h"
  4. // ============================================================================
  5. // Q8_K Dequantization - AVX512
  6. // BlockQ8_K layout:
  7. // offset 0: D (float32)
  8. // offset 4: QS[256] (int8)
  9. // ============================================================================
  10. // func dequantQ8KAVX512(b *BlockQ8_K, out *float32)
  11. TEXT ·dequantQ8KAVX512(SB), NOSPLIT, $0-16
  12. MOVQ b+0(FP), DI
  13. MOVQ out+8(FP), SI
  14. // Broadcast scale d to Z0
  15. VBROADCASTSS (DI), Z0
  16. // QS pointer = b + 4
  17. LEAQ 4(DI), R8
  18. MOVQ SI, R9
  19. // Process 256 elements, 16 at a time (unrolled by 4)
  20. MOVQ $0, CX
  21. loop_q8:
  22. CMPQ CX, $256
  23. JGE done_q8
  24. // Load 16 int8, sign-extend to 16 int32, convert to float, multiply
  25. VPMOVSXBD (R8), Z1
  26. VCVTDQ2PS Z1, Z1
  27. VMULPS Z0, Z1, Z1
  28. VMOVUPS Z1, (R9)
  29. VPMOVSXBD 16(R8), Z2
  30. VCVTDQ2PS Z2, Z2
  31. VMULPS Z0, Z2, Z2
  32. VMOVUPS Z2, 64(R9)
  33. VPMOVSXBD 32(R8), Z3
  34. VCVTDQ2PS Z3, Z3
  35. VMULPS Z0, Z3, Z3
  36. VMOVUPS Z3, 128(R9)
  37. VPMOVSXBD 48(R8), Z4
  38. VCVTDQ2PS Z4, Z4
  39. VMULPS Z0, Z4, Z4
  40. VMOVUPS Z4, 192(R9)
  41. ADDQ $64, R8
  42. ADDQ $256, R9
  43. ADDQ $64, CX
  44. JMP loop_q8
  45. done_q8:
  46. VZEROUPPER
  47. RET
  48. // ==========================================================================
  49. // Q8_K Fused Dot - AVX512
  50. // Computes: sum_i x[i] * (d * float32(qs[i]))
  51. // func dotQ8KAVX512(b *BlockQ8_K, x *float32) float32
  52. TEXT ·dotQ8KAVX512(SB), NOSPLIT, $0-24
  53. MOVQ b+0(FP), DI
  54. MOVQ x+8(FP), SI
  55. // Load scale d
  56. MOVSS (DI), X0
  57. // QS pointer = b + 4
  58. LEAQ 4(DI), R8
  59. // Accumulator Z1
  60. VXORPS Z1, Z1, Z1
  61. MOVQ $0, CX
  62. dot_q8_512_loop:
  63. CMPQ CX, $256
  64. JGE dot_q8_512_reduce
  65. // 16x int8 -> 16x int32 -> 16x float
  66. VPMOVSXBD (R8), Z2
  67. VCVTDQ2PS Z2, Z2
  68. // load 16 floats from x
  69. VMOVUPS (SI), Z3
  70. VFMADD231PS Z3, Z2, Z1
  71. ADDQ $16, R8
  72. ADDQ $64, SI
  73. ADDQ $16, CX
  74. JMP dot_q8_512_loop
  75. dot_q8_512_reduce:
  76. // horizontal reduce Z1 -> X1
  77. VEXTRACTF32X8 $1, Z1, Y2
  78. VADDPS Y2, Y1, Y1
  79. VEXTRACTF128 $1, Y1, X2
  80. VADDPS X2, X1, X1
  81. VPSHUFD $0x4E, X1, X2
  82. VADDPS X2, X1, X1
  83. VPSHUFD $0xB1, X1, X2
  84. VADDPS X2, X1, X1
  85. // multiply by d
  86. MULSS X0, X1
  87. MOVSS X1, ret+16(FP)
  88. VZEROUPPER
  89. RET
  90. TEXT ·dotQ4KInnerAVX512(SB), NOSPLIT, $0-40
  91. MOVQ qs+0(FP), DI
  92. MOVQ x+8(FP), SI
  93. LEAQ 128(SI), R10
  94. VBROADCASTSS d1+16(FP), Z0
  95. VBROADCASTSS m1+20(FP), Z1
  96. VBROADCASTSS d2+24(FP), Z2
  97. VBROADCASTSS m2+28(FP), Z3
  98. MOVL $0x0F0F0F0F, AX
  99. MOVD AX, X7
  100. VPBROADCASTD X7, Z7
  101. VXORPS Z12, Z12, Z12
  102. VXORPS Z13, Z13, Z13
  103. VXORPS Z14, Z14, Z14
  104. VXORPS Z15, Z15, Z15
  105. MOVQ $0, CX
  106. dot_q4k_512_loop:
  107. CMPQ CX, $32
  108. JGE dot_q4k_512_reduce
  109. VPMOVZXBD (DI), Z8
  110. VPANDD Z7, Z8, Z9
  111. VCVTDQ2PS Z9, Z9
  112. VMOVUPS (SI), Z10
  113. VADDPS Z10, Z13, Z13
  114. VFMADD231PS Z10, Z9, Z12
  115. VPSRLD $4, Z8, Z8
  116. VCVTDQ2PS Z8, Z8
  117. VMOVUPS (R10), Z11
  118. VADDPS Z11, Z15, Z15
  119. VFMADD231PS Z11, Z8, Z14
  120. ADDQ $16, DI
  121. ADDQ $64, SI
  122. ADDQ $64, R10
  123. ADDQ $16, CX
  124. JMP dot_q4k_512_loop
  125. dot_q4k_512_reduce:
  126. VMULPS Z0, Z12, Z12
  127. VMULPS Z1, Z13, Z13
  128. VSUBPS Z13, Z12, Z12
  129. VMULPS Z2, Z14, Z14
  130. VMULPS Z3, Z15, Z15
  131. VSUBPS Z15, Z14, Z14
  132. VADDPS Z14, Z12, Z12
  133. VEXTRACTF32X8 $1, Z12, Y1
  134. VADDPS Y1, Y12, Y12
  135. VEXTRACTF128 $1, Y12, X1
  136. VADDPS X1, X12, X12
  137. VPSHUFD $0x4E, X12, X1
  138. VADDPS X1, X12, X12
  139. VPSHUFD $0xB1, X12, X1
  140. VADDPS X1, X12, X12
  141. MOVSS X12, ret+32(FP)
  142. VZEROUPPER
  143. RET
  144. // func dotQ5KInnerAVX512(qs *byte, qh *byte, x *float32, d1, m1, d2, m2 float32, u1, u2 uint) float32
  145. TEXT ·dotQ5KInnerAVX512(SB), NOSPLIT, $0-64
  146. MOVQ qs+0(FP), DI
  147. MOVQ qh+8(FP), SI
  148. MOVQ x+16(FP), DX
  149. LEAQ 128(DX), R10
  150. VBROADCASTSS d1+24(FP), Z0
  151. VBROADCASTSS m1+28(FP), Z1
  152. VBROADCASTSS d2+32(FP), Z2
  153. VBROADCASTSS m2+36(FP), Z3
  154. MOVL $0x0F0F0F0F, AX
  155. MOVD AX, X7
  156. VPBROADCASTD X7, Z7
  157. MOVL $1, AX
  158. MOVD AX, X4
  159. VPBROADCASTD X4, Z4
  160. MOVQ u1+40(FP), CX
  161. MOVQ $1, AX
  162. SHLQ CL, AX
  163. MOVL AX, BX
  164. MOVD BX, X6
  165. VPBROADCASTD X6, Z6
  166. MOVQ u2+48(FP), CX
  167. MOVQ $1, AX
  168. SHLQ CL, AX
  169. MOVL AX, BX
  170. MOVD BX, X5
  171. VPBROADCASTD X5, Z5
  172. VXORPS Z15, Z15, Z15
  173. MOVQ u1+40(FP), AX
  174. CMPQ AX, $0
  175. JE dot_q5k_loop_s0
  176. CMPQ AX, $2
  177. JE dot_q5k_loop_s1
  178. CMPQ AX, $4
  179. JE dot_q5k_loop_s2
  180. JMP dot_q5k_loop_s3
  181. dot_q5k_loop_s0:
  182. MOVQ $0, CX
  183. dot_q5k_loop0:
  184. CMPQ CX, $32
  185. JGE dot_q5k_reduce
  186. VPMOVZXBD (DI), Z11
  187. VPANDD Z7, Z11, Z9
  188. VPSRLD $4, Z11, Z10
  189. VPMOVZXBD (SI), Z12
  190. VPANDD Z6, Z12, Z13
  191. VPSRLD $0, Z13, Z13
  192. VPANDD Z4, Z13, Z13
  193. VPSLLD $4, Z13, Z13
  194. VPANDD Z5, Z12, Z8
  195. VPSRLD $1, Z8, Z8
  196. VPANDD Z4, Z8, Z8
  197. VPSLLD $4, Z8, Z8
  198. VPADDD Z13, Z9, Z9
  199. VPADDD Z8, Z10, Z10
  200. VCVTDQ2PS Z9, Z9
  201. VMULPS Z0, Z9, Z9
  202. VSUBPS Z1, Z9, Z9
  203. VMOVUPS (DX), Z14
  204. VFMADD231PS Z14, Z9, Z15
  205. VCVTDQ2PS Z10, Z10
  206. VMULPS Z2, Z10, Z10
  207. VSUBPS Z3, Z10, Z10
  208. VMOVUPS (R10), Z14
  209. VFMADD231PS Z14, Z10, Z15
  210. ADDQ $16, DI
  211. ADDQ $16, SI
  212. ADDQ $64, DX
  213. ADDQ $64, R10
  214. ADDQ $16, CX
  215. JMP dot_q5k_loop0
  216. dot_q5k_loop_s1:
  217. MOVQ $0, CX
  218. dot_q5k_loop1:
  219. CMPQ CX, $32
  220. JGE dot_q5k_reduce
  221. VPMOVZXBD (DI), Z11
  222. VPANDD Z7, Z11, Z9
  223. VPSRLD $4, Z11, Z10
  224. VPMOVZXBD (SI), Z12
  225. VPANDD Z6, Z12, Z13
  226. VPSRLD $2, Z13, Z13
  227. VPANDD Z4, Z13, Z13
  228. VPSLLD $4, Z13, Z13
  229. VPANDD Z5, Z12, Z8
  230. VPSRLD $3, Z8, Z8
  231. VPANDD Z4, Z8, Z8
  232. VPSLLD $4, Z8, Z8
  233. VPADDD Z13, Z9, Z9
  234. VPADDD Z8, Z10, Z10
  235. VCVTDQ2PS Z9, Z9
  236. VMULPS Z0, Z9, Z9
  237. VSUBPS Z1, Z9, Z9
  238. VMOVUPS (DX), Z14
  239. VFMADD231PS Z14, Z9, Z15
  240. VCVTDQ2PS Z10, Z10
  241. VMULPS Z2, Z10, Z10
  242. VSUBPS Z3, Z10, Z10
  243. VMOVUPS (R10), Z14
  244. VFMADD231PS Z14, Z10, Z15
  245. ADDQ $16, DI
  246. ADDQ $16, SI
  247. ADDQ $64, DX
  248. ADDQ $64, R10
  249. ADDQ $16, CX
  250. JMP dot_q5k_loop1
  251. dot_q5k_loop_s2:
  252. MOVQ $0, CX
  253. dot_q5k_loop2:
  254. CMPQ CX, $32
  255. JGE dot_q5k_reduce
  256. VPMOVZXBD (DI), Z11
  257. VPANDD Z7, Z11, Z9
  258. VPSRLD $4, Z11, Z10
  259. VPMOVZXBD (SI), Z12
  260. VPANDD Z6, Z12, Z13
  261. VPSRLD $4, Z13, Z13
  262. VPANDD Z4, Z13, Z13
  263. VPSLLD $4, Z13, Z13
  264. VPANDD Z5, Z12, Z8
  265. VPSRLD $5, Z8, Z8
  266. VPANDD Z4, Z8, Z8
  267. VPSLLD $4, Z8, Z8
  268. VPADDD Z13, Z9, Z9
  269. VPADDD Z8, Z10, Z10
  270. VCVTDQ2PS Z9, Z9
  271. VMULPS Z0, Z9, Z9
  272. VSUBPS Z1, Z9, Z9
  273. VMOVUPS (DX), Z14
  274. VFMADD231PS Z14, Z9, Z15
  275. VCVTDQ2PS Z10, Z10
  276. VMULPS Z2, Z10, Z10
  277. VSUBPS Z3, Z10, Z10
  278. VMOVUPS (R10), Z14
  279. VFMADD231PS Z14, Z10, Z15
  280. ADDQ $16, DI
  281. ADDQ $16, SI
  282. ADDQ $64, DX
  283. ADDQ $64, R10
  284. ADDQ $16, CX
  285. JMP dot_q5k_loop2
  286. dot_q5k_loop_s3:
  287. MOVQ $0, CX
  288. dot_q5k_loop3:
  289. CMPQ CX, $32
  290. JGE dot_q5k_reduce
  291. VPMOVZXBD (DI), Z11
  292. VPANDD Z7, Z11, Z9
  293. VPSRLD $4, Z11, Z10
  294. VPMOVZXBD (SI), Z12
  295. VPANDD Z6, Z12, Z13
  296. VPSRLD $6, Z13, Z13
  297. VPANDD Z4, Z13, Z13
  298. VPSLLD $4, Z13, Z13
  299. VPANDD Z5, Z12, Z8
  300. VPSRLD $7, Z8, Z8
  301. VPANDD Z4, Z8, Z8
  302. VPSLLD $4, Z8, Z8
  303. VPADDD Z13, Z9, Z9
  304. VPADDD Z8, Z10, Z10
  305. VCVTDQ2PS Z9, Z9
  306. VMULPS Z0, Z9, Z9
  307. VSUBPS Z1, Z9, Z9
  308. VMOVUPS (DX), Z14
  309. VFMADD231PS Z14, Z9, Z15
  310. VCVTDQ2PS Z10, Z10
  311. VMULPS Z2, Z10, Z10
  312. VSUBPS Z3, Z10, Z10
  313. VMOVUPS (R10), Z14
  314. VFMADD231PS Z14, Z10, Z15
  315. ADDQ $16, DI
  316. ADDQ $16, SI
  317. ADDQ $64, DX
  318. ADDQ $64, R10
  319. ADDQ $16, CX
  320. JMP dot_q5k_loop3
  321. dot_q5k_reduce:
  322. VEXTRACTF32X8 $1, Z15, Y1
  323. VADDPS Y1, Y15, Y15
  324. VEXTRACTF128 $1, Y15, X1
  325. VADDPS X1, X15, X15
  326. VPSHUFD $0x4E, X15, X1
  327. VADDPS X1, X15, X15
  328. VPSHUFD $0xB1, X15, X1
  329. VADDPS X1, X15, X15
  330. MOVSS X15, ret+56(FP)
  331. VZEROUPPER
  332. RET
  333. TEXT ·dotQ6KInnerAVX512(SB), NOSPLIT, $0-40
  334. MOVQ ql+0(FP), DI
  335. MOVQ qh+8(FP), SI
  336. MOVQ scales+16(FP), DX
  337. MOVQ x+24(FP), R8
  338. MOVL $0x0F, AX
  339. MOVD AX, X11
  340. VPBROADCASTD X11, Z11
  341. MOVL $0x03, AX
  342. MOVD AX, X10
  343. VPBROADCASTD X10, Z10
  344. MOVL $0x42000000, AX
  345. MOVD AX, X9
  346. VBROADCASTSS X9, Z9
  347. VXORPS Z15, Z15, Z15
  348. MOVQ $0, CX
  349. dot_q6k_512_loop:
  350. CMPQ CX, $32
  351. JGE dot_q6k_512_reduce
  352. MOVQ CX, R11
  353. SHRQ $4, R11
  354. VPMOVZXBD (DI)(CX*1), Z0
  355. VPMOVZXBD 32(DI)(CX*1), Z1
  356. VPMOVZXBD (SI)(CX*1), Z2
  357. VPANDD Z11, Z0, Z3
  358. VPANDD Z10, Z2, Z4
  359. VPSLLD $4, Z4, Z4
  360. VPADDD Z4, Z3, Z3
  361. VCVTDQ2PS Z3, Z3
  362. VSUBPS Z9, Z3, Z3
  363. VBROADCASTSS (DX)(R11*4), Z4
  364. VMULPS Z4, Z3, Z3
  365. LEAQ (R8)(CX*4), R12
  366. VMOVUPS (R12), Z5
  367. VFMADD231PS Z5, Z3, Z15
  368. VPANDD Z11, Z1, Z3
  369. VPSRLD $2, Z2, Z4
  370. VPANDD Z10, Z4, Z4
  371. VPSLLD $4, Z4, Z4
  372. VPADDD Z4, Z3, Z3
  373. VCVTDQ2PS Z3, Z3
  374. VSUBPS Z9, Z3, Z3
  375. VBROADCASTSS 8(DX)(R11*4), Z4
  376. VMULPS Z4, Z3, Z3
  377. VMOVUPS 128(R12), Z5
  378. VFMADD231PS Z5, Z3, Z15
  379. VPSRLD $4, Z0, Z3
  380. VPANDD Z11, Z3, Z3
  381. VPSRLD $4, Z2, Z4
  382. VPANDD Z10, Z4, Z4
  383. VPSLLD $4, Z4, Z4
  384. VPADDD Z4, Z3, Z3
  385. VCVTDQ2PS Z3, Z3
  386. VSUBPS Z9, Z3, Z3
  387. VBROADCASTSS 16(DX)(R11*4), Z4
  388. VMULPS Z4, Z3, Z3
  389. VMOVUPS 256(R12), Z5
  390. VFMADD231PS Z5, Z3, Z15
  391. VPSRLD $4, Z1, Z3
  392. VPANDD Z11, Z3, Z3
  393. VPSRLD $6, Z2, Z4
  394. VPANDD Z10, Z4, Z4
  395. VPSLLD $4, Z4, Z4
  396. VPADDD Z4, Z3, Z3
  397. VCVTDQ2PS Z3, Z3
  398. VSUBPS Z9, Z3, Z3
  399. VBROADCASTSS 24(DX)(R11*4), Z4
  400. VMULPS Z4, Z3, Z3
  401. VMOVUPS 384(R12), Z5
  402. VFMADD231PS Z5, Z3, Z15
  403. ADDQ $16, CX
  404. JMP dot_q6k_512_loop
  405. dot_q6k_512_reduce:
  406. VEXTRACTF32X8 $1, Z15, Y1
  407. VADDPS Y1, Y15, Y15
  408. VEXTRACTF128 $1, Y15, X1
  409. VADDPS X1, X15, X15
  410. VPSHUFD $0x4E, X15, X1
  411. VADDPS X1, X15, X15
  412. VPSHUFD $0xB1, X15, X1
  413. VADDPS X1, X15, X15
  414. MOVSS X15, ret+32(FP)
  415. VZEROUPPER
  416. RET