1
0

simd_dequant_avx2.s 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897
  1. //go:build amd64
  2. // +build amd64
  3. #include "textflag.h"
  4. // ============================================================================
  5. // Q8_K Dequantization - AVX2
  6. // BlockQ8_K layout:
  7. // offset 0: D (float32)
  8. // offset 4: QS[256] (int8)
  9. // ============================================================================
  10. // func dequantQ8KAVX2(b *BlockQ8_K, out *float32)
  11. TEXT ·dequantQ8KAVX2(SB), NOSPLIT, $0-16
  12. MOVQ b+0(FP), DI
  13. MOVQ out+8(FP), SI
  14. // Broadcast scale d to Y0
  15. VBROADCASTSS (DI), Y0
  16. // QS pointer = b + 4
  17. LEAQ 4(DI), R8
  18. MOVQ SI, R9
  19. // Process 256 elements, 32 at a time (unrolled)
  20. MOVQ $0, CX
  21. loop_q8:
  22. CMPQ CX, $256
  23. JGE done_q8
  24. // Load 8 int8, sign-extend to 8 int32, convert to float, multiply by scale
  25. VPMOVSXBD (R8), Y1
  26. VCVTDQ2PS Y1, Y1
  27. VMULPS Y0, Y1, Y1
  28. VMOVUPS Y1, (R9)
  29. VPMOVSXBD 8(R8), Y2
  30. VCVTDQ2PS Y2, Y2
  31. VMULPS Y0, Y2, Y2
  32. VMOVUPS Y2, 32(R9)
  33. VPMOVSXBD 16(R8), Y3
  34. VCVTDQ2PS Y3, Y3
  35. VMULPS Y0, Y3, Y3
  36. VMOVUPS Y3, 64(R9)
  37. VPMOVSXBD 24(R8), Y4
  38. VCVTDQ2PS Y4, Y4
  39. VMULPS Y0, Y4, Y4
  40. VMOVUPS Y4, 96(R9)
  41. ADDQ $32, R8
  42. ADDQ $128, R9
  43. ADDQ $32, CX
  44. JMP loop_q8
  45. done_q8:
  46. VZEROUPPER
  47. RET
  48. // =========================================================================
  49. // Q3_K Fused Dot Inner Loop - AVX2
  50. // Computes: sum_i x[i] * (dl * qv_i) for 16 elements
  51. // where qv_i is 2-bit value with sign via hm/m.
  52. // func dotQ3KInnerAVX2Fused(q *byte, hm *byte, x *float32, dl float32, m uint8, shift uint) float32
  53. TEXT ·dotQ3KInnerAVX2Fused(SB), NOSPLIT, $0-48
  54. MOVQ q+0(FP), DI
  55. MOVQ hm+8(FP), SI
  56. MOVQ x+16(FP), DX
  57. // Load dl (float32) -> Y0
  58. VBROADCASTSS dl+24(FP), Y0
  59. // Load shift -> X2
  60. MOVQ shift+32(FP), AX
  61. MOVD AX, X2
  62. // Load 16 bytes from q -> 16 words in Y3
  63. VPMOVZXBW (DI), Y3
  64. // Shift right by variable 'shift'
  65. VPSRLW X2, Y3, Y3
  66. // Mask with 3 (0x0003)
  67. MOVL $3, BX
  68. MOVD BX, X4
  69. VPBROADCASTW X4, Y4
  70. VPAND Y4, Y3, Y3 // Y3 = (q >> shift) & 3
  71. // Handle HMask
  72. MOVBLZX m+28(FP), BX
  73. MOVD BX, X5
  74. VPBROADCASTB X5, X5
  75. VMOVDQU (SI), X6
  76. VPAND X5, X6, X6
  77. VPXOR X7, X7, X7
  78. VPCMPEQB X7, X6, X6
  79. VPMOVSXBW X6, Y6
  80. MOVL $-4, BX
  81. MOVD BX, X7
  82. VPBROADCASTW X7, Y7
  83. VPAND Y7, Y6, Y6
  84. VPADDW Y6, Y3, Y3
  85. // Split into low/high 8 words -> int32
  86. // Low half of Y3 is already in X3.
  87. VPMOVSXWD X3, Y8
  88. VEXTRACTI128 $1, Y3, X4
  89. VPMOVSXWD X4, Y9
  90. // Convert to float and scale by dl
  91. VCVTDQ2PS Y8, Y8
  92. VCVTDQ2PS Y9, Y9
  93. VMULPS Y0, Y8, Y8
  94. VMULPS Y0, Y9, Y9
  95. // Load x and accumulate dot
  96. VMOVUPS (DX), Y10
  97. VMOVUPS 32(DX), Y11
  98. VMULPS Y10, Y8, Y8
  99. VMULPS Y11, Y9, Y9
  100. VADDPS Y9, Y8, Y8
  101. // Horizontal sum Y8 -> X0
  102. VEXTRACTF128 $1, Y8, X1
  103. VADDPS X1, X8, X8
  104. VHADDPS X8, X8, X8
  105. VHADDPS X8, X8, X8
  106. VMOVSS X8, ret+40(FP)
  107. VZEROUPPER
  108. RET
  109. // ==========================================================================
  110. // Q2_K Fused Dot - AVX2
  111. // Computes: sum_i x[i] * (dl*val_i - ml), for i in [0..15]
  112. // where val_i = (q[i] >> shift) & 3
  113. // func dotQ2KInnerAVX2(q *byte, x *float32, dl, ml float32, shift uint) float32
  114. TEXT ·dotQ2KInnerAVX2Fused(SB), NOSPLIT, $0-40
  115. MOVQ q+0(FP), DI
  116. MOVQ x+8(FP), SI
  117. VBROADCASTSS dl+16(FP), Y0
  118. VBROADCASTSS ml+20(FP), Y1
  119. MOVQ shift+24(FP), CX
  120. // Mask for 2 bits
  121. MOVL $0x03030303, AX
  122. MOVD AX, X7
  123. VPBROADCASTD X7, Y7
  124. // Shift amount
  125. MOVD CX, X6
  126. // Accumulator
  127. VXORPS Y15, Y15, Y15
  128. // Low 8 bytes
  129. VPMOVZXBD (DI), Y2
  130. VPSRLD X6, Y2, Y2
  131. VPAND Y7, Y2, Y2
  132. VCVTDQ2PS Y2, Y2
  133. VMULPS Y0, Y2, Y2
  134. VSUBPS Y1, Y2, Y2
  135. VMOVUPS (SI), Y4
  136. VFMADD231PS Y4, Y2, Y15
  137. // High 8 bytes
  138. VPMOVZXBD 8(DI), Y3
  139. VPSRLD X6, Y3, Y3
  140. VPAND Y7, Y3, Y3
  141. VCVTDQ2PS Y3, Y3
  142. VMULPS Y0, Y3, Y3
  143. VSUBPS Y1, Y3, Y3
  144. VMOVUPS 32(SI), Y5
  145. VFMADD231PS Y5, Y3, Y15
  146. // Reduce Y15 -> scalar
  147. VEXTRACTF128 $1, Y15, X1
  148. VADDPS X1, X15, X0
  149. VHADDPS X0, X0, X0
  150. VHADDPS X0, X0, X0
  151. MOVSS X0, ret+32(FP)
  152. VZEROUPPER
  153. RET
  154. // func dotQ5KInnerAVX2(qs *byte, qh *byte, x *float32, d1, m1, d2, m2 float32, u1, u2 uint) float32
  155. TEXT ·dotQ5KInnerAVX2(SB), NOSPLIT, $0-64
  156. MOVQ qs+0(FP), DI
  157. MOVQ qh+8(FP), SI
  158. MOVQ x+16(FP), DX
  159. VBROADCASTSS d1+24(FP), Y0
  160. VBROADCASTSS m1+28(FP), Y1
  161. VBROADCASTSS d2+32(FP), Y2
  162. VBROADCASTSS m2+36(FP), Y3
  163. // low-nibble mask 0x0F
  164. MOVL $0x0F0F0F0F, AX
  165. MOVD AX, X6
  166. VPBROADCASTD X6, Y6
  167. // int32(16) constant
  168. MOVL $16, AX
  169. MOVD AX, X7
  170. VPBROADCASTD X7, Y7
  171. // mask1 = 1 << u1
  172. MOVQ u1+40(FP), CX
  173. MOVQ $1, AX
  174. SHLQ CL, AX
  175. MOVL AX, BX
  176. MOVD BX, X4
  177. VPBROADCASTD X4, Y4
  178. // mask2 = 1 << u2
  179. MOVQ u2+48(FP), CX
  180. MOVQ $1, AX
  181. SHLQ CL, AX
  182. MOVL AX, BX
  183. MOVD BX, X5
  184. VPBROADCASTD X5, Y5
  185. // zero and ones (int dwords)
  186. VXORPS Y8, Y8, Y8
  187. VPCMPEQD Y9, Y9, Y9
  188. // Accumulator
  189. VXORPS Y15, Y15, Y15
  190. MOVQ $0, CX
  191. dot_q5k_loop:
  192. CMPQ CX, $32
  193. JGE dot_q5k_reduce
  194. // Load 8 qs bytes
  195. VPMOVZXBD (DI), Y10
  196. // low nibble -> int32
  197. VPAND Y6, Y10, Y11
  198. // high nibble -> int32
  199. VPSRLD $4, Y10, Y10
  200. // Load 8 qh bytes and copy
  201. VPMOVZXBD (SI), Y12
  202. VMOVAPS Y12, Y13
  203. // flag1 int32: (qh & mask1) ? 16 : 0
  204. VPAND Y4, Y12, Y12
  205. VPCMPEQD Y8, Y12, Y12
  206. VPXOR Y9, Y12, Y12
  207. VPAND Y7, Y12, Y12
  208. // low dequant: (float(lowNib + flag1) * d1) - m1
  209. VCVTDQ2PS Y11, Y11
  210. VCVTDQ2PS Y12, Y12
  211. VADDPS Y12, Y11, Y11
  212. VMULPS Y0, Y11, Y11
  213. VSUBPS Y1, Y11, Y11
  214. VMOVUPS (DX), Y14
  215. VFMADD231PS Y14, Y11, Y15
  216. // flag2 int32: (qh & mask2) ? 16 : 0
  217. VPAND Y5, Y13, Y13
  218. VPCMPEQD Y8, Y13, Y13
  219. VPXOR Y9, Y13, Y13
  220. VPAND Y7, Y13, Y13
  221. // high dequant: (float(highNib + flag2) * d2) - m2
  222. VCVTDQ2PS Y10, Y10
  223. VCVTDQ2PS Y13, Y13
  224. VADDPS Y13, Y10, Y10
  225. VMULPS Y2, Y10, Y10
  226. VSUBPS Y3, Y10, Y10
  227. VMOVUPS 128(DX), Y14
  228. VFMADD231PS Y14, Y10, Y15
  229. ADDQ $8, DI
  230. ADDQ $8, SI
  231. ADDQ $32, DX
  232. ADDQ $8, CX
  233. JMP dot_q5k_loop
  234. dot_q5k_reduce:
  235. VEXTRACTF128 $1, Y15, X1
  236. VADDPS X1, X15, X0
  237. VHADDPS X0, X0, X0
  238. VHADDPS X0, X0, X0
  239. MOVSS X0, ret+56(FP)
  240. VZEROUPPER
  241. RET
  242. // ==========================================================================
  243. // Q8_K Fused Dot - AVX2
  244. // Computes: sum_i x[i] * (d * float32(qs[i]))
  245. // func dotQ8KAVX2(b *BlockQ8_K, x *float32) float32
  246. TEXT ·dotQ8KAVX2(SB), NOSPLIT, $0-24
  247. MOVQ b+0(FP), DI
  248. MOVQ x+8(FP), SI
  249. // Load scale d
  250. MOVSS (DI), X0
  251. VBROADCASTSS X0, Y0
  252. // QS pointer = b + 4
  253. LEAQ 4(DI), R8
  254. // Accumulator Y1
  255. VXORPS Y1, Y1, Y1
  256. MOVQ $0, CX
  257. dot_q8_256_loop:
  258. CMPQ CX, $256
  259. JGE dot_q8_256_reduce
  260. // 8x int8 -> 8x int32 -> 8x float
  261. VPMOVSXBD (R8), Y2
  262. VCVTDQ2PS Y2, Y2
  263. VMULPS Y0, Y2, Y2
  264. // load 8 floats from x
  265. VMOVUPS (SI), Y3
  266. VFMADD231PS Y3, Y2, Y1
  267. ADDQ $8, R8
  268. ADDQ $32, SI
  269. ADDQ $8, CX
  270. JMP dot_q8_256_loop
  271. dot_q8_256_reduce:
  272. // horizontal add ymm1 -> scalar x1
  273. VEXTRACTF128 $1, Y1, X2
  274. VADDPS X2, X1, X1
  275. VHADDPS X1, X1, X1
  276. VHADDPS X1, X1, X1
  277. MOVSS X1, ret+16(FP)
  278. VZEROUPPER
  279. RET
  280. // ============================================================================
  281. // Q4_K Inner Loop - AVX2 (Vectorized nibble extraction)
  282. // Processes 32 4-bit quants with pre-computed scales.
  283. // ============================================================================
  284. // func dequantQ4KInnerAVX2(qs *byte, out *float32, d1, m1, d2, m2 float32)
  285. TEXT ·dequantQ4KInnerAVX2(SB), NOSPLIT, $0-40
  286. MOVQ qs+0(FP), DI
  287. MOVQ out+8(FP), SI
  288. // Broadcast d1, m1, d2, m2
  289. VBROADCASTSS d1+16(FP), Y0 // d1
  290. VBROADCASTSS m1+20(FP), Y1 // m1
  291. VBROADCASTSS d2+24(FP), Y2 // d2
  292. VBROADCASTSS m2+28(FP), Y3 // m2
  293. // Mask for low nibble (0x0F repeated)
  294. MOVL $0x0F0F0F0F, AX
  295. MOVD AX, X7
  296. VPBROADCASTD X7, Y7
  297. // Process 32 quants, 8 at a time
  298. MOVQ $0, CX
  299. loop_q4k:
  300. CMPQ CX, $32
  301. JGE done_q4k
  302. // Load 8 bytes from QS as unsigned
  303. VPMOVZXBD (DI), Y4 // 8 bytes -> 8 uint32
  304. // Extract low nibbles: v1 = val & 0xF
  305. VPAND Y7, Y4, Y5
  306. VCVTDQ2PS Y5, Y5
  307. VFMSUB132PS Y0, Y1, Y5 // out[i] = v1*d1 - m1
  308. VMOVUPS Y5, (SI)
  309. // Extract high nibbles: v2 = val >> 4
  310. VPSRLD $4, Y4, Y4
  311. VCVTDQ2PS Y4, Y4
  312. VFMSUB132PS Y2, Y3, Y4 // out[i+32] = v2*d2 - m2
  313. VMOVUPS Y4, 128(SI) // 32 * 4 bytes offset
  314. ADDQ $8, DI
  315. ADDQ $32, SI
  316. ADDQ $8, CX
  317. JMP loop_q4k
  318. done_q4k:
  319. VZEROUPPER
  320. RET
  321. // func dotQ4KInnerAVX2(qs *byte, x *float32, d1, m1, d2, m2 float32) float32
  322. TEXT ·dotQ4KInnerAVX2(SB), NOSPLIT, $0-40
  323. MOVQ qs+0(FP), DI
  324. MOVQ x+8(FP), SI
  325. // Broadcast d1, m1, d2, m2
  326. VBROADCASTSS d1+16(FP), Y0
  327. VBROADCASTSS m1+20(FP), Y1
  328. VBROADCASTSS d2+24(FP), Y2
  329. VBROADCASTSS m2+28(FP), Y3
  330. // Mask for low nibble (0x0F repeated)
  331. MOVL $0x0F0F0F0F, AX
  332. MOVD AX, X7
  333. VPBROADCASTD X7, Y7
  334. // Accumulators:
  335. // Y12 = sum(x_low * v1)
  336. // Y13 = sum(x_low)
  337. // Y14 = sum(x_high * v2)
  338. // Y15 = sum(x_high)
  339. VXORPS Y12, Y12, Y12
  340. VXORPS Y13, Y13, Y13
  341. VXORPS Y14, Y14, Y14
  342. VXORPS Y15, Y15, Y15
  343. // Process 32 bytes as 4x8-byte chunks
  344. MOVQ $0, CX
  345. dot_q4k_loop:
  346. CMPQ CX, $32
  347. JGE dot_q4k_reduce
  348. // Load 8 bytes from QS as unsigned dwords
  349. VPMOVZXBD (DI), Y8
  350. // Low nibble values -> float
  351. VPAND Y7, Y8, Y9
  352. VCVTDQ2PS Y9, Y9
  353. // x low: 8 floats
  354. VMOVUPS (SI), Y10
  355. VADDPS Y10, Y13, Y13
  356. VFMADD231PS Y10, Y9, Y12
  357. // High nibble values -> float
  358. VPSRLD $4, Y8, Y8
  359. VCVTDQ2PS Y8, Y8
  360. // x high: offset by 32 floats (128 bytes)
  361. VMOVUPS 128(SI), Y11
  362. VADDPS Y11, Y15, Y15
  363. VFMADD231PS Y11, Y8, Y14
  364. ADDQ $8, DI
  365. ADDQ $32, SI
  366. ADDQ $8, CX
  367. JMP dot_q4k_loop
  368. dot_q4k_reduce:
  369. // result = d1*sum(x1*v1) - m1*sum(x1) + d2*sum(x2*v2) - m2*sum(x2)
  370. VMULPS Y0, Y12, Y12
  371. VMULPS Y1, Y13, Y13
  372. VSUBPS Y13, Y12, Y12
  373. VMULPS Y2, Y14, Y14
  374. VMULPS Y3, Y15, Y15
  375. VSUBPS Y15, Y14, Y14
  376. VADDPS Y14, Y12, Y12
  377. // Horizontal add ymm12 -> scalar in X0
  378. VEXTRACTF128 $1, Y12, X1
  379. VADDPS X1, X12, X0
  380. VHADDPS X0, X0, X0
  381. VHADDPS X0, X0, X0
  382. MOVSS X0, ret+32(FP)
  383. VZEROUPPER
  384. RET
  385. // ============================================================================
  386. // Q2_K Inner Loop - AVX2
  387. // Processes 16 2-bit values with scale and min applied
  388. // ============================================================================
  389. // func dequantQ2KInnerAVX2(q *byte, out *float32, dl, ml float32, shift uint)
  390. TEXT ·dequantQ2KInnerAVX2(SB), NOSPLIT, $0-32
  391. MOVQ q+0(FP), DI
  392. MOVQ out+8(FP), SI
  393. VBROADCASTSS dl+16(FP), Y0
  394. VBROADCASTSS ml+20(FP), Y1
  395. MOVQ shift+24(FP), CX
  396. // Mask for 2 bits
  397. MOVL $0x03030303, AX
  398. MOVD AX, X7
  399. VPBROADCASTD X7, Y7
  400. // Load 16 bytes, extract 2-bit values
  401. VPMOVZXBD (DI), Y2 // First 8 bytes -> 8 int32
  402. VPMOVZXBD 8(DI), Y3 // Next 8 bytes -> 8 int32
  403. // Shift right by 'shift' and mask
  404. MOVD CX, X6
  405. VPSRLD X6, Y2, Y2
  406. VPAND Y7, Y2, Y2
  407. VPSRLD X6, Y3, Y3
  408. VPAND Y7, Y3, Y3
  409. // Convert to float and compute: dl*val - ml
  410. VCVTDQ2PS Y2, Y2
  411. VFMSUB132PS Y0, Y1, Y2
  412. VMOVUPS Y2, (SI)
  413. VCVTDQ2PS Y3, Y3
  414. VFMSUB132PS Y0, Y1, Y3
  415. VMOVUPS Y3, 32(SI)
  416. VZEROUPPER
  417. RET
  418. // ============================================================================
  419. // Q3_K Inner Loop - AVX2
  420. // Processes 16 output elements (consuming 16 bytes from q)
  421. // ============================================================================
  422. // func dequantQ3KInnerAVX2(q *byte, hm *byte, out *float32, dl float32, m uint8, shift uint)
  423. TEXT ·dequantQ3KInnerAVX2(SB), NOSPLIT, $0-40
  424. MOVQ q+0(FP), DI
  425. MOVQ hm+8(FP), SI
  426. MOVQ out+16(FP), DX
  427. // Load dl (float32) -> Y0
  428. VBROADCASTSS dl+24(FP), Y0
  429. // Load shift -> X2
  430. MOVQ shift+32(FP), AX
  431. MOVD AX, X2
  432. // Load 16 bytes from q -> 16 words in Y3
  433. VPMOVZXBW (DI), Y3
  434. // Shift right by variable 'shift'
  435. VPSRLW X2, Y3, Y3
  436. // Mask with 3 (0x0003)
  437. MOVL $3, BX
  438. MOVD BX, X4
  439. VPBROADCASTW X4, Y4
  440. VPAND Y4, Y3, Y3 // Y3 = (q >> shift) & 3
  441. // Handle HMask
  442. // Load `m` (byte)
  443. MOVBLZX m+28(FP), BX
  444. MOVD BX, X5
  445. VPBROADCASTB X5, X5 // X5 = m repeated
  446. // Load 16 bytes hm
  447. VMOVDQU (SI), X6
  448. // Check (hm & m) == 0
  449. VPAND X5, X6, X6 // X = hm & m
  450. VPXOR X7, X7, X7 // Zero
  451. VPCMPEQB X7, X6, X6 // X6 = (hm&m == 0) ? FF : 00
  452. // Expand byte mask to word mask (-1 or 0)
  453. VPMOVSXBW X6, Y6 // Y6 = -1 or 0
  454. // We want to subtract 4 if mask is -1.
  455. // Add (mask & -4).
  456. MOVL $-4, BX // 0xFFFFFFFC
  457. MOVD BX, X7
  458. VPBROADCASTW X7, Y7 // Y7 = -4 repeated
  459. VPAND Y7, Y6, Y6 // Y6 = -4 or 0
  460. VPADDW Y6, Y3, Y3 // Y3 = val - 4 (if needed)
  461. // Convert to float (Y3 has 16 int16)
  462. // Split into low 8 (Y8) and high 8 (Y9) as int32
  463. VPMOVSXWD X3, Y8 // Low 8 words -> 8 int32
  464. VEXTRACTI128 $1, Y3, X3
  465. VPMOVSXWD X3, Y9 // High 8 words -> 8 int32
  466. VCVTDQ2PS Y8, Y8
  467. VCVTDQ2PS Y9, Y9
  468. VMULPS Y0, Y8, Y8
  469. VMULPS Y0, Y9, Y9
  470. // Store 16 floats
  471. VMOVUPS Y8, (DX)
  472. VMOVUPS Y9, 32(DX)
  473. VZEROUPPER
  474. RET
  475. // ============================================================================
  476. // Q6_K Inner Loop - AVX2
  477. //func dequantQ6KInnerAVX2(ql *byte, qh *byte, scales *int8, out *float32, d float32)
  478. // Processes 128 elements (all 4 sub-blocks)
  479. // ============================================================================
  480. TEXT ·dequantQ6KInnerAVX2(SB), NOSPLIT, $0-40
  481. MOVQ ql+0(FP), DI
  482. MOVQ qh+8(FP), SI
  483. MOVQ scales+16(FP), DX
  484. MOVQ out+24(FP), R8
  485. // Broadcast d (float32) -> Y0
  486. VBROADCASTSS d+32(FP), Y0
  487. // Y15 = 0x0F (mask)
  488. MOVL $0x0F, AX
  489. MOVD AX, X15
  490. VPBROADCASTB X15, Y15
  491. // Y14 = 0x03 (mask)
  492. MOVL $0x03, AX
  493. MOVD AX, X14
  494. VPBROADCASTB X14, Y14
  495. // Y13 = 32.0 (float)
  496. MOVL $0x42000000, AX // float 32.0
  497. MOVD AX, X13
  498. VBROADCASTSS X13, Y13
  499. // Registers:
  500. // R9: Loop counter (0, 16)
  501. MOVQ $0, R9
  502. loop_q6k:
  503. CMPQ R9, $32
  504. JGE done_q6k
  505. // Load qh chunk (16 bytes) -> X1
  506. VMOVDQU (SI)(R9*1), X1
  507. // Load ql chunk 1 (16 bytes) -> X2
  508. VMOVDQU (DI)(R9*1), X2
  509. // Load ql chunk 2 (16 bytes) -> X3
  510. VMOVDQU 32(DI)(R9*1), X3
  511. // Mask for bit shifting logic
  512. MOVL $0xF0, AX
  513. MOVD AX, X6
  514. VPBROADCASTB X6, X6
  515. // --- Q1 ---
  516. // (ql_c1 & 0xF) | ((qh & 3) << 4)
  517. VPAND X15, X2, X4
  518. VPAND X14, X1, X5
  519. VPSLLW $4, X5, X5
  520. VPAND X6, X5, X5
  521. VPOR X5, X4, X4 // X4 = Q1 values (16 bytes)
  522. // Scale s[0] or s[1] -> offset R9>>4
  523. MOVQ R9, R11
  524. SHRQ $4, R11 // 0 or 1
  525. MOVBQSX (DX)(R11*1), BX
  526. // Convert X4 -> Y4/Y5 (floats), scale by BX, d, sub 32, store
  527. // (Inline expansion)
  528. VCVTSI2SSQ BX, X10, X10
  529. VBROADCASTSS X10, Y10 // Scale
  530. VMOVDQA X4, X7 // Copy X4 to X7
  531. VPMOVZXBD X7, Y4 // Low 8 bytes from X7 -> Y4
  532. VPSRLDQ $8, X4, X8 // Shift X4 -> X8
  533. VPMOVZXBD X8, Y5 // Low 8 bytes from X8 -> Y5
  534. VCVTDQ2PS Y4, Y4
  535. VCVTDQ2PS Y5, Y5
  536. VSUBPS Y13, Y4, Y4
  537. VSUBPS Y13, Y5, Y5
  538. VMULPS Y10, Y4, Y4
  539. VMULPS Y0, Y4, Y4
  540. VMULPS Y10, Y5, Y5
  541. VMULPS Y0, Y5, Y5
  542. // Store to out + R9*4
  543. LEAQ (R8)(R9*4), R12
  544. VMOVUPS Y4, (R12)
  545. VMOVUPS Y5, 32(R12)
  546. // --- Q2 ---
  547. // (ql_c2 & 0xF) | (((qh >> 2) & 3) << 4)
  548. VPAND X15, X3, X4
  549. VPSRLW $2, X1, X5
  550. VPAND X14, X5, X5
  551. VPSLLW $4, X5, X5
  552. VPAND X6, X5, X5
  553. VPOR X5, X4, X4
  554. // Scale s[2] or s[3] -> offset R9>>4 + 2
  555. MOVBQSX 2(DX)(R11*1), BX
  556. VCVTSI2SSQ BX, X10, X10
  557. VBROADCASTSS X10, Y10
  558. VMOVDQA X4, X7
  559. VPMOVZXBD X7, Y4
  560. VPSRLDQ $8, X4, X8
  561. VPMOVZXBD X8, Y5
  562. VCVTDQ2PS Y4, Y4
  563. VCVTDQ2PS Y5, Y5
  564. VSUBPS Y13, Y4, Y4
  565. VSUBPS Y13, Y5, Y5
  566. VMULPS Y10, Y4, Y4
  567. VMULPS Y0, Y4, Y4
  568. VMULPS Y10, Y5, Y5
  569. VMULPS Y0, Y5, Y5
  570. // Store to out + 128 + R9*4
  571. LEAQ 128(R8)(R9*4), R12
  572. VMOVUPS Y4, (R12)
  573. VMOVUPS Y5, 32(R12)
  574. // --- Q3 ---
  575. // (ql_c1 >> 4) | (((qh >> 4) & 3) << 4)
  576. VPSRLW $4, X2, X4
  577. VPAND X15, X4, X4
  578. VPSRLW $4, X1, X5
  579. VPAND X14, X5, X5
  580. VPSLLW $4, X5, X5
  581. VPAND X6, X5, X5
  582. VPOR X5, X4, X4
  583. // Scale s[4] or s[5]
  584. MOVBQSX 4(DX)(R11*1), BX
  585. VCVTSI2SSQ BX, X10, X10
  586. VBROADCASTSS X10, Y10
  587. VMOVDQA X4, X7
  588. VPMOVZXBD X7, Y4
  589. VPSRLDQ $8, X4, X8
  590. VPMOVZXBD X8, Y5
  591. VCVTDQ2PS Y4, Y4
  592. VCVTDQ2PS Y5, Y5
  593. VSUBPS Y13, Y4, Y4
  594. VSUBPS Y13, Y5, Y5
  595. VMULPS Y10, Y4, Y4
  596. VMULPS Y0, Y4, Y4
  597. VMULPS Y10, Y5, Y5
  598. VMULPS Y0, Y5, Y5
  599. LEAQ 256(R8)(R9*4), R12
  600. VMOVUPS Y4, (R12)
  601. VMOVUPS Y5, 32(R12)
  602. // --- Q4 ---
  603. // (ql_c2 >> 4) | (((qh >> 6) & 3) << 4)
  604. VPSRLW $4, X3, X4
  605. VPAND X15, X4, X4
  606. VPSRLW $6, X1, X5
  607. VPAND X14, X5, X5
  608. VPSLLW $4, X5, X5
  609. VPAND X6, X5, X5
  610. VPOR X5, X4, X4
  611. // Scale s[6] or s[7]
  612. MOVBQSX 6(DX)(R11*1), BX
  613. VCVTSI2SSQ BX, X10, X10
  614. VBROADCASTSS X10, Y10
  615. VMOVDQA X4, X7
  616. VPMOVZXBD X7, Y4
  617. VPSRLDQ $8, X4, X8
  618. VPMOVZXBD X8, Y5
  619. VCVTDQ2PS Y4, Y4
  620. VCVTDQ2PS Y5, Y5
  621. VSUBPS Y13, Y4, Y4
  622. VSUBPS Y13, Y5, Y5
  623. VMULPS Y10, Y4, Y4
  624. VMULPS Y0, Y4, Y4
  625. VMULPS Y10, Y5, Y5
  626. VMULPS Y0, Y5, Y5
  627. LEAQ 384(R8)(R9*4), R12
  628. VMOVUPS Y4, (R12)
  629. VMOVUPS Y5, 32(R12)
  630. ADDQ $16, R9
  631. JMP loop_q6k
  632. done_q6k:
  633. VZEROUPPER
  634. RET
  635. // ============================================================================
  636. // Q6_K Fused Dot Inner Loop - AVX2
  637. // func dotQ6KInnerAVX2(ql *byte, qh *byte, scales *float32, x *float32) float32
  638. // Processes 128 elements (one half of block), returns partial dot sum
  639. // scales: 8 precomputed float32 values (d*scale[0..7])
  640. // ============================================================================
  641. TEXT ·dotQ6KInnerAVX2(SB), NOSPLIT, $0-40
  642. MOVQ ql+0(FP), DI // QL pointer (64 bytes)
  643. MOVQ qh+8(FP), SI // QH pointer (32 bytes)
  644. MOVQ scales+16(FP), DX // Precomputed scales (8 floats)
  645. MOVQ x+24(FP), R8 // X pointer (128 floats)
  646. // Y11 = 0x0F as dwords (for masking after VPMOVZXBD)
  647. MOVL $0x0F, AX
  648. MOVD AX, X11
  649. VPBROADCASTD X11, Y11
  650. // Y10 = 0x03 as dwords
  651. MOVL $0x03, AX
  652. MOVD AX, X10
  653. VPBROADCASTD X10, Y10
  654. // Y9 = 32.0 (float bias)
  655. MOVL $0x42000000, AX
  656. MOVD AX, X9
  657. VBROADCASTSS X9, Y9
  658. // Y8 = accumulator for dot product
  659. VXORPS Y8, Y8, Y8
  660. // Process 8 elements at a time (4 iterations for 32 elements)
  661. // Each iteration: load 8 QL bytes, 8 QH bytes, compute Q1-Q4 for 8 elements
  662. MOVQ $0, R9
  663. dotq6k_loop:
  664. CMPQ R9, $32
  665. JGE dotq6k_done
  666. // R11 = R9 >> 4 (0 or 1, for scale indexing)
  667. MOVQ R9, R11
  668. SHRQ $4, R11
  669. // Load 8 bytes of QL (for Q1/Q3) and 8 bytes of QH
  670. VPMOVZXBD (DI)(R9*1), Y0 // QL[R9..R9+7] -> 8 dwords
  671. VPMOVZXBD 32(DI)(R9*1), Y1 // QL[32+R9..32+R9+7] -> 8 dwords (for Q2/Q4)
  672. VPMOVZXBD (SI)(R9*1), Y2 // QH[R9..R9+7] -> 8 dwords
  673. // --- Q1: (ql & 0xF) | ((qh & 3) << 4) ---
  674. VPAND Y11, Y0, Y3 // ql & 0x0F
  675. VPAND Y10, Y2, Y4 // qh & 0x03
  676. VPSLLD $4, Y4, Y4 // << 4
  677. VPOR Y4, Y3, Y3 // combine
  678. VCVTDQ2PS Y3, Y3
  679. VSUBPS Y9, Y3, Y3 // q - 32
  680. VBROADCASTSS (DX)(R11*4), Y4 // scale s[0] or s[1]
  681. VMULPS Y4, Y3, Y3 // * scale
  682. LEAQ (R8)(R9*4), R12
  683. VMOVUPS (R12), Y5 // x[R9..R9+7]
  684. VFMADD231PS Y3, Y5, Y8 // acc += q * x
  685. // --- Q2: (ql32 & 0xF) | (((qh >> 2) & 3) << 4) ---
  686. VPAND Y11, Y1, Y3 // ql32 & 0x0F
  687. VPSRLD $2, Y2, Y4 // qh >> 2
  688. VPAND Y10, Y4, Y4 // & 0x03
  689. VPSLLD $4, Y4, Y4 // << 4
  690. VPOR Y4, Y3, Y3 // combine
  691. VCVTDQ2PS Y3, Y3
  692. VSUBPS Y9, Y3, Y3 // q - 32
  693. VBROADCASTSS 8(DX)(R11*4), Y4 // scale s[2] or s[3]
  694. VMULPS Y4, Y3, Y3 // * scale
  695. VMOVUPS 128(R12), Y5 // x[32+R9..32+R9+7]
  696. VFMADD231PS Y3, Y5, Y8 // acc += q * x
  697. // --- Q3: (ql >> 4) | (((qh >> 4) & 3) << 4) ---
  698. VPSRLD $4, Y0, Y3 // ql >> 4
  699. VPAND Y11, Y3, Y3 // & 0x0F
  700. VPSRLD $4, Y2, Y4 // qh >> 4
  701. VPAND Y10, Y4, Y4 // & 0x03
  702. VPSLLD $4, Y4, Y4 // << 4
  703. VPOR Y4, Y3, Y3 // combine
  704. VCVTDQ2PS Y3, Y3
  705. VSUBPS Y9, Y3, Y3 // q - 32
  706. VBROADCASTSS 16(DX)(R11*4), Y4 // scale s[4] or s[5]
  707. VMULPS Y4, Y3, Y3 // * scale
  708. VMOVUPS 256(R12), Y5 // x[64+R9..64+R9+7]
  709. VFMADD231PS Y3, Y5, Y8 // acc += q * x
  710. // --- Q4: (ql32 >> 4) | (((qh >> 6) & 3) << 4) ---
  711. VPSRLD $4, Y1, Y3 // ql32 >> 4
  712. VPAND Y11, Y3, Y3 // & 0x0F
  713. VPSRLD $6, Y2, Y4 // qh >> 6
  714. VPAND Y10, Y4, Y4 // & 0x03
  715. VPSLLD $4, Y4, Y4 // << 4
  716. VPOR Y4, Y3, Y3 // combine
  717. VCVTDQ2PS Y3, Y3
  718. VSUBPS Y9, Y3, Y3 // q - 32
  719. VBROADCASTSS 24(DX)(R11*4), Y4 // scale s[6] or s[7]
  720. VMULPS Y4, Y3, Y3 // * scale
  721. VMOVUPS 384(R12), Y5 // x[96+R9..96+R9+7]
  722. VFMADD231PS Y3, Y5, Y8 // acc += q * x
  723. ADDQ $8, R9
  724. JMP dotq6k_loop
  725. dotq6k_done:
  726. // Horizontal sum of Y8
  727. VEXTRACTF128 $1, Y8, X0
  728. VADDPS X8, X0, X0
  729. VHADDPS X0, X0, X0
  730. VHADDPS X0, X0, X0
  731. VMOVSS X0, ret+32(FP)
  732. VZEROUPPER
  733. RET