1
0

quant.go 31 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480
  1. package tensor
  2. import (
  3. "math"
  4. "sync"
  5. "unsafe"
  6. )
  7. // QK_K is the super-block size for K-quants
  8. const QK_K = 256
  9. // unpackQ3Scale extracts a signed 6-bit scale from packed bytes (llama.cpp layout)
  10. func unpackQ3Scale(packed []byte, idx int) int8 {
  11. var sc uint8
  12. if idx < 8 {
  13. sc = packed[idx] & 0xF
  14. } else {
  15. sc = packed[idx-8] >> 4
  16. }
  17. sc |= ((packed[8+(idx%4)] >> (2 * (idx / 4))) & 0x3) << 4
  18. return int8(sc) - 32
  19. }
  20. // BlockQ4_K represents a block of 256 weights quantized to 4 bits with super-block scales
  21. // Layout (144 bytes):
  22. // - D (2 bytes): float16 super-scale
  23. // - DMin (2 bytes): float16 super-min-scale
  24. // - Scales (12 bytes): 8 6-bit scales and 8 6-bit mins packed
  25. // - QS (128 bytes): 256 4-bit quants
  26. type BlockQ4_K struct {
  27. D uint16
  28. DMin uint16
  29. Scales [12]uint8
  30. QS [128]uint8
  31. }
  32. type BlockQ5_K struct {
  33. D uint16
  34. DMin uint16
  35. Scales [12]uint8
  36. QH [32]uint8
  37. QS [128]uint8
  38. }
  39. func getScaleMinK4(j int, q *[12]uint8) (d uint8, m uint8) {
  40. if j < 4 {
  41. d = q[j] & 63
  42. m = q[j+4] & 63
  43. return d, m
  44. }
  45. d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4)
  46. m = (q[j+4] >> 4) | ((q[j] >> 6) << 4)
  47. return d, m
  48. }
  49. // FP16ToFP32 converts a float16 (as uint16) to float32.
  50. // Implements full IEEE 754 half-precision conversion.
  51. func FP16ToFP32(n uint16) float32 {
  52. sign := uint32(n&0x8000) << 16
  53. exp := uint32(n&0x7C00) >> 10
  54. mant := uint32(n & 0x03FF)
  55. // Normalized case (most common for model weights)
  56. if exp > 0 && exp < 0x1F {
  57. return math.Float32frombits(sign | ((exp + 112) << 23) | (mant << 13))
  58. }
  59. // Zero or Denormalized
  60. if exp == 0 {
  61. if mant == 0 {
  62. return math.Float32frombits(sign)
  63. }
  64. // Denormalized number
  65. // Renormalize: multiply by 2^(-14)
  66. // 1024.0 is 2^10
  67. m := float32(mant) / 1024.0
  68. val := m * float32(math.Pow(2, -14))
  69. if sign != 0 {
  70. val = -val
  71. }
  72. return val
  73. }
  74. // Infinity or NaN (exp == 0x1F)
  75. if mant == 0 {
  76. return math.Float32frombits(sign | 0x7F800000) // Infinity
  77. }
  78. return math.Float32frombits(sign | 0x7FC00000 | (mant << 13)) // NaN
  79. }
  80. // DequantizeQ4_K dequantizes a single Q4_K block into 256 floats
  81. func DequantizeQ4_K(b *BlockQ4_K, out []float32) {
  82. if dequantQ4KSimd(b, out) {
  83. return
  84. }
  85. d := FP16ToFP32(b.D)
  86. dmin := FP16ToFP32(b.DMin)
  87. var sc [8]uint8
  88. var m [8]uint8
  89. for j := 0; j < 4; j++ {
  90. sc[j] = b.Scales[j] & 63
  91. m[j] = b.Scales[j+4] & 63
  92. }
  93. for j := 4; j < 8; j++ {
  94. sc[j] = (b.Scales[j+4] & 0xF) | ((b.Scales[j-4] >> 6) << 4)
  95. m[j] = (b.Scales[j+4] >> 4) | ((b.Scales[j-0] >> 6) << 4)
  96. }
  97. outPtr := 0
  98. qsPtr := 0
  99. for i := 0; i < 8; i += 2 {
  100. d1 := d * float32(sc[i])
  101. m1 := dmin * float32(m[i])
  102. d2 := d * float32(sc[i+1])
  103. m2 := dmin * float32(m[i+1])
  104. for l := 0; l < 32; l++ {
  105. val := b.QS[qsPtr+l]
  106. v1 := val & 0xF
  107. v2 := val >> 4
  108. out[outPtr] = float32(v1)*d1 - m1
  109. out[outPtr+32] = float32(v2)*d2 - m2
  110. outPtr++
  111. }
  112. outPtr += 32
  113. qsPtr += 32
  114. }
  115. }
  116. func DotQ4_K(b *BlockQ4_K, x []float32) float32 {
  117. if len(x) != QK_K {
  118. panic("DotQ4_K: mismatched slice length")
  119. }
  120. if sum, ok := dotQ4KSimd(b, x); ok {
  121. return sum
  122. }
  123. d := FP16ToFP32(b.D)
  124. dmin := FP16ToFP32(b.DMin)
  125. var sc [8]uint8
  126. var m [8]uint8
  127. for j := 0; j < 4; j++ {
  128. sc[j] = b.Scales[j] & 63
  129. m[j] = b.Scales[j+4] & 63
  130. }
  131. for j := 4; j < 8; j++ {
  132. sc[j] = (b.Scales[j+4] & 0xF) | ((b.Scales[j-4] >> 6) << 4)
  133. m[j] = (b.Scales[j+4] >> 4) | ((b.Scales[j-0] >> 6) << 4)
  134. }
  135. var sum float32
  136. outPtr := 0
  137. qsPtr := 0
  138. for i := 0; i < 8; i += 2 {
  139. d1 := d * float32(sc[i])
  140. m1 := dmin * float32(m[i])
  141. d2 := d * float32(sc[i+1])
  142. m2 := dmin * float32(m[i+1])
  143. for l := 0; l < 32; l++ {
  144. val := b.QS[qsPtr+l]
  145. v1 := val & 0xF
  146. v2 := val >> 4
  147. sum += x[outPtr] * (float32(v1)*d1 - m1)
  148. sum += x[outPtr+32] * (float32(v2)*d2 - m2)
  149. outPtr++
  150. }
  151. outPtr += 32
  152. qsPtr += 32
  153. }
  154. return sum
  155. }
  156. func DotQ2_K_Params(b *BlockQ2_K, p *Q2KDotParams, x []float32) float32 {
  157. if len(x) != QK_K {
  158. panic("DotQ2_K_Params: mismatched slice length")
  159. }
  160. if hasAVX2 {
  161. is := 0
  162. xIdx := 0
  163. qOffset := 0
  164. var sum float32
  165. for n := 0; n < QK_K; n += 128 {
  166. _ = n
  167. shift := uint(0)
  168. for j := 0; j < 4; j++ {
  169. dl := p.DL[is]
  170. ml := p.ML[is]
  171. is++
  172. sum += dotQ2KInnerAVX2Fused(&b.QS[qOffset], &x[xIdx], dl, ml, shift)
  173. xIdx += 16
  174. dl = p.DL[is]
  175. ml = p.ML[is]
  176. is++
  177. sum += dotQ2KInnerAVX2Fused(&b.QS[qOffset+16], &x[xIdx], dl, ml, shift)
  178. xIdx += 16
  179. shift += 2
  180. }
  181. qOffset += 32
  182. }
  183. return sum
  184. }
  185. q := b.QS[:]
  186. is := 0
  187. outIdx := 0
  188. var sum float32
  189. for n := 0; n < QK_K; n += 128 {
  190. shift := uint(0)
  191. for j := 0; j < 4; j++ {
  192. dl := p.DL[is]
  193. ml := p.ML[is]
  194. is++
  195. for l := 0; l < 16; l++ {
  196. val := float32(int8((q[l] >> shift) & 3))
  197. sum += x[outIdx] * (dl*val - ml)
  198. outIdx++
  199. }
  200. dl = p.DL[is]
  201. ml = p.ML[is]
  202. is++
  203. for l := 0; l < 16; l++ {
  204. val := float32(int8((q[l+16] >> shift) & 3))
  205. sum += x[outIdx] * (dl*val - ml)
  206. outIdx++
  207. }
  208. shift += 2
  209. }
  210. q = q[32:]
  211. }
  212. return sum
  213. }
  214. func DotQ2KTile8(sums *[8]float32, w []BlockQ2_K, wp []Q2KDotParams, base int, stride int, x *float32, n int) {
  215. if n <= 0 {
  216. return
  217. }
  218. xp := unsafe.Pointer(x)
  219. if hasAVX2 {
  220. for t := 0; t < n; t++ {
  221. idx := base + t*stride
  222. b := &w[idx]
  223. p := &wp[idx]
  224. is := 0
  225. xIdx := 0
  226. qOffset := 0
  227. for nn := 0; nn < QK_K; nn += 128 {
  228. _ = nn
  229. shift := uint(0)
  230. for j := 0; j < 4; j++ {
  231. dl := p.DL[is]
  232. ml := p.ML[is]
  233. is++
  234. xSeg := (*float32)(unsafe.Add(xp, uintptr(xIdx)*4))
  235. sums[t] += dotQ2KInnerAVX2Fused(&b.QS[qOffset], xSeg, dl, ml, shift)
  236. xIdx += 16
  237. dl = p.DL[is]
  238. ml = p.ML[is]
  239. is++
  240. xSeg = (*float32)(unsafe.Add(xp, uintptr(xIdx)*4))
  241. sums[t] += dotQ2KInnerAVX2Fused(&b.QS[qOffset+16], xSeg, dl, ml, shift)
  242. xIdx += 16
  243. shift += 2
  244. }
  245. qOffset += 32
  246. }
  247. }
  248. return
  249. }
  250. xSlice := unsafe.Slice(x, QK_K)
  251. for t := 0; t < n; t++ {
  252. idx := base + t*stride
  253. b := &w[idx]
  254. p := &wp[idx]
  255. sums[t] += DotQ2_K_Params(b, p, xSlice)
  256. }
  257. }
  258. func DotQ8KTile8(sums *[8]float32, w []BlockQ8_K, base int, stride int, x *float32, n int) {
  259. if n <= 0 {
  260. return
  261. }
  262. if hasAVX512 {
  263. for t := 0; t < n; t++ {
  264. idx := base + t*stride
  265. sums[t] += dotQ8KAVX512(&w[idx], x)
  266. }
  267. return
  268. }
  269. if hasAVX2 {
  270. for t := 0; t < n; t++ {
  271. idx := base + t*stride
  272. sums[t] += dotQ8KAVX2(&w[idx], x)
  273. }
  274. return
  275. }
  276. xSlice := unsafe.Slice(x, QK_K)
  277. for t := 0; t < n; t++ {
  278. idx := base + t*stride
  279. sums[t] += DotQ8_K(&w[idx], xSlice)
  280. }
  281. }
  282. func DotQ3_K_Params(b *BlockQ3_K, p *Q3KDotParams, x []float32) float32 {
  283. if len(x) != QK_K {
  284. panic("DotQ3_K_Params: mismatched slice length")
  285. }
  286. if hasAVX2 {
  287. q := b.QS[:]
  288. hm := b.HMask[:]
  289. is := 0
  290. xIdx := 0
  291. m := uint8(1)
  292. var sum float32
  293. for n := 0; n < QK_K; n += 128 {
  294. for j := 0; j < 4; j++ {
  295. dl := p.S[is]
  296. sum += dotQ3KInnerAVX2Fused(&q[0], &hm[0], &x[xIdx], dl, m, uint(j*2))
  297. xIdx += 16
  298. is++
  299. dl = p.S[is]
  300. sum += dotQ3KInnerAVX2Fused(&q[16], &hm[16], &x[xIdx], dl, m, uint(j*2))
  301. xIdx += 16
  302. is++
  303. m <<= 1
  304. }
  305. q = q[32:]
  306. }
  307. return sum
  308. }
  309. q := b.QS[:]
  310. hm := b.HMask[:]
  311. outIdx := 0
  312. is := 0
  313. m := uint8(1)
  314. var sum float32
  315. for n := 0; n < QK_K; n += 128 {
  316. shift := uint(0)
  317. for j := 0; j < 4; j++ {
  318. dl := p.S[is]
  319. is++
  320. for l := 0; l < 16; l++ {
  321. qv := int8((q[l] >> shift) & 0x3)
  322. if hm[l]&m == 0 {
  323. qv -= 4
  324. }
  325. sum += x[outIdx] * (dl * float32(qv))
  326. outIdx++
  327. }
  328. dl = p.S[is]
  329. is++
  330. for l := 0; l < 16; l++ {
  331. qv := int8((q[l+16] >> shift) & 0x3)
  332. if hm[l+16]&m == 0 {
  333. qv -= 4
  334. }
  335. sum += x[outIdx] * (dl * float32(qv))
  336. outIdx++
  337. }
  338. shift += 2
  339. m <<= 1
  340. }
  341. q = q[32:]
  342. }
  343. return sum
  344. }
  345. func DotQ3KTile8(sums *[8]float32, w []BlockQ3_K, wp []Q3KDotParams, base int, stride int, x *float32, n int) {
  346. if n <= 0 {
  347. return
  348. }
  349. xp := unsafe.Pointer(x)
  350. if hasAVX2 {
  351. for t := 0; t < n; t++ {
  352. idx := base + t*stride
  353. b := &w[idx]
  354. p := &wp[idx]
  355. q := b.QS[:]
  356. hm := b.HMask[:]
  357. is := 0
  358. xIdx := 0
  359. m := uint8(1)
  360. for nn := 0; nn < QK_K; nn += 128 {
  361. for j := 0; j < 4; j++ {
  362. dl := p.S[is]
  363. xSeg := (*float32)(unsafe.Add(xp, uintptr(xIdx)*4))
  364. sums[t] += dotQ3KInnerAVX2Fused(&q[0], &hm[0], xSeg, dl, m, uint(j*2))
  365. xIdx += 16
  366. is++
  367. dl = p.S[is]
  368. xSeg = (*float32)(unsafe.Add(xp, uintptr(xIdx)*4))
  369. sums[t] += dotQ3KInnerAVX2Fused(&q[16], &hm[16], xSeg, dl, m, uint(j*2))
  370. xIdx += 16
  371. is++
  372. m <<= 1
  373. }
  374. q = q[32:]
  375. }
  376. }
  377. return
  378. }
  379. for t := 0; t < n; t++ {
  380. idx := base + t*stride
  381. b := &w[idx]
  382. p := &wp[idx]
  383. xIdx := 0
  384. q := b.QS[:]
  385. hm := b.HMask[:]
  386. is := 0
  387. m := uint8(1)
  388. for nn := 0; nn < QK_K; nn += 128 {
  389. shift := uint(0)
  390. for j := 0; j < 4; j++ {
  391. dl := p.S[is]
  392. is++
  393. for l := 0; l < 16; l++ {
  394. qv := int8((q[l] >> shift) & 0x3)
  395. if hm[l]&m == 0 {
  396. qv -= 4
  397. }
  398. x0 := *(*float32)(unsafe.Add(xp, uintptr(xIdx)*4))
  399. sums[t] += x0 * (dl * float32(qv))
  400. xIdx++
  401. }
  402. dl = p.S[is]
  403. is++
  404. for l := 0; l < 16; l++ {
  405. qv := int8((q[l+16] >> shift) & 0x3)
  406. if hm[l+16]&m == 0 {
  407. qv -= 4
  408. }
  409. x0 := *(*float32)(unsafe.Add(xp, uintptr(xIdx)*4))
  410. sums[t] += x0 * (dl * float32(qv))
  411. xIdx++
  412. }
  413. shift += 2
  414. m <<= 1
  415. }
  416. q = q[32:]
  417. }
  418. }
  419. }
  420. func DotQ6_K_Params(b *BlockQ6_K, p *Q6KDotParams, x []float32) float32 {
  421. if len(x) != QK_K {
  422. panic("DotQ6_K_Params: mismatched slice length")
  423. }
  424. if hasAVX512 {
  425. sum := dotQ6KInnerAVX512(&b.QL[0], &b.QH[0], &p.S[0], &x[0])
  426. sum += dotQ6KInnerAVX512(&b.QL[64], &b.QH[32], &p.S[8], &x[128])
  427. return sum
  428. }
  429. if hasAVX2 {
  430. sum := dotQ6KInnerAVX2(&b.QL[0], &b.QH[0], &p.S[0], &x[0])
  431. sum += dotQ6KInnerAVX2(&b.QL[64], &b.QH[32], &p.S[8], &x[128])
  432. return sum
  433. }
  434. qlPtr := 0
  435. qhPtr := 0
  436. outPtr := 0
  437. var sum float32
  438. for half := 0; half < 2; half++ {
  439. scBase := half * 8
  440. for l := 0; l < 32; l++ {
  441. is := l / 16
  442. ql0 := b.QL[qlPtr+l]
  443. ql32 := b.QL[qlPtr+l+32]
  444. qh := b.QH[qhPtr+l]
  445. q1 := int8((ql0&0xF)|(((qh>>0)&3)<<4)) - 32
  446. q2 := int8((ql32&0xF)|(((qh>>2)&3)<<4)) - 32
  447. q3 := int8((ql0>>4)|(((qh>>4)&3)<<4)) - 32
  448. q4 := int8((ql32>>4)|(((qh>>6)&3)<<4)) - 32
  449. s0 := p.S[scBase+is+0]
  450. s2 := p.S[scBase+is+2]
  451. s4 := p.S[scBase+is+4]
  452. s6 := p.S[scBase+is+6]
  453. base := outPtr + l
  454. sum += x[base+0] * s0 * float32(q1)
  455. sum += x[base+32] * s2 * float32(q2)
  456. sum += x[base+64] * s4 * float32(q3)
  457. sum += x[base+96] * s6 * float32(q4)
  458. }
  459. outPtr += 128
  460. qlPtr += 64
  461. qhPtr += 32
  462. }
  463. return sum
  464. }
  465. func DotQ6KTile8(sums *[8]float32, w []BlockQ6_K, wp []Q6KDotParams, base int, stride int, x *float32, n int) {
  466. if n <= 0 {
  467. return
  468. }
  469. xp := unsafe.Pointer(x)
  470. if hasAVX512 {
  471. for t := 0; t < n; t++ {
  472. idx := base + t*stride
  473. b := &w[idx]
  474. p := &wp[idx]
  475. sums[t] += dotQ6KInnerAVX512(&b.QL[0], &b.QH[0], &p.S[0], x)
  476. x128 := (*float32)(unsafe.Add(xp, 128*4))
  477. sums[t] += dotQ6KInnerAVX512(&b.QL[64], &b.QH[32], &p.S[8], x128)
  478. }
  479. return
  480. }
  481. if hasAVX2 {
  482. for t := 0; t < n; t++ {
  483. idx := base + t*stride
  484. b := &w[idx]
  485. p := &wp[idx]
  486. sums[t] += dotQ6KInnerAVX2(&b.QL[0], &b.QH[0], &p.S[0], x)
  487. x128 := (*float32)(unsafe.Add(xp, 128*4))
  488. sums[t] += dotQ6KInnerAVX2(&b.QL[64], &b.QH[32], &p.S[8], x128)
  489. }
  490. return
  491. }
  492. for half := 0; half < 2; half++ {
  493. xBase := unsafe.Add(xp, uintptr(half*128)*4)
  494. qlPtr := half * 64
  495. qhPtr := half * 32
  496. scBase := half * 8
  497. for l := 0; l < 32; l++ {
  498. x0 := *(*float32)(unsafe.Add(xBase, uintptr(l)*4))
  499. x1 := *(*float32)(unsafe.Add(xBase, uintptr(32+l)*4))
  500. x2 := *(*float32)(unsafe.Add(xBase, uintptr(64+l)*4))
  501. x3 := *(*float32)(unsafe.Add(xBase, uintptr(96+l)*4))
  502. is := l / 16
  503. for t := 0; t < n; t++ {
  504. idx := base + t*stride
  505. b := &w[idx]
  506. p := &wp[idx]
  507. ql0 := b.QL[qlPtr+l]
  508. ql32 := b.QL[qlPtr+l+32]
  509. qh := b.QH[qhPtr+l]
  510. q1 := int8((ql0&0xF)|(((qh>>0)&3)<<4)) - 32
  511. q2 := int8((ql32&0xF)|(((qh>>2)&3)<<4)) - 32
  512. q3 := int8((ql0>>4)|(((qh>>4)&3)<<4)) - 32
  513. q4 := int8((ql32>>4)|(((qh>>6)&3)<<4)) - 32
  514. s0 := p.S[scBase+is+0]
  515. s2 := p.S[scBase+is+2]
  516. s4 := p.S[scBase+is+4]
  517. s6 := p.S[scBase+is+6]
  518. sums[t] += x0*s0*float32(q1) + x1*s2*float32(q2) + x2*s4*float32(q3) + x3*s6*float32(q4)
  519. }
  520. }
  521. }
  522. }
  523. type Q5KDotParams struct {
  524. D1 [4]float32
  525. M1 [4]float32
  526. D2 [4]float32
  527. M2 [4]float32
  528. }
  529. type q5kParamsKey struct {
  530. p unsafe.Pointer
  531. n int
  532. }
  533. var q5kDotParamsCache sync.Map
  534. func GetQ5KDotParams(blocks []BlockQ5_K) []Q5KDotParams {
  535. if len(blocks) == 0 {
  536. return nil
  537. }
  538. key := q5kParamsKey{p: unsafe.Pointer(&blocks[0]), n: len(blocks)}
  539. if v, ok := q5kDotParamsCache.Load(key); ok {
  540. return v.([]Q5KDotParams)
  541. }
  542. params := make([]Q5KDotParams, len(blocks))
  543. for bi := range blocks {
  544. b := &blocks[bi]
  545. d := FP16ToFP32(b.D)
  546. dmin := FP16ToFP32(b.DMin)
  547. var p Q5KDotParams
  548. seg := 0
  549. is := 0
  550. for j := 0; j < QK_K; j += 64 {
  551. sc, m := getScaleMinK4(is+0, &b.Scales)
  552. p.D1[seg] = d * float32(sc)
  553. p.M1[seg] = dmin * float32(m)
  554. sc, m = getScaleMinK4(is+1, &b.Scales)
  555. p.D2[seg] = d * float32(sc)
  556. p.M2[seg] = dmin * float32(m)
  557. seg++
  558. is += 2
  559. }
  560. params[bi] = p
  561. }
  562. q5kDotParamsCache.Store(key, params)
  563. return params
  564. }
  565. func DotQ5_K_Params(b *BlockQ5_K, p *Q5KDotParams, x []float32) float32 {
  566. if len(x) != QK_K {
  567. panic("DotQ5_K_Params: mismatched slice length")
  568. }
  569. var sum float32
  570. qsPtr := 0
  571. qh := b.QH[:]
  572. for seg := 0; seg < 4; seg++ {
  573. d1, m1 := p.D1[seg], p.M1[seg]
  574. d2, m2 := p.D2[seg], p.M2[seg]
  575. outPtr := seg * 64
  576. shift1 := uint(2 * seg)
  577. shift2 := uint(2*seg + 1)
  578. if hasAVX512 {
  579. sum += dotQ5KInnerAVX512(&b.QS[qsPtr], &b.QH[0], &x[outPtr], d1, m1, d2, m2, shift1, shift2)
  580. qsPtr += 32
  581. continue
  582. }
  583. if hasAVX2 {
  584. sum += dotQ5KInnerAVX2(&b.QS[qsPtr], &b.QH[0], &x[outPtr], d1, m1, d2, m2, shift1, shift2)
  585. qsPtr += 32
  586. continue
  587. }
  588. u1 := uint8(1 << shift1)
  589. u2 := uint8(1 << shift2)
  590. for l := 0; l < 32; l++ {
  591. v := int(b.QS[qsPtr+l] & 0xF)
  592. if (qh[l] & u1) != 0 {
  593. v += 16
  594. }
  595. sum += x[outPtr+l] * (d1*float32(v) - m1)
  596. }
  597. for l := 0; l < 32; l++ {
  598. v := int(b.QS[qsPtr+l] >> 4)
  599. if (qh[l] & u2) != 0 {
  600. v += 16
  601. }
  602. sum += x[outPtr+32+l] * (d2*float32(v) - m2)
  603. }
  604. qsPtr += 32
  605. }
  606. return sum
  607. }
  608. func DotQ5_K_ParamsPtr(b *BlockQ5_K, p *Q5KDotParams, x *float32) float32 {
  609. var sum float32
  610. xp := unsafe.Pointer(x)
  611. qsPtr := 0
  612. qh := b.QH[:]
  613. for seg := 0; seg < 4; seg++ {
  614. d1, m1 := p.D1[seg], p.M1[seg]
  615. d2, m2 := p.D2[seg], p.M2[seg]
  616. outPtr := seg * 64
  617. shift1 := uint(2 * seg)
  618. shift2 := uint(2*seg + 1)
  619. if hasAVX512 {
  620. xSeg := (*float32)(unsafe.Add(xp, uintptr(outPtr)*4))
  621. sum += dotQ5KInnerAVX512(&b.QS[qsPtr], &b.QH[0], xSeg, d1, m1, d2, m2, shift1, shift2)
  622. qsPtr += 32
  623. continue
  624. }
  625. if hasAVX2 {
  626. xSeg := (*float32)(unsafe.Add(xp, uintptr(outPtr)*4))
  627. sum += dotQ5KInnerAVX2(&b.QS[qsPtr], &b.QH[0], xSeg, d1, m1, d2, m2, shift1, shift2)
  628. qsPtr += 32
  629. continue
  630. }
  631. u1 := uint8(1 << shift1)
  632. u2 := uint8(1 << shift2)
  633. for l := 0; l < 32; l++ {
  634. v := int(b.QS[qsPtr+l] & 0xF)
  635. if (qh[l] & u1) != 0 {
  636. v += 16
  637. }
  638. x0 := *(*float32)(unsafe.Add(xp, uintptr(outPtr+l)*4))
  639. sum += x0 * (d1*float32(v) - m1)
  640. }
  641. for l := 0; l < 32; l++ {
  642. v := int(b.QS[qsPtr+l] >> 4)
  643. if (qh[l] & u2) != 0 {
  644. v += 16
  645. }
  646. x1 := *(*float32)(unsafe.Add(xp, uintptr(outPtr+32+l)*4))
  647. sum += x1 * (d2*float32(v) - m2)
  648. }
  649. qsPtr += 32
  650. }
  651. return sum
  652. }
  653. func DotQ5KTile8(sums *[8]float32, w []BlockQ5_K, wp []Q5KDotParams, base int, stride int, x *float32, n int) {
  654. if n <= 0 {
  655. return
  656. }
  657. xp := unsafe.Pointer(x)
  658. for seg := 0; seg < 4; seg++ {
  659. xSeg := (*float32)(unsafe.Add(xp, uintptr(seg*64)*4))
  660. qsPtr := seg * 32
  661. shift1 := uint(2 * seg)
  662. shift2 := uint(2*seg + 1)
  663. u1 := uint8(1 << shift1)
  664. u2 := uint8(1 << shift2)
  665. xsp := unsafe.Pointer(xSeg)
  666. for t := 0; t < n; t++ {
  667. idx := base + t*stride
  668. b := &w[idx]
  669. p := &wp[idx]
  670. d1, m1 := p.D1[seg], p.M1[seg]
  671. d2, m2 := p.D2[seg], p.M2[seg]
  672. if hasAVX512 {
  673. sums[t] += dotQ5KInnerAVX512(&b.QS[qsPtr], &b.QH[0], xSeg, d1, m1, d2, m2, shift1, shift2)
  674. continue
  675. }
  676. if hasAVX2 {
  677. sums[t] += dotQ5KInnerAVX2(&b.QS[qsPtr], &b.QH[0], xSeg, d1, m1, d2, m2, shift1, shift2)
  678. continue
  679. }
  680. qh := b.QH[:]
  681. for l := 0; l < 32; l++ {
  682. v0 := int(b.QS[qsPtr+l] & 0xF)
  683. if (qh[l] & u1) != 0 {
  684. v0 += 16
  685. }
  686. v1 := int(b.QS[qsPtr+l] >> 4)
  687. if (qh[l] & u2) != 0 {
  688. v1 += 16
  689. }
  690. x0 := *(*float32)(unsafe.Add(xsp, uintptr(l)*4))
  691. x1 := *(*float32)(unsafe.Add(xsp, uintptr(32+l)*4))
  692. sums[t] += x0*(d1*float32(v0)-m1) + x1*(d2*float32(v1)-m2)
  693. }
  694. }
  695. }
  696. }
  697. func DequantizeQ5_K(b *BlockQ5_K, out []float32) {
  698. d := FP16ToFP32(b.D)
  699. min := FP16ToFP32(b.DMin)
  700. outIdx := 0
  701. ql := b.QS[:]
  702. qh := b.QH[:]
  703. is := 0
  704. u1 := uint8(1)
  705. u2 := uint8(2)
  706. for j := 0; j < QK_K; j += 64 {
  707. sc, m := getScaleMinK4(is+0, &b.Scales)
  708. d1 := d * float32(sc)
  709. m1 := min * float32(m)
  710. sc, m = getScaleMinK4(is+1, &b.Scales)
  711. d2 := d * float32(sc)
  712. m2 := min * float32(m)
  713. for l := 0; l < 32; l++ {
  714. v := int(ql[l] & 0xF)
  715. if (qh[l] & u1) != 0 {
  716. v += 16
  717. }
  718. out[outIdx] = d1*float32(v) - m1
  719. outIdx++
  720. }
  721. for l := 0; l < 32; l++ {
  722. v := int(ql[l] >> 4)
  723. if (qh[l] & u2) != 0 {
  724. v += 16
  725. }
  726. out[outIdx] = d2*float32(v) - m2
  727. outIdx++
  728. }
  729. ql = ql[32:]
  730. is += 2
  731. u1 <<= 2
  732. u2 <<= 2
  733. }
  734. }
  735. func DotQ5_K(b *BlockQ5_K, x []float32) float32 {
  736. if len(x) != QK_K {
  737. panic("DotQ5_K: mismatched slice length")
  738. }
  739. d := FP16ToFP32(b.D)
  740. min := FP16ToFP32(b.DMin)
  741. ql := b.QS[:]
  742. qh := b.QH[:]
  743. is := 0
  744. u1 := uint8(1)
  745. u2 := uint8(2)
  746. var sum float32
  747. outIdx := 0
  748. for j := 0; j < QK_K; j += 64 {
  749. sc, m := getScaleMinK4(is+0, &b.Scales)
  750. d1 := d * float32(sc)
  751. m1 := min * float32(m)
  752. sc, m = getScaleMinK4(is+1, &b.Scales)
  753. d2 := d * float32(sc)
  754. m2 := min * float32(m)
  755. for l := 0; l < 32; l++ {
  756. v := int(ql[l] & 0xF)
  757. if (qh[l] & u1) != 0 {
  758. v += 16
  759. }
  760. sum += x[outIdx] * (d1*float32(v) - m1)
  761. outIdx++
  762. }
  763. for l := 0; l < 32; l++ {
  764. v := int(ql[l] >> 4)
  765. if (qh[l] & u2) != 0 {
  766. v += 16
  767. }
  768. sum += x[outIdx] * (d2*float32(v) - m2)
  769. outIdx++
  770. }
  771. ql = ql[32:]
  772. is += 2
  773. u1 <<= 2
  774. u2 <<= 2
  775. }
  776. return sum
  777. }
  778. type Q4KDotParams struct {
  779. D1 [4]float32
  780. M1 [4]float32
  781. D2 [4]float32
  782. M2 [4]float32
  783. }
  784. type q4kParamsKey struct {
  785. p unsafe.Pointer
  786. n int
  787. }
  788. var q4kDotParamsCache sync.Map
  789. func GetQ4KDotParams(blocks []BlockQ4_K) []Q4KDotParams {
  790. if len(blocks) == 0 {
  791. return nil
  792. }
  793. key := q4kParamsKey{p: unsafe.Pointer(&blocks[0]), n: len(blocks)}
  794. if v, ok := q4kDotParamsCache.Load(key); ok {
  795. return v.([]Q4KDotParams)
  796. }
  797. params := make([]Q4KDotParams, len(blocks))
  798. for bi := range blocks {
  799. b := &blocks[bi]
  800. d := FP16ToFP32(b.D)
  801. dmin := FP16ToFP32(b.DMin)
  802. var sc [8]uint8
  803. var m [8]uint8
  804. for j := 0; j < 4; j++ {
  805. sc[j] = b.Scales[j] & 63
  806. m[j] = b.Scales[j+4] & 63
  807. }
  808. for j := 4; j < 8; j++ {
  809. sc[j] = (b.Scales[j+4] & 0xF) | ((b.Scales[j-4] >> 6) << 4)
  810. m[j] = (b.Scales[j+4] >> 4) | ((b.Scales[j-0] >> 6) << 4)
  811. }
  812. var p Q4KDotParams
  813. seg := 0
  814. for i := 0; i < 8; i += 2 {
  815. p.D1[seg] = d * float32(sc[i])
  816. p.M1[seg] = dmin * float32(m[i])
  817. p.D2[seg] = d * float32(sc[i+1])
  818. p.M2[seg] = dmin * float32(m[i+1])
  819. seg++
  820. }
  821. params[bi] = p
  822. }
  823. q4kDotParamsCache.Store(key, params)
  824. return params
  825. }
  826. func DotQ4_K_Params(b *BlockQ4_K, p *Q4KDotParams, x []float32) float32 {
  827. if len(x) != QK_K {
  828. panic("DotQ4_K_Params: mismatched slice length")
  829. }
  830. var sum float32
  831. outPtr := 0
  832. qsPtr := 0
  833. for seg := 0; seg < 4; seg++ {
  834. d1, m1 := p.D1[seg], p.M1[seg]
  835. d2, m2 := p.D2[seg], p.M2[seg]
  836. if hasAVX512 {
  837. sum += dotQ4KInnerAVX512(&b.QS[qsPtr], &x[outPtr], d1, m1, d2, m2)
  838. } else if hasAVX2 {
  839. sum += dotQ4KInnerAVX2(&b.QS[qsPtr], &x[outPtr], d1, m1, d2, m2)
  840. } else {
  841. for l := 0; l < 32; l++ {
  842. val := b.QS[qsPtr+l]
  843. v1 := val & 0xF
  844. v2 := val >> 4
  845. sum += x[outPtr] * (float32(v1)*d1 - m1)
  846. sum += x[outPtr+32] * (float32(v2)*d2 - m2)
  847. outPtr++
  848. }
  849. outPtr += 32
  850. qsPtr += 32
  851. continue
  852. }
  853. outPtr += 64
  854. qsPtr += 32
  855. }
  856. return sum
  857. }
  858. func DotQ4_K_ParamsPtr(b *BlockQ4_K, p *Q4KDotParams, x *float32) float32 {
  859. var sum float32
  860. xp := unsafe.Pointer(x)
  861. outPtr := 0
  862. qsPtr := 0
  863. for seg := 0; seg < 4; seg++ {
  864. d1, m1 := p.D1[seg], p.M1[seg]
  865. d2, m2 := p.D2[seg], p.M2[seg]
  866. xSeg := (*float32)(unsafe.Add(xp, uintptr(outPtr)*4))
  867. if hasAVX512 {
  868. sum += dotQ4KInnerAVX512(&b.QS[qsPtr], xSeg, d1, m1, d2, m2)
  869. } else if hasAVX2 {
  870. sum += dotQ4KInnerAVX2(&b.QS[qsPtr], xSeg, d1, m1, d2, m2)
  871. } else {
  872. for l := 0; l < 32; l++ {
  873. val := b.QS[qsPtr+l]
  874. v1 := val & 0xF
  875. v2 := val >> 4
  876. x0 := *(*float32)(unsafe.Add(xp, uintptr(outPtr)*4))
  877. x1 := *(*float32)(unsafe.Add(xp, uintptr(outPtr+32)*4))
  878. sum += x0 * (float32(v1)*d1 - m1)
  879. sum += x1 * (float32(v2)*d2 - m2)
  880. outPtr++
  881. }
  882. outPtr += 32
  883. qsPtr += 32
  884. continue
  885. }
  886. outPtr += 64
  887. qsPtr += 32
  888. }
  889. return sum
  890. }
  891. func DotQ4KTile8(sums *[8]float32, w []BlockQ4_K, wp []Q4KDotParams, base int, stride int, x *float32, n int) {
  892. if n <= 0 {
  893. return
  894. }
  895. xp := unsafe.Pointer(x)
  896. outPtr := 0
  897. qsPtr := 0
  898. for seg := 0; seg < 4; seg++ {
  899. xSeg := (*float32)(unsafe.Add(xp, uintptr(outPtr)*4))
  900. for t := 0; t < n; t++ {
  901. idx := base + t*stride
  902. b := &w[idx]
  903. p := &wp[idx]
  904. d1, m1 := p.D1[seg], p.M1[seg]
  905. d2, m2 := p.D2[seg], p.M2[seg]
  906. if hasAVX512 {
  907. sums[t] += dotQ4KInnerAVX512(&b.QS[qsPtr], xSeg, d1, m1, d2, m2)
  908. } else if hasAVX2 {
  909. sums[t] += dotQ4KInnerAVX2(&b.QS[qsPtr], xSeg, d1, m1, d2, m2)
  910. } else {
  911. xsp := unsafe.Pointer(xSeg)
  912. for l := 0; l < 32; l++ {
  913. val := b.QS[qsPtr+l]
  914. v1 := val & 0xF
  915. v2 := val >> 4
  916. x0 := *(*float32)(unsafe.Add(xsp, uintptr(l)*4))
  917. x1 := *(*float32)(unsafe.Add(xsp, uintptr(32+l)*4))
  918. sums[t] += x0*(float32(v1)*d1-m1) + x1*(float32(v2)*d2-m2)
  919. }
  920. }
  921. }
  922. outPtr += 64
  923. qsPtr += 32
  924. }
  925. }
  926. // BlockQ8_K represents a block of 256 weights quantized to 8 bits
  927. // Layout (292 bytes):
  928. // - D (4 bytes): float32 scale
  929. // - QS (256 bytes): 256 int8 quants
  930. // - BSums (32 bytes): 16 int16 block sums (for dot product optimization, not used in dequant)
  931. type BlockQ8_K struct {
  932. D float32
  933. QS [256]int8
  934. BSums [16]int16
  935. }
  936. // DequantizeQ8_K dequantizes a single Q8_K block into 256 floats
  937. func DequantizeQ8_K(b *BlockQ8_K, out []float32) {
  938. if dequantQ8KSimd(b, out) {
  939. return
  940. }
  941. d := b.D
  942. for i := 0; i < 256; i++ {
  943. out[i] = d * float32(b.QS[i])
  944. }
  945. }
  946. func DotQ8_K(b *BlockQ8_K, x []float32) float32 {
  947. if len(x) != QK_K {
  948. panic("DotQ8_K: mismatched slice length")
  949. }
  950. if sum, ok := dotQ8KSimd(b, x); ok {
  951. return sum
  952. }
  953. d := b.D
  954. var sum float32
  955. for i := 0; i < 256; i++ {
  956. sum += x[i] * float32(b.QS[i])
  957. }
  958. return d * sum
  959. }
  960. // BlockQ2_K represents a block of 256 weights quantized to 2 bits
  961. // Layout (84 bytes):
  962. // - Scales (16 bytes): 16 4-bit scales and 16 4-bit mins packed
  963. // - QS (64 bytes): 256 2-bit quants packed
  964. // - D (2 bytes): float16 super-scale
  965. // - DMin (2 bytes): float16 super-min-scale
  966. type BlockQ2_K struct {
  967. Scales [16]uint8
  968. QS [64]uint8
  969. D uint16
  970. DMin uint16
  971. }
  972. type Q2KDotParams struct {
  973. DL [16]float32
  974. ML [16]float32
  975. }
  976. type q2kParamsKey struct {
  977. p unsafe.Pointer
  978. n int
  979. }
  980. var q2kDotParamsCache sync.Map
  981. func GetQ2KDotParams(blocks []BlockQ2_K) []Q2KDotParams {
  982. if len(blocks) == 0 {
  983. return nil
  984. }
  985. key := q2kParamsKey{p: unsafe.Pointer(&blocks[0]), n: len(blocks)}
  986. if v, ok := q2kDotParamsCache.Load(key); ok {
  987. return v.([]Q2KDotParams)
  988. }
  989. params := make([]Q2KDotParams, len(blocks))
  990. for bi := range blocks {
  991. b := &blocks[bi]
  992. d := FP16ToFP32(b.D)
  993. dmin := FP16ToFP32(b.DMin)
  994. var p Q2KDotParams
  995. pi := 0
  996. is := 0
  997. for n := 0; n < QK_K; n += 128 {
  998. _ = n
  999. for shift := uint(0); shift < 8; shift += 2 {
  1000. sc := b.Scales[is]
  1001. is++
  1002. p.DL[pi] = d * float32(sc&0xF)
  1003. p.ML[pi] = dmin * float32(sc>>4)
  1004. pi++
  1005. sc = b.Scales[is]
  1006. is++
  1007. p.DL[pi] = d * float32(sc&0xF)
  1008. p.ML[pi] = dmin * float32(sc>>4)
  1009. pi++
  1010. }
  1011. }
  1012. params[bi] = p
  1013. }
  1014. q2kDotParamsCache.Store(key, params)
  1015. return params
  1016. }
  1017. // DequantizeQ2_K dequantizes a single Q2_K block into 256 floats
  1018. // Mirrors llama.cpp dequantize_row_q2_K
  1019. func DequantizeQ2_K(b *BlockQ2_K, out []float32) {
  1020. if dequantQ2KSimd(b, out) {
  1021. return
  1022. }
  1023. d := FP16ToFP32(b.D)
  1024. dmin := FP16ToFP32(b.DMin)
  1025. q := b.QS[:]
  1026. is := 0
  1027. outIdx := 0
  1028. for n := 0; n < QK_K; n += 128 {
  1029. shift := uint(0)
  1030. for j := 0; j < 4; j++ {
  1031. sc := b.Scales[is]
  1032. is++
  1033. dl := d * float32(sc&0xF)
  1034. ml := dmin * float32(sc>>4)
  1035. for l := 0; l < 16; l++ {
  1036. val := int8((q[l] >> shift) & 3)
  1037. out[outIdx] = dl*float32(val) - ml
  1038. outIdx++
  1039. }
  1040. sc = b.Scales[is]
  1041. is++
  1042. dl = d * float32(sc&0xF)
  1043. ml = dmin * float32(sc>>4)
  1044. for l := 0; l < 16; l++ {
  1045. val := int8((q[l+16] >> shift) & 3)
  1046. out[outIdx] = dl*float32(val) - ml
  1047. outIdx++
  1048. }
  1049. shift += 2
  1050. }
  1051. q = q[32:]
  1052. }
  1053. }
  1054. func DotQ2_K(b *BlockQ2_K, x []float32) float32 {
  1055. if len(x) != QK_K {
  1056. panic("DotQ2_K: mismatched slice length")
  1057. }
  1058. if sum, ok := dotQ2KSimd(b, x); ok {
  1059. return sum
  1060. }
  1061. d := FP16ToFP32(b.D)
  1062. dmin := FP16ToFP32(b.DMin)
  1063. q := b.QS[:]
  1064. is := 0
  1065. outIdx := 0
  1066. var sum float32
  1067. for n := 0; n < QK_K; n += 128 {
  1068. shift := uint(0)
  1069. for j := 0; j < 4; j++ {
  1070. sc := b.Scales[is]
  1071. is++
  1072. dl := d * float32(sc&0xF)
  1073. ml := dmin * float32(sc>>4)
  1074. for l := 0; l < 16; l++ {
  1075. val := float32(int8((q[l] >> shift) & 3))
  1076. sum += x[outIdx] * (dl*val - ml)
  1077. outIdx++
  1078. }
  1079. sc = b.Scales[is]
  1080. is++
  1081. dl = d * float32(sc&0xF)
  1082. ml = dmin * float32(sc>>4)
  1083. for l := 0; l < 16; l++ {
  1084. val := float32(int8((q[l+16] >> shift) & 3))
  1085. sum += x[outIdx] * (dl*val - ml)
  1086. outIdx++
  1087. }
  1088. shift += 2
  1089. }
  1090. q = q[32:]
  1091. }
  1092. return sum
  1093. }
  1094. // BlockQ3_K represents a block of 256 weights quantized to 3 bits
  1095. // Layout (110 bytes):
  1096. // - HMask (32 bytes): high bit of 3-bit quants
  1097. // - QS (64 bytes): low 2 bits of 3-bit quants
  1098. // - Scales (12 bytes): 6-bit scales packed
  1099. // - D (2 bytes): float16 super-scale
  1100. type BlockQ3_K struct {
  1101. HMask [32]uint8
  1102. QS [64]uint8
  1103. Scales [12]uint8
  1104. D uint16
  1105. }
  1106. type Q3KDotParams struct {
  1107. S [16]float32
  1108. }
  1109. type q3kParamsKey struct {
  1110. p unsafe.Pointer
  1111. n int
  1112. }
  1113. var q3kDotParamsCache sync.Map
  1114. func GetQ3KDotParams(blocks []BlockQ3_K) []Q3KDotParams {
  1115. if len(blocks) == 0 {
  1116. return nil
  1117. }
  1118. key := q3kParamsKey{p: unsafe.Pointer(&blocks[0]), n: len(blocks)}
  1119. if v, ok := q3kDotParamsCache.Load(key); ok {
  1120. return v.([]Q3KDotParams)
  1121. }
  1122. params := make([]Q3KDotParams, len(blocks))
  1123. for bi := range blocks {
  1124. b := &blocks[bi]
  1125. d := FP16ToFP32(b.D)
  1126. var p Q3KDotParams
  1127. for i := 0; i < 16; i++ {
  1128. p.S[i] = d * float32(unpackQ3Scale(b.Scales[:], i))
  1129. }
  1130. params[bi] = p
  1131. }
  1132. q3kDotParamsCache.Store(key, params)
  1133. return params
  1134. }
  1135. // DequantizeQ3_K dequantizes a single Q3_K block into 256 floats
  1136. func DequantizeQ3_K(b *BlockQ3_K, out []float32) {
  1137. if dequantQ3KSimd(b, out) {
  1138. return
  1139. }
  1140. d := FP16ToFP32(b.D)
  1141. var scales [16]int8
  1142. for i := 0; i < 16; i++ {
  1143. scales[i] = unpackQ3Scale(b.Scales[:], i)
  1144. }
  1145. q := b.QS[:]
  1146. hm := b.HMask[:]
  1147. outIdx := 0
  1148. is := 0
  1149. m := uint8(1)
  1150. for n := 0; n < QK_K; n += 128 {
  1151. shift := uint(0)
  1152. for j := 0; j < 4; j++ {
  1153. dl := d * float32(scales[is])
  1154. is++
  1155. for l := 0; l < 16; l++ {
  1156. qv := int8((q[l] >> shift) & 0x3)
  1157. if hm[l]&m == 0 {
  1158. qv -= 4
  1159. }
  1160. out[outIdx] = dl * float32(qv)
  1161. outIdx++
  1162. }
  1163. dl = d * float32(scales[is])
  1164. is++
  1165. for l := 0; l < 16; l++ {
  1166. qv := int8((q[l+16] >> shift) & 0x3)
  1167. if hm[l+16]&m == 0 {
  1168. qv -= 4
  1169. }
  1170. out[outIdx] = dl * float32(qv)
  1171. outIdx++
  1172. }
  1173. shift += 2
  1174. m <<= 1
  1175. }
  1176. q = q[32:]
  1177. }
  1178. }
  1179. func DotQ3_K(b *BlockQ3_K, x []float32) float32 {
  1180. if len(x) != QK_K {
  1181. panic("DotQ3_K: mismatched slice length")
  1182. }
  1183. if sum, ok := dotQ3KSimd(b, x); ok {
  1184. return sum
  1185. }
  1186. d := FP16ToFP32(b.D)
  1187. var scales [16]int8
  1188. for i := 0; i < 16; i++ {
  1189. scales[i] = unpackQ3Scale(b.Scales[:], i)
  1190. }
  1191. q := b.QS[:]
  1192. hm := b.HMask[:]
  1193. outIdx := 0
  1194. is := 0
  1195. m := uint8(1)
  1196. var sum float32
  1197. for n := 0; n < QK_K; n += 128 {
  1198. shift := uint(0)
  1199. for j := 0; j < 4; j++ {
  1200. dl := d * float32(scales[is])
  1201. is++
  1202. for l := 0; l < 16; l++ {
  1203. qv := int8((q[l] >> shift) & 0x3)
  1204. if hm[l]&m == 0 {
  1205. qv -= 4
  1206. }
  1207. sum += x[outIdx] * (dl * float32(qv))
  1208. outIdx++
  1209. }
  1210. dl = d * float32(scales[is])
  1211. is++
  1212. for l := 0; l < 16; l++ {
  1213. qv := int8((q[l+16] >> shift) & 0x3)
  1214. if hm[l+16]&m == 0 {
  1215. qv -= 4
  1216. }
  1217. sum += x[outIdx] * (dl * float32(qv))
  1218. outIdx++
  1219. }
  1220. shift += 2
  1221. m <<= 1
  1222. }
  1223. q = q[32:]
  1224. }
  1225. return sum
  1226. }
  1227. // BlockQ6_K represents a block of 256 weights quantized to 6 bits
  1228. // Layout (210 bytes):
  1229. // - QL (128 bytes): lower 4 bits of 6-bit quants
  1230. // - QH (64 bytes): upper 2 bits of 6-bit quants
  1231. // - Scales (16 bytes): 8-bit signed scales
  1232. // - D (2 bytes): float16 super-scale
  1233. type BlockQ6_K struct {
  1234. QL [128]uint8
  1235. QH [64]uint8
  1236. Scales [16]int8
  1237. D uint16
  1238. }
  1239. type Q6KDotParams struct {
  1240. S [16]float32
  1241. }
  1242. type q6kParamsKey struct {
  1243. p unsafe.Pointer
  1244. n int
  1245. }
  1246. var q6kDotParamsCache sync.Map
  1247. func GetQ6KDotParams(blocks []BlockQ6_K) []Q6KDotParams {
  1248. if len(blocks) == 0 {
  1249. return nil
  1250. }
  1251. key := q6kParamsKey{p: unsafe.Pointer(&blocks[0]), n: len(blocks)}
  1252. if v, ok := q6kDotParamsCache.Load(key); ok {
  1253. return v.([]Q6KDotParams)
  1254. }
  1255. params := make([]Q6KDotParams, len(blocks))
  1256. for bi := range blocks {
  1257. b := &blocks[bi]
  1258. d := FP16ToFP32(b.D)
  1259. var p Q6KDotParams
  1260. for i := 0; i < 16; i++ {
  1261. p.S[i] = d * float32(b.Scales[i])
  1262. }
  1263. params[bi] = p
  1264. }
  1265. q6kDotParamsCache.Store(key, params)
  1266. return params
  1267. }
  1268. // DequantizeQ6_K dequantizes a single Q6_K block into 256 floats
  1269. // Logic adapted from llama.cpp's dequantize_row_q6_K
  1270. func DequantizeQ6_K(b *BlockQ6_K, out []float32) {
  1271. if dequantQ6KSimd(b, out) {
  1272. return
  1273. }
  1274. d := FP16ToFP32(b.D)
  1275. qlPtr := 0
  1276. qhPtr := 0
  1277. scPtr := 0
  1278. outPtr := 0
  1279. for n := 0; n < 256; n += 128 {
  1280. for l := 0; l < 32; l++ {
  1281. is := l / 16
  1282. ql0 := b.QL[qlPtr+l]
  1283. ql32 := b.QL[qlPtr+l+32]
  1284. qh := b.QH[qhPtr+l]
  1285. q1 := int8((ql0&0xF)|(((qh>>0)&3)<<4)) - 32
  1286. q2 := int8((ql32&0xF)|(((qh>>2)&3)<<4)) - 32
  1287. q3 := int8((ql0>>4)|(((qh>>4)&3)<<4)) - 32
  1288. q4 := int8((ql32>>4)|(((qh>>6)&3)<<4)) - 32
  1289. out[outPtr+l+0] = d * float32(b.Scales[scPtr+is+0]) * float32(q1)
  1290. out[outPtr+l+32] = d * float32(b.Scales[scPtr+is+2]) * float32(q2)
  1291. out[outPtr+l+64] = d * float32(b.Scales[scPtr+is+4]) * float32(q3)
  1292. out[outPtr+l+96] = d * float32(b.Scales[scPtr+is+6]) * float32(q4)
  1293. }
  1294. outPtr += 128
  1295. qlPtr += 64
  1296. qhPtr += 32
  1297. scPtr += 8
  1298. }
  1299. }
  1300. func DotQ6_K(b *BlockQ6_K, x []float32) float32 {
  1301. if len(x) != QK_K {
  1302. panic("DotQ6_K: mismatched slice length")
  1303. }
  1304. if sum, ok := dotQ6KSimd(b, x); ok {
  1305. return sum
  1306. }
  1307. d := FP16ToFP32(b.D)
  1308. qlPtr := 0
  1309. qhPtr := 0
  1310. scPtr := 0
  1311. outPtr := 0
  1312. var sum float32
  1313. for n := 0; n < 256; n += 128 {
  1314. for l := 0; l < 32; l++ {
  1315. is := l / 16
  1316. ql0 := b.QL[qlPtr+l]
  1317. ql32 := b.QL[qlPtr+l+32]
  1318. qh := b.QH[qhPtr+l]
  1319. q1 := int8((ql0&0xF)|(((qh>>0)&3)<<4)) - 32
  1320. q2 := int8((ql32&0xF)|(((qh>>2)&3)<<4)) - 32
  1321. q3 := int8((ql0>>4)|(((qh>>4)&3)<<4)) - 32
  1322. q4 := int8((ql32>>4)|(((qh>>6)&3)<<4)) - 32
  1323. s0 := d * float32(b.Scales[scPtr+is+0])
  1324. s2 := d * float32(b.Scales[scPtr+is+2])
  1325. s4 := d * float32(b.Scales[scPtr+is+4])
  1326. s6 := d * float32(b.Scales[scPtr+is+6])
  1327. base := outPtr + l
  1328. sum += x[base+0] * s0 * float32(q1)
  1329. sum += x[base+32] * s2 * float32(q2)
  1330. sum += x[base+64] * s4 * float32(q3)
  1331. sum += x[base+96] * s6 * float32(q4)
  1332. }
  1333. outPtr += 128
  1334. qlPtr += 64
  1335. qhPtr += 32
  1336. scPtr += 8
  1337. }
  1338. return sum
  1339. }