simd_dequant.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486
  1. package tensor
  2. import (
  3. "sync"
  4. "golang.org/x/sys/cpu"
  5. )
  6. var (
  7. hasAVX512 = cpu.X86.HasAVX512F && cpu.X86.HasAVX512DQ && cpu.X86.HasAVX512BW && cpu.X86.HasAVX512VL
  8. hasAVX2 = cpu.X86.HasAVX2 && cpu.X86.HasFMA
  9. )
  10. var q6kSimdOnce sync.Once
  11. var q6kSimdOK bool
  12. func absDiffF32(a, b float32) float32 {
  13. if a > b {
  14. return a - b
  15. }
  16. return b - a
  17. }
  18. func dequantizeQ6KScalar(b *BlockQ6_K, out []float32) {
  19. d := FP16ToFP32(b.D)
  20. qlPtr := 0
  21. qhPtr := 0
  22. scPtr := 0
  23. outPtr := 0
  24. for n := 0; n < 256; n += 128 {
  25. for l := 0; l < 32; l++ {
  26. is := l / 16
  27. ql0 := b.QL[qlPtr+l]
  28. ql32 := b.QL[qlPtr+l+32]
  29. qh := b.QH[qhPtr+l]
  30. q1 := int8((ql0&0xF)|(((qh>>0)&3)<<4)) - 32
  31. q2 := int8((ql32&0xF)|(((qh>>2)&3)<<4)) - 32
  32. q3 := int8((ql0>>4)|(((qh>>4)&3)<<4)) - 32
  33. q4 := int8((ql32>>4)|(((qh>>6)&3)<<4)) - 32
  34. out[outPtr+l+0] = d * float32(b.Scales[scPtr+is+0]) * float32(q1)
  35. out[outPtr+l+32] = d * float32(b.Scales[scPtr+is+2]) * float32(q2)
  36. out[outPtr+l+64] = d * float32(b.Scales[scPtr+is+4]) * float32(q3)
  37. out[outPtr+l+96] = d * float32(b.Scales[scPtr+is+6]) * float32(q4)
  38. }
  39. outPtr += 128
  40. qlPtr += 64
  41. qhPtr += 32
  42. scPtr += 8
  43. }
  44. }
  45. func q6kSimdReady() bool {
  46. if !hasAVX2 {
  47. return false
  48. }
  49. q6kSimdOnce.Do(func() {
  50. var b BlockQ6_K
  51. b.D = 0x3C00
  52. for i := range b.Scales {
  53. b.Scales[i] = int8((i % 16) - 8)
  54. }
  55. for i := range b.QL {
  56. b.QL[i] = uint8(i)
  57. }
  58. for i := range b.QH {
  59. b.QH[i] = uint8(i * 3)
  60. }
  61. var outScalar [256]float32
  62. dequantizeQ6KScalar(&b, outScalar[:])
  63. var outSimd [256]float32
  64. d := FP16ToFP32(b.D)
  65. dequantQ6KInnerAVX2(&b.QL[0], &b.QH[0], &b.Scales[0], &outSimd[0], d)
  66. dequantQ6KInnerAVX2(&b.QL[64], &b.QH[32], &b.Scales[8], &outSimd[128], d)
  67. for i := 0; i < 256; i++ {
  68. if absDiffF32(outSimd[i], outScalar[i]) > 1e-4 {
  69. q6kSimdOK = false
  70. return
  71. }
  72. }
  73. q6kSimdOK = true
  74. })
  75. return q6kSimdOK
  76. }
  77. // ============================================================================
  78. // Assembly Declarations
  79. // ============================================================================
  80. // Q8_K - Full AVX2/AVX512 vectorization
  81. //
  82. //go:noescape
  83. func dequantQ8KAVX512(b *BlockQ8_K, out *float32)
  84. //go:noescape
  85. func dequantQ8KAVX2(b *BlockQ8_K, out *float32)
  86. //go:noescape
  87. func dotQ8KAVX512(b *BlockQ8_K, x *float32) float32
  88. //go:noescape
  89. func dotQ8KAVX2(b *BlockQ8_K, x *float32) float32
  90. // Q4_K - Inner loop vectorization (32 quants at a time)
  91. //
  92. //go:noescape
  93. func dequantQ4KInnerAVX2(qs *byte, out *float32, d1, m1, d2, m2 float32)
  94. // Q4_K - Fused dot product (32 bytes -> 64 quants) against 64 float32 inputs.
  95. //
  96. //go:noescape
  97. func dotQ4KInnerAVX512(qs *byte, x *float32, d1, m1, d2, m2 float32) float32
  98. //go:noescape
  99. func dotQ4KInnerAVX2(qs *byte, x *float32, d1, m1, d2, m2 float32) float32
  100. //go:noescape
  101. func dotQ5KInnerAVX512(qs *byte, qh *byte, x *float32, d1, m1, d2, m2 float32, u1, u2 uint) float32
  102. //go:noescape
  103. func dotQ5KInnerAVX2(qs *byte, qh *byte, x *float32, d1, m1, d2, m2 float32, u1, u2 uint) float32
  104. // Q2_K - Inner loop vectorization (16 quants at a time)
  105. //
  106. //go:noescape
  107. func dequantQ2KInnerAVX2(q *byte, out *float32, dl, ml float32, shift uint)
  108. //go:noescape
  109. func dotQ2KInnerAVX2Fused(q *byte, x *float32, dl, ml float32, shift uint) float32
  110. // Q3_K - Inner loop vectorization
  111. //
  112. //go:noescape
  113. func dequantQ3KInnerAVX2(q *byte, hm *byte, out *float32, dl float32, m uint8, shift uint)
  114. //go:noescape
  115. func dotQ3KInnerAVX2Fused(q *byte, hm *byte, x *float32, dl float32, m uint8, shift uint) float32
  116. // Q6_K - Inner loop vectorization
  117. //
  118. //go:noescape
  119. func dequantQ6KInnerAVX2(ql *byte, qh *byte, scales *int8, out *float32, d float32)
  120. //go:noescape
  121. func dotQ6KInnerAVX2(ql *byte, qh *byte, scales *float32, x *float32) float32
  122. //go:noescape
  123. func dotQ6KInnerAVX512(ql *byte, qh *byte, scales *float32, x *float32) float32
  124. // ============================================================================
  125. // Q8_K Dequantization
  126. // ============================================================================
  127. // dequantQ8KSimd attempts a vector-friendly dequant for Q8_K.
  128. // Returns true if the fast path was taken.
  129. func dequantQ8KSimd(b *BlockQ8_K, out []float32) bool {
  130. if hasAVX512 {
  131. dequantQ8KAVX512(b, &out[0])
  132. return true
  133. }
  134. if hasAVX2 {
  135. dequantQ8KAVX2(b, &out[0])
  136. return true
  137. }
  138. return false
  139. }
  140. // ============================================================================
  141. // Q4_K Dequantization
  142. // ============================================================================
  143. // dequantQ4KSimd performs vectorized Q4_K dequantization using AVX2.
  144. func dequantQ4KSimd(b *BlockQ4_K, out []float32) bool {
  145. if !hasAVX2 {
  146. return false
  147. }
  148. d := FP16ToFP32(b.D)
  149. dmin := FP16ToFP32(b.DMin)
  150. // Decode 6-bit scales and mins
  151. var sc [8]uint8
  152. var m [8]uint8
  153. for j := 0; j < 4; j++ {
  154. sc[j] = b.Scales[j] & 63
  155. m[j] = b.Scales[j+4] & 63
  156. }
  157. for j := 4; j < 8; j++ {
  158. sc[j] = (b.Scales[j+4] & 0xF) | ((b.Scales[j-4] >> 6) << 4)
  159. m[j] = (b.Scales[j+4] >> 4) | ((b.Scales[j-0] >> 6) << 4)
  160. }
  161. outPtr := 0
  162. qsPtr := 0
  163. for i := 0; i < 8; i += 2 {
  164. d1 := d * float32(sc[i])
  165. m1 := dmin * float32(m[i])
  166. d2 := d * float32(sc[i+1])
  167. m2 := dmin * float32(m[i+1])
  168. // Use AVX2 kernel for inner loop
  169. dequantQ4KInnerAVX2(&b.QS[qsPtr], &out[outPtr], d1, m1, d2, m2)
  170. outPtr += 64
  171. qsPtr += 32
  172. }
  173. return true
  174. }
  175. func dotQ4KSimd(b *BlockQ4_K, x []float32) (float32, bool) {
  176. useAVX512 := hasAVX512
  177. useAVX2 := hasAVX2
  178. if !useAVX512 && !useAVX2 {
  179. return 0, false
  180. }
  181. if len(x) != QK_K {
  182. return 0, false
  183. }
  184. d := FP16ToFP32(b.D)
  185. dmin := FP16ToFP32(b.DMin)
  186. // Decode 6-bit scales and mins
  187. var sc [8]uint8
  188. var m [8]uint8
  189. for j := 0; j < 4; j++ {
  190. sc[j] = b.Scales[j] & 63
  191. m[j] = b.Scales[j+4] & 63
  192. }
  193. for j := 4; j < 8; j++ {
  194. sc[j] = (b.Scales[j+4] & 0xF) | ((b.Scales[j-4] >> 6) << 4)
  195. m[j] = (b.Scales[j+4] >> 4) | ((b.Scales[j-0] >> 6) << 4)
  196. }
  197. var sum float32
  198. outPtr := 0
  199. qsPtr := 0
  200. for i := 0; i < 8; i += 2 {
  201. d1 := d * float32(sc[i])
  202. m1 := dmin * float32(m[i])
  203. d2 := d * float32(sc[i+1])
  204. m2 := dmin * float32(m[i+1])
  205. if useAVX512 {
  206. sum += dotQ4KInnerAVX512(&b.QS[qsPtr], &x[outPtr], d1, m1, d2, m2)
  207. } else {
  208. sum += dotQ4KInnerAVX2(&b.QS[qsPtr], &x[outPtr], d1, m1, d2, m2)
  209. }
  210. outPtr += 64
  211. qsPtr += 32
  212. }
  213. return sum, true
  214. }
  215. func dotQ8KSimd(b *BlockQ8_K, x []float32) (float32, bool) {
  216. useAVX512 := hasAVX512
  217. useAVX2 := hasAVX2
  218. if !useAVX512 && !useAVX2 {
  219. return 0, false
  220. }
  221. if len(x) != QK_K {
  222. return 0, false
  223. }
  224. if useAVX512 {
  225. return dotQ8KAVX512(b, &x[0]), true
  226. }
  227. return dotQ8KAVX2(b, &x[0]), true
  228. }
  229. func dotQ2KSimd(b *BlockQ2_K, x []float32) (float32, bool) {
  230. if !hasAVX2 {
  231. return 0, false
  232. }
  233. if len(x) != QK_K {
  234. return 0, false
  235. }
  236. d := FP16ToFP32(b.D)
  237. dmin := FP16ToFP32(b.DMin)
  238. is := 0
  239. xIdx := 0
  240. qOffset := 0
  241. var sum float32
  242. for n := 0; n < QK_K; n += 128 {
  243. for shift := uint(0); shift < 8; shift += 2 {
  244. sc := b.Scales[is]
  245. is++
  246. dl := d * float32(sc&0xF)
  247. ml := dmin * float32(sc>>4)
  248. sum += dotQ2KInnerAVX2Fused(&b.QS[qOffset], &x[xIdx], dl, ml, shift)
  249. xIdx += 16
  250. sc = b.Scales[is]
  251. is++
  252. dl = d * float32(sc&0xF)
  253. ml = dmin * float32(sc>>4)
  254. sum += dotQ2KInnerAVX2Fused(&b.QS[qOffset+16], &x[xIdx], dl, ml, shift)
  255. xIdx += 16
  256. }
  257. qOffset += 32
  258. }
  259. return sum, true
  260. }
  261. func dotQ3KSimd(b *BlockQ3_K, x []float32) (float32, bool) {
  262. if !hasAVX2 {
  263. return 0, false
  264. }
  265. if len(x) != QK_K {
  266. return 0, false
  267. }
  268. d := FP16ToFP32(b.D)
  269. var scales [16]float32
  270. for i := 0; i < 16; i++ {
  271. scales[i] = float32(unpackQ3Scale(b.Scales[:], i))
  272. }
  273. q := b.QS[:]
  274. hm := b.HMask[:]
  275. is := 0
  276. xIdx := 0
  277. m := uint8(1)
  278. var sum float32
  279. for n := 0; n < QK_K; n += 128 {
  280. for j := 0; j < 4; j++ {
  281. dl := d * scales[is]
  282. sum += dotQ3KInnerAVX2Fused(&q[0], &hm[0], &x[xIdx], dl, m, uint(j*2))
  283. xIdx += 16
  284. is++
  285. dl = d * scales[is]
  286. sum += dotQ3KInnerAVX2Fused(&q[16], &hm[16], &x[xIdx], dl, m, uint(j*2))
  287. xIdx += 16
  288. is++
  289. m <<= 1
  290. }
  291. q = q[32:]
  292. }
  293. return sum, true
  294. }
  295. // ============================================================================
  296. // Q2_K Dequantization
  297. // ============================================================================
  298. // dequantQ2KSimd performs vectorized Q2_K dequantization.
  299. func dequantQ2KSimd(b *BlockQ2_K, out []float32) bool {
  300. if !hasAVX2 {
  301. return false
  302. }
  303. d := FP16ToFP32(b.D)
  304. dmin := FP16ToFP32(b.DMin)
  305. is := 0
  306. outIdx := 0
  307. qOffset := 0
  308. for n := 0; n < QK_K; n += 128 {
  309. for shift := uint(0); shift < 8; shift += 2 {
  310. sc := b.Scales[is]
  311. is++
  312. dl := d * float32(sc&0xF)
  313. ml := dmin * float32(sc>>4)
  314. // Process 16 elements with AVX2
  315. dequantQ2KInnerAVX2(&b.QS[qOffset], &out[outIdx], dl, ml, shift)
  316. outIdx += 16
  317. sc = b.Scales[is]
  318. is++
  319. dl = d * float32(sc&0xF)
  320. ml = dmin * float32(sc>>4)
  321. // Process next 16 elements
  322. dequantQ2KInnerAVX2(&b.QS[qOffset+16], &out[outIdx], dl, ml, shift)
  323. outIdx += 16
  324. }
  325. qOffset += 32
  326. }
  327. return true
  328. }
  329. // ============================================================================
  330. // Q3_K Dequantization
  331. // ============================================================================
  332. // dequantQ3KSimd returns false because benchmarks showed the scalar path is faster.
  333. // Q3_K has complex 3-bit + high-bit packing that doesn't benefit from SIMD.
  334. // Scalar: 443ns, Unrolled Go: 502ns
  335. func dequantQ3KSimd(b *BlockQ3_K, out []float32) bool {
  336. if !hasAVX2 {
  337. return false
  338. }
  339. d := FP16ToFP32(b.D)
  340. var scales [16]float32
  341. for i := 0; i < 16; i++ {
  342. scales[i] = float32(unpackQ3Scale(b.Scales[:], i))
  343. }
  344. q := b.QS[:]
  345. hm := b.HMask[:]
  346. outIdx := 0
  347. is := 0
  348. m := uint8(1)
  349. // Same loop structure as scalar, but vectorized inner loop
  350. for n := 0; n < QK_K; n += 128 {
  351. for j := 0; j < 4; j++ {
  352. // First 16
  353. dl1 := d * scales[is]
  354. dequantQ3KInnerAVX2(&q[0], &hm[0], &out[outIdx], dl1, m, uint(j*2))
  355. is++
  356. outIdx += 16
  357. // Second 16 (offset by 16 in q/hm)
  358. dl2 := d * scales[is]
  359. dequantQ3KInnerAVX2(&q[16], &hm[16], &out[outIdx], dl2, m, uint(j*2))
  360. is++
  361. outIdx += 16
  362. // In scalar, mask m shifts left
  363. m <<= 1
  364. }
  365. q = q[32:]
  366. // reset mask for next 128 block?
  367. // Wait, scalar: m starts at 1. Correct.
  368. }
  369. return true
  370. }
  371. // ============================================================================
  372. // Q6_K Dequantization
  373. // ============================================================================
  374. // dequantQ6KSimd returns false because benchmarks showed the scalar path is equivalent.
  375. // Q6_K has complex 6-bit packing that doesn't benefit from our Go-based optimization.
  376. // Scalar: 515ns, Unrolled Go: 521ns
  377. func dequantQ6KSimd(b *BlockQ6_K, out []float32) bool {
  378. // disabled: verification failure in block calculation
  379. if !q6kSimdReady() {
  380. return false
  381. }
  382. if len(out) != QK_K {
  383. return false
  384. }
  385. d := FP16ToFP32(b.D)
  386. dequantQ6KInnerAVX2(&b.QL[0], &b.QH[0], &b.Scales[0], &out[0], d)
  387. dequantQ6KInnerAVX2(&b.QL[64], &b.QH[32], &b.Scales[8], &out[128], d)
  388. return true
  389. }
  390. func dotQ6KSimd(b *BlockQ6_K, x []float32) (float32, bool) {
  391. if !q6kSimdReady() {
  392. return 0, false
  393. }
  394. if len(x) != QK_K {
  395. return 0, false
  396. }
  397. var tmp [128]float32
  398. d := FP16ToFP32(b.D)
  399. var sum float32
  400. dequantQ6KInnerAVX2(&b.QL[0], &b.QH[0], &b.Scales[0], &tmp[0], d)
  401. for i := 0; i < 128; i++ {
  402. sum += x[i] * tmp[i]
  403. }
  404. dequantQ6KInnerAVX2(&b.QL[64], &b.QH[32], &b.Scales[8], &tmp[0], d)
  405. for i := 0; i < 128; i++ {
  406. sum += x[128+i] * tmp[i]
  407. }
  408. return sum, true
  409. }