quantize.go 24 KB


  1. // Package quant provides fast quantization functions for model weights
  2. package quant
  3. import (
  4. "math"
  5. )
  6. const QK_K = 256
  7. // QuantizeQ8K quantizes float32 data to Q8_K format
  8. // Block layout (292 bytes per 256 elements):
  9. // - D (4 bytes): float32 scale
  10. // - QS (256 bytes): 256 int8 quants
  11. // - BSums (32 bytes): 16 int16 block sums
  12. func QuantizeQ8K(data []float32) []byte {
  13. // Pad to multiple of QK_K
  14. n := len(data)
  15. padding := (QK_K - (n % QK_K)) % QK_K
  16. if padding > 0 {
  17. padded := make([]float32, n+padding)
  18. copy(padded, data)
  19. data = padded
  20. }
  21. nBlocks := len(data) / QK_K
  22. // 292 bytes per block: 4 (d) + 256 (qs) + 32 (bsums)
  23. out := make([]byte, nBlocks*292)
  24. for b := 0; b < nBlocks; b++ {
  25. block := data[b*QK_K : (b+1)*QK_K]
  26. outBlock := out[b*292 : (b+1)*292]
  27. // Find max absolute value
  28. var amax float32
  29. for _, v := range block {
  30. if abs := float32(math.Abs(float64(v))); abs > amax {
  31. amax = abs
  32. }
  33. }
  34. // Calculate scale
  35. d := amax / 127.0
  36. var iscale float32
  37. if amax > 0 {
  38. iscale = 127.0 / amax
  39. }
  40. // Write d as float32 (little endian)
  41. dBits := math.Float32bits(d)
  42. outBlock[0] = byte(dBits)
  43. outBlock[1] = byte(dBits >> 8)
  44. outBlock[2] = byte(dBits >> 16)
  45. outBlock[3] = byte(dBits >> 24)
  46. // Quantize and write QS, calculate bsums
  47. var bsums [16]int16
  48. for i := 0; i < QK_K; i++ {
  49. q := int8(clampInt(int(math.Round(float64(block[i]*iscale))), -127, 127))
  50. outBlock[4+i] = byte(q)
  51. bsums[i/16] += int16(q)
  52. }
  53. // Write bsums (16 int16, little endian)
  54. for i := 0; i < 16; i++ {
  55. outBlock[260+i*2] = byte(bsums[i])
  56. outBlock[260+i*2+1] = byte(bsums[i] >> 8)
  57. }
  58. }
  59. return out
  60. }
  61. func getScaleMinK4(j int, q *[12]uint8) (d uint8, m uint8) {
  62. if j < 4 {
  63. d = q[j] & 63
  64. m = q[j+4] & 63
  65. return d, m
  66. }
  67. d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4)
  68. m = (q[j+4] >> 4) | ((q[j] >> 6) << 4)
  69. return d, m
  70. }
  71. func makeQKX2Quants32(x []float32, weights []float32) (float32, [32]uint8, float32) {
  72. const n = 32
  73. const nmax = 31.0
  74. const rmin = -0.5
  75. const rdelta = 0.1
  76. const nstep = 15
  77. const useMAD = false
  78. var lBest [32]uint8
  79. minVal := x[0]
  80. maxVal := x[0]
  81. sumW := weights[0]
  82. sumX := sumW * x[0]
  83. for i := 1; i < n; i++ {
  84. v := x[i]
  85. if v < minVal {
  86. minVal = v
  87. }
  88. if v > maxVal {
  89. maxVal = v
  90. }
  91. w := weights[i]
  92. sumW += w
  93. sumX += w * v
  94. }
  95. if minVal > 0 {
  96. minVal = 0
  97. }
  98. if maxVal == minVal {
  99. return 0, lBest, -minVal
  100. }
  101. iscale := float32(nmax) / (maxVal - minVal)
  102. scale := 1 / iscale
  103. bestErr := float32(0)
  104. var L [32]uint8
  105. for i := 0; i < n; i++ {
  106. l := clampInt(nearestIntFloat32(iscale*(x[i]-minVal)), 0, 31)
  107. L[i] = uint8(l)
  108. diff := scale*float32(l) + minVal - x[i]
  109. if useMAD {
  110. if diff < 0 {
  111. diff = -diff
  112. }
  113. } else {
  114. diff = diff * diff
  115. }
  116. bestErr += weights[i] * diff
  117. }
  118. bestScale := scale
  119. bestMin := minVal
  120. copy(lBest[:], L[:])
  121. var Laux [32]uint8
  122. for isIdx := 0; isIdx <= nstep; isIdx++ {
  123. iscale = (float32(rmin) + float32(rdelta)*float32(isIdx) + float32(nmax)) / (maxVal - minVal)
  124. var sumL, sumL2, sumXL float32
  125. for i := 0; i < n; i++ {
  126. l := clampInt(nearestIntFloat32(iscale*(x[i]-minVal)), 0, 31)
  127. Laux[i] = uint8(l)
  128. lf := float32(l)
  129. w := weights[i]
  130. sumL += w * lf
  131. sumL2 += w * lf * lf
  132. sumXL += w * lf * x[i]
  133. }
  134. D := sumW*sumL2 - sumL*sumL
  135. if D > 0 {
  136. thisScale := (sumW*sumXL - sumX*sumL) / D
  137. thisMin := (sumL2*sumX - sumL*sumXL) / D
  138. if thisMin > 0 {
  139. thisMin = 0
  140. thisScale = sumXL / sumL2
  141. }
  142. curErr := float32(0)
  143. for i := 0; i < n; i++ {
  144. diff := thisScale*float32(Laux[i]) + thisMin - x[i]
  145. if useMAD {
  146. if diff < 0 {
  147. diff = -diff
  148. }
  149. } else {
  150. diff = diff * diff
  151. }
  152. curErr += weights[i] * diff
  153. }
  154. if curErr < bestErr {
  155. copy(lBest[:], Laux[:])
  156. bestErr = curErr
  157. bestScale = thisScale
  158. bestMin = thisMin
  159. }
  160. }
  161. }
  162. return bestScale, lBest, -bestMin
  163. }
  164. func QuantizeQ5K(data []float32) []byte {
  165. n := len(data)
  166. padding := (QK_K - (n % QK_K)) % QK_K
  167. if padding > 0 {
  168. padded := make([]float32, n+padding)
  169. copy(padded, data)
  170. data = padded
  171. }
  172. nBlocks := len(data) / QK_K
  173. out := make([]byte, nBlocks*176)
  174. for b := 0; b < nBlocks; b++ {
  175. block := data[b*QK_K : (b+1)*QK_K]
  176. outBlock := out[b*176 : (b+1)*176]
  177. var L [QK_K]uint8
  178. var mins [8]float32
  179. var scales [8]float32
  180. maxScale := float32(0)
  181. maxMin := float32(0)
  182. var weights [32]float32
  183. for j := 0; j < 8; j++ {
  184. seg := block[j*32 : (j+1)*32]
  185. var sumX2 float32
  186. for l := 0; l < 32; l++ {
  187. v := seg[l]
  188. sumX2 += v * v
  189. }
  190. avX := float32(math.Sqrt(float64(sumX2 / 32)))
  191. for l := 0; l < 32; l++ {
  192. v := seg[l]
  193. absV := float32(math.Abs(float64(v)))
  194. weights[l] = avX + absV
  195. }
  196. sc, lq, mn := makeQKX2Quants32(seg, weights[:])
  197. scales[j] = sc
  198. mins[j] = mn
  199. copy(L[j*32:(j+1)*32], lq[:])
  200. if sc > maxScale {
  201. maxScale = sc
  202. }
  203. if mn > maxMin {
  204. maxMin = mn
  205. }
  206. }
  207. invScale := float32(0)
  208. invMin := float32(0)
  209. if maxScale > 0 {
  210. invScale = 63.0 / maxScale
  211. }
  212. if maxMin > 0 {
  213. invMin = 63.0 / maxMin
  214. }
  215. var scalesPacked [12]uint8
  216. for j := 0; j < 8; j++ {
  217. ls := uint8(clampInt(nearestIntFloat32(invScale*scales[j]), 0, 63))
  218. lm := uint8(clampInt(nearestIntFloat32(invMin*mins[j]), 0, 63))
  219. if j < 4 {
  220. scalesPacked[j] = ls
  221. scalesPacked[j+4] = lm
  222. } else {
  223. scalesPacked[j+4] = (ls & 0xF) | ((lm & 0xF) << 4)
  224. scalesPacked[j-4] |= ((ls >> 4) << 6)
  225. scalesPacked[j] |= ((lm >> 4) << 6)
  226. }
  227. }
  228. copy(outBlock[4:16], scalesPacked[:])
  229. dVal := maxScale / 63.0
  230. dMinVal := maxMin / 63.0
  231. dF16 := float32ToFloat16(dVal)
  232. dMinF16 := float32ToFloat16(dMinVal)
  233. outBlock[0] = byte(dF16)
  234. outBlock[1] = byte(dF16 >> 8)
  235. outBlock[2] = byte(dMinF16)
  236. outBlock[3] = byte(dMinF16 >> 8)
  237. if maxScale > 0 {
  238. for j := 0; j < 8; j++ {
  239. sc, m := getScaleMinK4(j, &scalesPacked)
  240. dLocal := dVal * float32(sc)
  241. if dLocal == 0 {
  242. continue
  243. }
  244. dm := dMinVal * float32(m)
  245. for ii := 0; ii < 32; ii++ {
  246. l := nearestIntFloat32((block[j*32+ii] + dm) / dLocal)
  247. L[j*32+ii] = uint8(clampInt(l, 0, 31))
  248. }
  249. }
  250. }
  251. qh := outBlock[16:48]
  252. qs := outBlock[48:176]
  253. for i := range qh {
  254. qh[i] = 0
  255. }
  256. m1 := uint8(1)
  257. m2 := uint8(2)
  258. qsOff := 0
  259. for n0 := 0; n0 < QK_K; n0 += 64 {
  260. for j := 0; j < 32; j++ {
  261. l1 := L[n0+j]
  262. if l1 > 15 {
  263. l1 -= 16
  264. qh[j] |= m1
  265. }
  266. l2 := L[n0+j+32]
  267. if l2 > 15 {
  268. l2 -= 16
  269. qh[j] |= m2
  270. }
  271. qs[qsOff+j] = l1 | (l2 << 4)
  272. }
  273. m1 <<= 2
  274. m2 <<= 2
  275. qsOff += 32
  276. }
  277. }
  278. return out
  279. }
  280. // QuantizeQ6K quantizes float32 data to Q6_K format
  281. // Layout (210 bytes per 256 elements):
  282. // - QL (128 bytes): lower 4 bits of 6-bit quants
  283. // - QH (64 bytes): upper 2 bits of 6-bit quants
  284. // - Scales (16 bytes): 8-bit signed scales
  285. // - D (2 bytes): float16 super-scale
  286. func QuantizeQ6K(data []float32) []byte {
  287. n := len(data)
  288. padding := (QK_K - (n % QK_K)) % QK_K
  289. if padding > 0 {
  290. padded := make([]float32, n+padding)
  291. copy(padded, data)
  292. data = padded
  293. }
  294. nBlocks := len(data) / QK_K
  295. out := make([]byte, nBlocks*210)
  296. for b := 0; b < nBlocks; b++ {
  297. block := data[b*QK_K : (b+1)*QK_K]
  298. outBlock := out[b*210 : (b+1)*210]
  299. // Calculate scales per 16-element sub-block (16 sub-blocks)
  300. var sbScale [16]float32
  301. var maxScale float32
  302. for j := 0; j < 16; j++ {
  303. sub := block[j*16 : (j+1)*16]
  304. var sbMax float32
  305. for _, v := range sub {
  306. if abs := float32(math.Abs(float64(v))); abs > sbMax {
  307. sbMax = abs
  308. }
  309. }
  310. if sbMax == 0 {
  311. sbMax = 1.0
  312. }
  313. sbScale[j] = sbMax / 31.5
  314. if sbScale[j] == 0 {
  315. sbScale[j] = 1.0
  316. }
  317. if sbScale[j] > maxScale {
  318. maxScale = sbScale[j]
  319. }
  320. }
  321. // Super-block scale
  322. dVal := maxScale / 127.0
  323. if dVal == 0 {
  324. dVal = 1.0
  325. }
  326. // Quantize sub-scales to 8-bit signed
  327. var ls [16]int8
  328. for j := 0; j < 16; j++ {
  329. ls[j] = int8(clampInt(int(math.Round(float64(sbScale[j]/dVal))), -128, 127))
  330. }
  331. // Restore dVal zeros
  332. if maxScale == 0 {
  333. dVal = 0
  334. }
  335. // Reconstruct scales and quantize weights
  336. var qVals [256]uint8
  337. for j := 0; j < 16; j++ {
  338. recS := float32(ls[j]) * dVal
  339. if recS == 0 {
  340. recS = 1.0
  341. }
  342. for i := 0; i < 16; i++ {
  343. q := int(math.Round(float64(block[j*16+i] / recS)))
  344. q = clampInt(q, -32, 31)
  345. qVals[j*16+i] = uint8(q + 32) // [0, 63]
  346. }
  347. }
  348. // Pack QL and QH
  349. ql := outBlock[0:128]
  350. qh := outBlock[128:192]
  351. // Process 2 halves of 128 weights each
  352. for nIdx := 0; nIdx < 256; nIdx += 128 {
  353. qlBase := nIdx / 2 // 0 or 64
  354. qhBase := nIdx / 4 // 0 or 32
  355. for l := 0; l < 32; l++ {
  356. idx1 := nIdx + l
  357. idx2 := nIdx + l + 32
  358. idx3 := nIdx + l + 64
  359. idx4 := nIdx + l + 96
  360. q1 := qVals[idx1]
  361. q2 := qVals[idx2]
  362. q3 := qVals[idx3]
  363. q4 := qVals[idx4]
  364. // Pack QL
  365. ql[qlBase+l] = (q1 & 0xF) | ((q3 & 0xF) << 4)
  366. ql[qlBase+l+32] = (q2 & 0xF) | ((q4 & 0xF) << 4)
  367. // Pack QH
  368. valH := ((q1 >> 4) & 0x3) |
  369. (((q2 >> 4) & 0x3) << 2) |
  370. (((q3 >> 4) & 0x3) << 4) |
  371. (((q4 >> 4) & 0x3) << 6)
  372. qh[qhBase+l] = valH
  373. }
  374. }
  375. // Write scales (16 bytes, int8 as uint8)
  376. scales := outBlock[192:208]
  377. for i := 0; i < 16; i++ {
  378. scales[i] = uint8(ls[i])
  379. }
  380. // Write d as float16 (little endian)
  381. dF16 := float32ToFloat16(dVal)
  382. outBlock[208] = byte(dF16)
  383. outBlock[209] = byte(dF16 >> 8)
  384. }
  385. return out
  386. }
  387. // QuantizeQ3K quantizes float32 data to Q3_K format
  388. // Layout (110 bytes per 256 elements):
  389. // - HMask (32 bytes): high bit of 3-bit quants
  390. // - QS (64 bytes): low 2 bits of 3-bit quants
  391. // - Scales (12 bytes): packed 6-bit signed scales
  392. // - D (2 bytes): float16 super-scale
  393. func QuantizeQ3K(data []float32) []byte {
  394. n := len(data)
  395. padding := (QK_K - (n % QK_K)) % QK_K
  396. if padding > 0 {
  397. padded := make([]float32, n+padding)
  398. copy(padded, data)
  399. data = padded
  400. }
  401. nBlocks := len(data) / QK_K
  402. out := make([]byte, nBlocks*110)
  403. var scales [16]float32
  404. var ls [16]uint8
  405. var lFinal [256]uint8
  406. for b := 0; b < nBlocks; b++ {
  407. block := data[b*QK_K : (b+1)*QK_K]
  408. outBlock := out[b*110 : (b+1)*110]
  409. hmask := outBlock[0:32]
  410. qs := outBlock[32:96]
  411. scalesPacked := outBlock[96:108]
  412. // First pass: compute per-sub-block scales
  413. var maxScale float32
  414. var maxAbs float32
  415. for j := 0; j < 16; j++ {
  416. sub := block[j*16 : (j+1)*16]
  417. sc := makeQ3Scale(sub)
  418. scales[j] = sc
  419. abs := float32(math.Abs(float64(sc)))
  420. if abs > maxAbs {
  421. maxAbs = abs
  422. maxScale = sc
  423. }
  424. }
  425. if maxAbs == 0 {
  426. // All zero block -> already zeroed
  427. continue
  428. }
  429. iscale := -32.0 / maxScale
  430. dVal := float32(1.0 / iscale)
  431. // Quantize scales to 6-bit signed, packed
  432. for j := 0; j < 16; j++ {
  433. l := clampInt(int(math.Round(float64(iscale*scales[j]))), -32, 31) + 32
  434. ls[j] = uint8(l)
  435. }
  436. packQ3Scales(ls, scalesPacked)
  437. // Re-quantize weights using packed scales
  438. for j := 0; j < 16; j++ {
  439. sc := unpackQ3Scale(scalesPacked, j)
  440. dLocal := dVal * float32(sc)
  441. if dLocal == 0 {
  442. for i := 0; i < 16; i++ {
  443. lFinal[j*16+i] = 0
  444. }
  445. continue
  446. }
  447. sub := block[j*16 : (j+1)*16]
  448. for i := 0; i < 16; i++ {
  449. q := clampInt(int(math.Round(float64(sub[i]/dLocal))), -4, 3)
  450. lFinal[j*16+i] = uint8(q + 4)
  451. }
  452. }
  453. // Build hmask and strip high bit
  454. m := 0
  455. hm := uint8(1)
  456. for j := 0; j < QK_K; j++ {
  457. if lFinal[j] > 3 {
  458. hmask[m] |= hm
  459. lFinal[j] -= 4
  460. }
  461. m++
  462. if m == QK_K/8 {
  463. m = 0
  464. hm <<= 1
  465. }
  466. }
  467. // Pack QS: four 2-bit lanes per byte
  468. for nIdx := 0; nIdx < 256; nIdx += 128 {
  469. for l := 0; l < 32; l++ {
  470. qs[nIdx/4+l] = lFinal[nIdx+l] |
  471. (lFinal[nIdx+l+32] << 2) |
  472. (lFinal[nIdx+l+64] << 4) |
  473. (lFinal[nIdx+l+96] << 6)
  474. }
  475. }
  476. // Write d
  477. dF16 := float32ToFloat16(dVal)
  478. outBlock[108] = byte(dF16)
  479. outBlock[109] = byte(dF16 >> 8)
  480. }
  481. return out
  482. }
  483. // QuantizeQ2K quantizes float32 data to Q2_K format
  484. // Layout (84 bytes per 256 elements):
  485. // - Scales (16 bytes): 4-bit scale + 4-bit min per sub-block
  486. // - QS (64 bytes): packed 2-bit quants
  487. // - D (2 bytes): float16 super-scale
  488. // - DMin (2 bytes): float16 super-min-scale
  489. func QuantizeQ2K(data []float32) []byte {
  490. n := len(data)
  491. padding := (QK_K - (n % QK_K)) % QK_K
  492. if padding > 0 {
  493. padded := make([]float32, n+padding)
  494. copy(padded, data)
  495. data = padded
  496. }
  497. nBlocks := len(data) / QK_K
  498. out := make([]byte, nBlocks*84)
  499. var scales [16]float32
  500. var mins [16]float32
  501. var scaleNib [16]uint8
  502. var minNib [16]uint8
  503. var lFinal [256]uint8
  504. for b := 0; b < nBlocks; b++ {
  505. block := data[b*QK_K : (b+1)*QK_K]
  506. outBlock := out[b*84 : (b+1)*84]
  507. scalesPacked := outBlock[0:16]
  508. qs := outBlock[16:80]
  509. var maxScale float32
  510. var maxMin float32
  511. // Per-sub-block quant search
  512. for j := 0; j < 16; j++ {
  513. sub := block[j*16 : (j+1)*16]
  514. scale, lTmp, tgtMin := makeQKX2Quants(sub)
  515. scales[j] = scale
  516. mins[j] = tgtMin
  517. if scale > maxScale {
  518. maxScale = scale
  519. }
  520. if tgtMin > maxMin {
  521. maxMin = tgtMin
  522. }
  523. // lTmp unused here; we re-quantize after super-scale
  524. _ = lTmp
  525. }
  526. var dVal float32
  527. if maxScale > 0 {
  528. inv := 15.0 / maxScale
  529. for j := 0; j < 16; j++ {
  530. scaleNib[j] = uint8(clampInt(int(math.Round(float64(inv*scales[j]))), 0, 15))
  531. }
  532. dVal = maxScale / 15.0
  533. } else {
  534. for j := 0; j < 16; j++ {
  535. scaleNib[j] = 0
  536. }
  537. dVal = 0
  538. }
  539. var dminVal float32
  540. if maxMin > 0 {
  541. invMin := 15.0 / maxMin
  542. for j := 0; j < 16; j++ {
  543. minNib[j] = uint8(clampInt(int(math.Round(float64(invMin*mins[j]))), 0, 15))
  544. }
  545. dminVal = maxMin / 15.0
  546. } else {
  547. for j := 0; j < 16; j++ {
  548. minNib[j] = 0
  549. }
  550. dminVal = 0
  551. }
  552. // Pack scales/mins nibbles
  553. for j := 0; j < 16; j++ {
  554. scalesPacked[j] = (scaleNib[j] & 0xF) | ((minNib[j] & 0xF) << 4)
  555. }
  556. // Re-quantize weights with quantized scales/mins
  557. for j := 0; j < 16; j++ {
  558. dl := dVal * float32(scaleNib[j])
  559. if dl == 0 {
  560. for i := 0; i < 16; i++ {
  561. lFinal[j*16+i] = 0
  562. }
  563. continue
  564. }
  565. dm := dminVal * float32(minNib[j])
  566. sub := block[j*16 : (j+1)*16]
  567. for i := 0; i < 16; i++ {
  568. q := clampInt(nearestIntFloat32((sub[i]+dm)/dl), 0, 3)
  569. lFinal[j*16+i] = uint8(q)
  570. }
  571. }
  572. // Pack QS
  573. for nIdx := 0; nIdx < 256; nIdx += 128 {
  574. for l := 0; l < 32; l++ {
  575. qs[nIdx/4+l] = lFinal[nIdx+l] |
  576. (lFinal[nIdx+l+32] << 2) |
  577. (lFinal[nIdx+l+64] << 4) |
  578. (lFinal[nIdx+l+96] << 6)
  579. }
  580. }
  581. // Write d and dmin
  582. dF16 := float32ToFloat16(dVal)
  583. dminF16 := float32ToFloat16(dminVal)
  584. outBlock[80] = byte(dF16)
  585. outBlock[81] = byte(dF16 >> 8)
  586. outBlock[82] = byte(dminF16)
  587. outBlock[83] = byte(dminF16 >> 8)
  588. }
  589. return out
  590. }
  591. // QuantizeQ4K quantizes float32 data to Q4_K format
  592. // Layout (144 bytes per 256 elements):
  593. // - D (2 bytes): float16 super-scale
  594. // - DMin (2 bytes): float16 super-min-scale
  595. // - Scales (12 bytes): packed 6-bit scales and mins
  596. // - QS (128 bytes): 256 4-bit quants
  597. func QuantizeQ4K(data []float32) []byte {
  598. n := len(data)
  599. padding := (QK_K - (n % QK_K)) % QK_K
  600. if padding > 0 {
  601. padded := make([]float32, n+padding)
  602. copy(padded, data)
  603. data = padded
  604. }
  605. nBlocks := len(data) / QK_K
  606. out := make([]byte, nBlocks*144)
  607. for b := 0; b < nBlocks; b++ {
  608. superblock := data[b*QK_K : (b+1)*QK_K]
  609. outBlock := out[b*144 : (b+1)*144]
  610. // Calculate min/max/scale per 32-element sub-block (8 sub-blocks)
  611. var sbMin, sbMax, sbScale [8]float32
  612. var targetMins [8]float32
  613. for j := 0; j < 8; j++ {
  614. sub := superblock[j*32 : (j+1)*32]
  615. minVal := sub[0]
  616. maxVal := sub[0]
  617. for _, v := range sub {
  618. if v < minVal {
  619. minVal = v
  620. }
  621. if v > maxVal {
  622. maxVal = v
  623. }
  624. }
  625. sbMin[j] = minVal
  626. sbMax[j] = maxVal
  627. // Constrain min to be at most 0
  628. minConstrained := minVal
  629. if minConstrained > 0 {
  630. minConstrained = 0
  631. }
  632. sbScale[j] = (maxVal - minConstrained) / 15.0
  633. targetMins[j] = -minConstrained // >= 0
  634. }
  635. // Super-block scales
  636. var maxScaleVal, maxMinVal float32
  637. for j := 0; j < 8; j++ {
  638. if sbScale[j] > maxScaleVal {
  639. maxScaleVal = sbScale[j]
  640. }
  641. if targetMins[j] > maxMinVal {
  642. maxMinVal = targetMins[j]
  643. }
  644. }
  645. dVal := maxScaleVal / 63.0
  646. dminVal := maxMinVal / 63.0
  647. // Avoid division by zero
  648. if dVal == 0 {
  649. dVal = 1.0
  650. }
  651. if dminVal == 0 {
  652. dminVal = 1.0
  653. }
  654. // Quantize scales and mins to 6 bits
  655. var ls, lm [8]uint8
  656. for j := 0; j < 8; j++ {
  657. ls[j] = uint8(clampInt(int(math.Round(float64(sbScale[j]/dVal))), 0, 63))
  658. lm[j] = uint8(clampInt(int(math.Round(float64(targetMins[j]/dminVal))), 0, 63))
  659. }
  660. // Restore zeros
  661. if maxScaleVal == 0 {
  662. dVal = 0
  663. }
  664. if maxMinVal == 0 {
  665. dminVal = 0
  666. }
  667. // Reconstruct local scales/mins
  668. var recS, recM [8]float32
  669. for j := 0; j < 8; j++ {
  670. recS[j] = float32(ls[j]) * dVal
  671. recM[j] = float32(lm[j]) * dminVal
  672. }
  673. // Quantize weights: w = q * s - m => q = (w + m) / s
  674. var qVals [256]uint8
  675. for j := 0; j < 8; j++ {
  676. s := recS[j]
  677. m := recM[j]
  678. if s == 0 {
  679. s = 1.0
  680. }
  681. for i := 0; i < 32; i++ {
  682. q := int(math.Round(float64((superblock[j*32+i] + m) / s)))
  683. qVals[j*32+i] = uint8(clampInt(q, 0, 15))
  684. }
  685. }
  686. // Write D and DMin as float16
  687. dF16 := float32ToFloat16(dVal)
  688. dminF16 := float32ToFloat16(dminVal)
  689. outBlock[0] = byte(dF16)
  690. outBlock[1] = byte(dF16 >> 8)
  691. outBlock[2] = byte(dminF16)
  692. outBlock[3] = byte(dminF16 >> 8)
  693. // Pack scales (12 bytes)
  694. scales := outBlock[4:16]
  695. // scales[0..3] = ls[0..3] | (ls[4..7] high 2 bits << 6)
  696. // scales[4..7] = lm[0..3] | (lm[4..7] high 2 bits << 6)
  697. // scales[8..11] = (ls[4..7] low 4 bits) | (lm[4..7] low 4 bits << 4)
  698. for j := 0; j < 4; j++ {
  699. scales[j] = ls[j] | ((ls[j+4] >> 4) << 6)
  700. scales[j+4] = lm[j] | ((lm[j+4] >> 4) << 6)
  701. }
  702. for j := 0; j < 4; j++ {
  703. scales[8+j] = (ls[j+4] & 0xF) | ((lm[j+4] & 0xF) << 4)
  704. }
  705. // Pack QS (128 bytes): pairs of nibbles
  706. qs := outBlock[16:144]
  707. for chunk := 0; chunk < 4; chunk++ {
  708. base := chunk * 64
  709. for l := 0; l < 32; l++ {
  710. low := qVals[base+l]
  711. high := qVals[base+l+32]
  712. qs[chunk*32+l] = low | (high << 4)
  713. }
  714. }
  715. }
  716. return out
  717. }
  718. // Helper functions
  719. func clampInt(v, lo, hi int) int {
  720. if v < lo {
  721. return lo
  722. }
  723. if v > hi {
  724. return hi
  725. }
  726. return v
  727. }
  728. // nearestIntFloat32 matches llama.cpp's nearest_int for float32 inputs.
  729. func nearestIntFloat32(v float32) int {
  730. const magic = 12582912.0 // 2^23 + 2^22
  731. f := v + magic
  732. bits := math.Float32bits(f)
  733. return int(bits&0x007FFFFF) - 0x00400000
  734. }
  735. // float32ToFloat16 converts float32 to float16 (IEEE 754 half precision)
  736. func float32ToFloat16(f float32) uint16 {
  737. bits := math.Float32bits(f)
  738. sign := uint16((bits >> 16) & 0x8000)
  739. exp := int((bits >> 23) & 0xFF)
  740. mant := bits & 0x007FFFFF
  741. if exp == 0xFF {
  742. // Inf or NaN
  743. if mant == 0 {
  744. return sign | 0x7C00 // Inf
  745. }
  746. return sign | 0x7E00 // NaN
  747. }
  748. if exp == 0 {
  749. // Zero or denormal
  750. return sign
  751. }
  752. // Rebias exponent from 127 to 15
  753. newExp := exp - 127 + 15
  754. if newExp >= 31 {
  755. // Overflow to infinity
  756. return sign | 0x7C00
  757. }
  758. if newExp <= 0 {
  759. // Underflow to zero or denormal
  760. if newExp < -10 {
  761. return sign
  762. }
  763. // Denormal
  764. mant |= 0x00800000
  765. shift := uint(14 - newExp)
  766. // Round to nearest-even while shifting
  767. value := mant >> shift
  768. roundMask := (uint32(1) << shift) - 1
  769. roundMid := uint32(1) << (shift - 1)
  770. roundBits := mant & roundMask
  771. if roundBits > roundMid || (roundBits == roundMid && (value&1) != 0) {
  772. value++
  773. }
  774. // Renormalize if rounding overflowed the mantissa
  775. if value == 0x00000400 {
  776. return sign | uint16(1<<10)
  777. }
  778. return sign | uint16(value)
  779. }
  780. // Normalized number: round mantissa to nearest-even before truncation
  781. mant += 0x00001000 // add 0.5 ulp at bit 12 (23-10-1)
  782. // Handle mantissa overflow into exponent
  783. if mant&0x00800000 != 0 {
  784. mant = 0
  785. newExp++
  786. if newExp >= 31 {
  787. return sign | 0x7C00
  788. }
  789. }
  790. return sign | (uint16(newExp) << 10) | uint16(mant>>13)
  791. }
  792. // makeQ3Scale computes the RMSE-optimized scale for a 16-value block (Q3_K).
  793. // Returns the scale; quantized values are not needed for the first pass.
  794. func makeQ3Scale(x []float32) float32 {
  795. const nmax = 4.0
  796. const eps = 1e-15
  797. // Find max absolute and the value achieving it (with sign)
  798. var amax float64
  799. var maxVal float64
  800. for _, v := range x {
  801. av := math.Abs(float64(v))
  802. if av > amax {
  803. amax = av
  804. maxVal = float64(v)
  805. }
  806. }
  807. if amax < eps {
  808. return 0
  809. }
  810. iscale := -nmax / maxVal
  811. var L [16]float64
  812. var sumlx float64
  813. var suml2 float64
  814. for i, v := range x {
  815. l := float64(clampInt(int(math.Round(iscale*float64(v))), -4, 3))
  816. L[i] = l
  817. w := float64(v) * float64(v)
  818. sumlx += w * float64(v) * l
  819. suml2 += w * l * l
  820. }
  821. for iter := 0; iter < 5; iter++ {
  822. changed := 0
  823. for i, v := range x {
  824. w := float64(v) * float64(v)
  825. slx := sumlx - w*float64(v)*L[i]
  826. if slx > 0 {
  827. sl2 := suml2 - w*L[i]*L[i]
  828. newL := float64(clampInt(int(math.Round(float64(v)*sl2/slx)), -4, 3))
  829. if newL != L[i] {
  830. slxNew := slx + w*float64(v)*newL
  831. sl2New := sl2 + w*newL*newL
  832. if sl2New > 0 && slxNew*slxNew*suml2 > sumlx*sumlx*sl2New {
  833. L[i] = newL
  834. sumlx = slxNew
  835. suml2 = sl2New
  836. changed++
  837. }
  838. }
  839. }
  840. }
  841. if changed == 0 {
  842. break
  843. }
  844. }
  845. if suml2 == 0 {
  846. return 0
  847. }
  848. return float32(sumlx / suml2)
  849. }
  850. // packQ3Scales packs 16 6-bit signed scales into 12 bytes (llama.cpp layout).
  851. func packQ3Scales(ls [16]uint8, dst []byte) {
  852. for i := range dst {
  853. dst[i] = 0
  854. }
  855. for j := 0; j < 16; j++ {
  856. l := ls[j]
  857. low := l & 0xF
  858. high := (l >> 4) & 0x3
  859. if j < 8 {
  860. dst[j] = low
  861. } else {
  862. dst[j-8] |= low << 4
  863. }
  864. dst[8+(j%4)] |= high << (2 * (j / 4))
  865. }
  866. }
  867. // unpackQ3Scale reverses packQ3Scales for a single index, returning signed scale [-32,31].
  868. func unpackQ3Scale(packed []byte, idx int) int8 {
  869. var sc uint8
  870. if idx < 8 {
  871. sc = packed[idx] & 0xF
  872. } else {
  873. sc = packed[idx-8] >> 4
  874. }
  875. sc |= ((packed[8+(idx%4)] >> (2 * (idx / 4))) & 0x3) << 4
  876. return int8(sc) - 32
  877. }
  878. // makeQKX2Quants implements the search used by Q2_K (port of llama.cpp/makarna python).
  879. // Returns scale, quantized values (unused by caller), and targetMin (-min).
  880. func makeQKX2Quants(x []float32) (float32, [16]uint8, float32) {
  881. const nmax = 3.0
  882. const rmin = -0.5
  883. const rdelta = 0.1
  884. const nstep = 15
  885. var lBest [16]uint8
  886. minVal := x[0]
  887. maxVal := x[0]
  888. sumW := float32(math.Abs(float64(x[0])))
  889. sumX := sumW * x[0]
  890. for i := 1; i < 16; i++ {
  891. v := x[i]
  892. if v < minVal {
  893. minVal = v
  894. }
  895. if v > maxVal {
  896. maxVal = v
  897. }
  898. w := float32(math.Abs(float64(v)))
  899. sumW += w
  900. sumX += w * v
  901. }
  902. if minVal > 0 {
  903. minVal = 0
  904. }
  905. if maxVal == minVal {
  906. return 0, lBest, -minVal
  907. }
  908. iscale := float32(nmax) / (maxVal - minVal)
  909. scale := 1 / iscale
  910. var L [16]uint8
  911. bestErr := float32(0)
  912. for i := 0; i < 16; i++ {
  913. l := clampInt(nearestIntFloat32(iscale*(x[i]-minVal)), 0, 3)
  914. L[i] = uint8(l)
  915. diff := scale*float32(l) + minVal - x[i]
  916. if diff < 0 {
  917. diff = -diff
  918. }
  919. w := float32(math.Abs(float64(x[i])))
  920. bestErr += w * diff
  921. }
  922. bestScale := scale
  923. bestMin := minVal
  924. copy(lBest[:], L[:])
  925. for isIdx := 0; isIdx <= nstep; isIdx++ {
  926. iscale = (float32(rmin) + float32(rdelta)*float32(isIdx) + float32(nmax)) / (maxVal - minVal)
  927. var Laux [16]uint8
  928. var sumL, sumL2, sumXL float32
  929. for i := 0; i < 16; i++ {
  930. l := clampInt(nearestIntFloat32(iscale*(x[i]-minVal)), 0, 3)
  931. Laux[i] = uint8(l)
  932. lf := float32(l)
  933. w := float32(math.Abs(float64(x[i])))
  934. sumL += w * lf
  935. sumL2 += w * lf * lf
  936. sumXL += w * lf * x[i]
  937. }
  938. D := sumW*sumL2 - sumL*sumL
  939. if D > 0 {
  940. thisScale := (sumW*sumXL - sumX*sumL) / D
  941. thisMin := (sumL2*sumX - sumL*sumXL) / D
  942. if thisMin > 0 {
  943. thisMin = 0
  944. thisScale = sumXL / sumL2
  945. }
  946. curErr := float32(0)
  947. for i := 0; i < 16; i++ {
  948. diff := thisScale*float32(Laux[i]) + thisMin - x[i]
  949. if diff < 0 {
  950. diff = -diff
  951. }
  952. w := float32(math.Abs(float64(x[i])))
  953. curErr += w * diff
  954. }
  955. if curErr < bestErr {
  956. copy(lBest[:], Laux[:])
  957. bestErr = curErr
  958. bestScale = thisScale
  959. bestMin = thisMin
  960. }
  961. }
  962. }
  963. return bestScale, lBest, -bestMin
  964. }