| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095 |
- // Package quant provides fast quantization functions for model weights
- package quant
- import (
- "math"
- )
- const QK_K = 256
- // QuantizeQ8K quantizes float32 data to Q8_K format
- // Block layout (292 bytes per 256 elements):
- // - D (4 bytes): float32 scale
- // - QS (256 bytes): 256 int8 quants
- // - BSums (32 bytes): 16 int16 block sums
- func QuantizeQ8K(data []float32) []byte {
- // Pad to multiple of QK_K
- n := len(data)
- padding := (QK_K - (n % QK_K)) % QK_K
- if padding > 0 {
- padded := make([]float32, n+padding)
- copy(padded, data)
- data = padded
- }
- nBlocks := len(data) / QK_K
- // 292 bytes per block: 4 (d) + 256 (qs) + 32 (bsums)
- out := make([]byte, nBlocks*292)
- for b := 0; b < nBlocks; b++ {
- block := data[b*QK_K : (b+1)*QK_K]
- outBlock := out[b*292 : (b+1)*292]
- // Find max absolute value
- var amax float32
- for _, v := range block {
- if abs := float32(math.Abs(float64(v))); abs > amax {
- amax = abs
- }
- }
- // Calculate scale
- d := amax / 127.0
- var iscale float32
- if amax > 0 {
- iscale = 127.0 / amax
- }
- // Write d as float32 (little endian)
- dBits := math.Float32bits(d)
- outBlock[0] = byte(dBits)
- outBlock[1] = byte(dBits >> 8)
- outBlock[2] = byte(dBits >> 16)
- outBlock[3] = byte(dBits >> 24)
- // Quantize and write QS, calculate bsums
- var bsums [16]int16
- for i := 0; i < QK_K; i++ {
- q := int8(clampInt(int(math.Round(float64(block[i]*iscale))), -127, 127))
- outBlock[4+i] = byte(q)
- bsums[i/16] += int16(q)
- }
- // Write bsums (16 int16, little endian)
- for i := 0; i < 16; i++ {
- outBlock[260+i*2] = byte(bsums[i])
- outBlock[260+i*2+1] = byte(bsums[i] >> 8)
- }
- }
- return out
- }
- func getScaleMinK4(j int, q *[12]uint8) (d uint8, m uint8) {
- if j < 4 {
- d = q[j] & 63
- m = q[j+4] & 63
- return d, m
- }
- d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4)
- m = (q[j+4] >> 4) | ((q[j] >> 6) << 4)
- return d, m
- }
- func makeQKX2Quants32(x []float32, weights []float32) (float32, [32]uint8, float32) {
- const n = 32
- const nmax = 31.0
- const rmin = -0.5
- const rdelta = 0.1
- const nstep = 15
- const useMAD = false
- var lBest [32]uint8
- minVal := x[0]
- maxVal := x[0]
- sumW := weights[0]
- sumX := sumW * x[0]
- for i := 1; i < n; i++ {
- v := x[i]
- if v < minVal {
- minVal = v
- }
- if v > maxVal {
- maxVal = v
- }
- w := weights[i]
- sumW += w
- sumX += w * v
- }
- if minVal > 0 {
- minVal = 0
- }
- if maxVal == minVal {
- return 0, lBest, -minVal
- }
- iscale := float32(nmax) / (maxVal - minVal)
- scale := 1 / iscale
- bestErr := float32(0)
- var L [32]uint8
- for i := 0; i < n; i++ {
- l := clampInt(nearestIntFloat32(iscale*(x[i]-minVal)), 0, 31)
- L[i] = uint8(l)
- diff := scale*float32(l) + minVal - x[i]
- if useMAD {
- if diff < 0 {
- diff = -diff
- }
- } else {
- diff = diff * diff
- }
- bestErr += weights[i] * diff
- }
- bestScale := scale
- bestMin := minVal
- copy(lBest[:], L[:])
- var Laux [32]uint8
- for isIdx := 0; isIdx <= nstep; isIdx++ {
- iscale = (float32(rmin) + float32(rdelta)*float32(isIdx) + float32(nmax)) / (maxVal - minVal)
- var sumL, sumL2, sumXL float32
- for i := 0; i < n; i++ {
- l := clampInt(nearestIntFloat32(iscale*(x[i]-minVal)), 0, 31)
- Laux[i] = uint8(l)
- lf := float32(l)
- w := weights[i]
- sumL += w * lf
- sumL2 += w * lf * lf
- sumXL += w * lf * x[i]
- }
- D := sumW*sumL2 - sumL*sumL
- if D > 0 {
- thisScale := (sumW*sumXL - sumX*sumL) / D
- thisMin := (sumL2*sumX - sumL*sumXL) / D
- if thisMin > 0 {
- thisMin = 0
- thisScale = sumXL / sumL2
- }
- curErr := float32(0)
- for i := 0; i < n; i++ {
- diff := thisScale*float32(Laux[i]) + thisMin - x[i]
- if useMAD {
- if diff < 0 {
- diff = -diff
- }
- } else {
- diff = diff * diff
- }
- curErr += weights[i] * diff
- }
- if curErr < bestErr {
- copy(lBest[:], Laux[:])
- bestErr = curErr
- bestScale = thisScale
- bestMin = thisMin
- }
- }
- }
- return bestScale, lBest, -bestMin
- }
- func QuantizeQ5K(data []float32) []byte {
- n := len(data)
- padding := (QK_K - (n % QK_K)) % QK_K
- if padding > 0 {
- padded := make([]float32, n+padding)
- copy(padded, data)
- data = padded
- }
- nBlocks := len(data) / QK_K
- out := make([]byte, nBlocks*176)
- for b := 0; b < nBlocks; b++ {
- block := data[b*QK_K : (b+1)*QK_K]
- outBlock := out[b*176 : (b+1)*176]
- var L [QK_K]uint8
- var mins [8]float32
- var scales [8]float32
- maxScale := float32(0)
- maxMin := float32(0)
- var weights [32]float32
- for j := 0; j < 8; j++ {
- seg := block[j*32 : (j+1)*32]
- var sumX2 float32
- for l := 0; l < 32; l++ {
- v := seg[l]
- sumX2 += v * v
- }
- avX := float32(math.Sqrt(float64(sumX2 / 32)))
- for l := 0; l < 32; l++ {
- v := seg[l]
- absV := float32(math.Abs(float64(v)))
- weights[l] = avX + absV
- }
- sc, lq, mn := makeQKX2Quants32(seg, weights[:])
- scales[j] = sc
- mins[j] = mn
- copy(L[j*32:(j+1)*32], lq[:])
- if sc > maxScale {
- maxScale = sc
- }
- if mn > maxMin {
- maxMin = mn
- }
- }
- invScale := float32(0)
- invMin := float32(0)
- if maxScale > 0 {
- invScale = 63.0 / maxScale
- }
- if maxMin > 0 {
- invMin = 63.0 / maxMin
- }
- var scalesPacked [12]uint8
- for j := 0; j < 8; j++ {
- ls := uint8(clampInt(nearestIntFloat32(invScale*scales[j]), 0, 63))
- lm := uint8(clampInt(nearestIntFloat32(invMin*mins[j]), 0, 63))
- if j < 4 {
- scalesPacked[j] = ls
- scalesPacked[j+4] = lm
- } else {
- scalesPacked[j+4] = (ls & 0xF) | ((lm & 0xF) << 4)
- scalesPacked[j-4] |= ((ls >> 4) << 6)
- scalesPacked[j] |= ((lm >> 4) << 6)
- }
- }
- copy(outBlock[4:16], scalesPacked[:])
- dVal := maxScale / 63.0
- dMinVal := maxMin / 63.0
- dF16 := float32ToFloat16(dVal)
- dMinF16 := float32ToFloat16(dMinVal)
- outBlock[0] = byte(dF16)
- outBlock[1] = byte(dF16 >> 8)
- outBlock[2] = byte(dMinF16)
- outBlock[3] = byte(dMinF16 >> 8)
- if maxScale > 0 {
- for j := 0; j < 8; j++ {
- sc, m := getScaleMinK4(j, &scalesPacked)
- dLocal := dVal * float32(sc)
- if dLocal == 0 {
- continue
- }
- dm := dMinVal * float32(m)
- for ii := 0; ii < 32; ii++ {
- l := nearestIntFloat32((block[j*32+ii] + dm) / dLocal)
- L[j*32+ii] = uint8(clampInt(l, 0, 31))
- }
- }
- }
- qh := outBlock[16:48]
- qs := outBlock[48:176]
- for i := range qh {
- qh[i] = 0
- }
- m1 := uint8(1)
- m2 := uint8(2)
- qsOff := 0
- for n0 := 0; n0 < QK_K; n0 += 64 {
- for j := 0; j < 32; j++ {
- l1 := L[n0+j]
- if l1 > 15 {
- l1 -= 16
- qh[j] |= m1
- }
- l2 := L[n0+j+32]
- if l2 > 15 {
- l2 -= 16
- qh[j] |= m2
- }
- qs[qsOff+j] = l1 | (l2 << 4)
- }
- m1 <<= 2
- m2 <<= 2
- qsOff += 32
- }
- }
- return out
- }
- // QuantizeQ6K quantizes float32 data to Q6_K format
- // Layout (210 bytes per 256 elements):
- // - QL (128 bytes): lower 4 bits of 6-bit quants
- // - QH (64 bytes): upper 2 bits of 6-bit quants
- // - Scales (16 bytes): 8-bit signed scales
- // - D (2 bytes): float16 super-scale
- func QuantizeQ6K(data []float32) []byte {
- n := len(data)
- padding := (QK_K - (n % QK_K)) % QK_K
- if padding > 0 {
- padded := make([]float32, n+padding)
- copy(padded, data)
- data = padded
- }
- nBlocks := len(data) / QK_K
- out := make([]byte, nBlocks*210)
- for b := 0; b < nBlocks; b++ {
- block := data[b*QK_K : (b+1)*QK_K]
- outBlock := out[b*210 : (b+1)*210]
- // Calculate scales per 16-element sub-block (16 sub-blocks)
- var sbScale [16]float32
- var maxScale float32
- for j := 0; j < 16; j++ {
- sub := block[j*16 : (j+1)*16]
- var sbMax float32
- for _, v := range sub {
- if abs := float32(math.Abs(float64(v))); abs > sbMax {
- sbMax = abs
- }
- }
- if sbMax == 0 {
- sbMax = 1.0
- }
- sbScale[j] = sbMax / 31.5
- if sbScale[j] == 0 {
- sbScale[j] = 1.0
- }
- if sbScale[j] > maxScale {
- maxScale = sbScale[j]
- }
- }
- // Super-block scale
- dVal := maxScale / 127.0
- if dVal == 0 {
- dVal = 1.0
- }
- // Quantize sub-scales to 8-bit signed
- var ls [16]int8
- for j := 0; j < 16; j++ {
- ls[j] = int8(clampInt(int(math.Round(float64(sbScale[j]/dVal))), -128, 127))
- }
- // Restore dVal zeros
- if maxScale == 0 {
- dVal = 0
- }
- // Reconstruct scales and quantize weights
- var qVals [256]uint8
- for j := 0; j < 16; j++ {
- recS := float32(ls[j]) * dVal
- if recS == 0 {
- recS = 1.0
- }
- for i := 0; i < 16; i++ {
- q := int(math.Round(float64(block[j*16+i] / recS)))
- q = clampInt(q, -32, 31)
- qVals[j*16+i] = uint8(q + 32) // [0, 63]
- }
- }
- // Pack QL and QH
- ql := outBlock[0:128]
- qh := outBlock[128:192]
- // Process 2 halves of 128 weights each
- for nIdx := 0; nIdx < 256; nIdx += 128 {
- qlBase := nIdx / 2 // 0 or 64
- qhBase := nIdx / 4 // 0 or 32
- for l := 0; l < 32; l++ {
- idx1 := nIdx + l
- idx2 := nIdx + l + 32
- idx3 := nIdx + l + 64
- idx4 := nIdx + l + 96
- q1 := qVals[idx1]
- q2 := qVals[idx2]
- q3 := qVals[idx3]
- q4 := qVals[idx4]
- // Pack QL
- ql[qlBase+l] = (q1 & 0xF) | ((q3 & 0xF) << 4)
- ql[qlBase+l+32] = (q2 & 0xF) | ((q4 & 0xF) << 4)
- // Pack QH
- valH := ((q1 >> 4) & 0x3) |
- (((q2 >> 4) & 0x3) << 2) |
- (((q3 >> 4) & 0x3) << 4) |
- (((q4 >> 4) & 0x3) << 6)
- qh[qhBase+l] = valH
- }
- }
- // Write scales (16 bytes, int8 as uint8)
- scales := outBlock[192:208]
- for i := 0; i < 16; i++ {
- scales[i] = uint8(ls[i])
- }
- // Write d as float16 (little endian)
- dF16 := float32ToFloat16(dVal)
- outBlock[208] = byte(dF16)
- outBlock[209] = byte(dF16 >> 8)
- }
- return out
- }
- // QuantizeQ3K quantizes float32 data to Q3_K format
- // Layout (110 bytes per 256 elements):
- // - HMask (32 bytes): high bit of 3-bit quants
- // - QS (64 bytes): low 2 bits of 3-bit quants
- // - Scales (12 bytes): packed 6-bit signed scales
- // - D (2 bytes): float16 super-scale
- func QuantizeQ3K(data []float32) []byte {
- n := len(data)
- padding := (QK_K - (n % QK_K)) % QK_K
- if padding > 0 {
- padded := make([]float32, n+padding)
- copy(padded, data)
- data = padded
- }
- nBlocks := len(data) / QK_K
- out := make([]byte, nBlocks*110)
- var scales [16]float32
- var ls [16]uint8
- var lFinal [256]uint8
- for b := 0; b < nBlocks; b++ {
- block := data[b*QK_K : (b+1)*QK_K]
- outBlock := out[b*110 : (b+1)*110]
- hmask := outBlock[0:32]
- qs := outBlock[32:96]
- scalesPacked := outBlock[96:108]
- // First pass: compute per-sub-block scales
- var maxScale float32
- var maxAbs float32
- for j := 0; j < 16; j++ {
- sub := block[j*16 : (j+1)*16]
- sc := makeQ3Scale(sub)
- scales[j] = sc
- abs := float32(math.Abs(float64(sc)))
- if abs > maxAbs {
- maxAbs = abs
- maxScale = sc
- }
- }
- if maxAbs == 0 {
- // All zero block -> already zeroed
- continue
- }
- iscale := -32.0 / maxScale
- dVal := float32(1.0 / iscale)
- // Quantize scales to 6-bit signed, packed
- for j := 0; j < 16; j++ {
- l := clampInt(int(math.Round(float64(iscale*scales[j]))), -32, 31) + 32
- ls[j] = uint8(l)
- }
- packQ3Scales(ls, scalesPacked)
- // Re-quantize weights using packed scales
- for j := 0; j < 16; j++ {
- sc := unpackQ3Scale(scalesPacked, j)
- dLocal := dVal * float32(sc)
- if dLocal == 0 {
- for i := 0; i < 16; i++ {
- lFinal[j*16+i] = 0
- }
- continue
- }
- sub := block[j*16 : (j+1)*16]
- for i := 0; i < 16; i++ {
- q := clampInt(int(math.Round(float64(sub[i]/dLocal))), -4, 3)
- lFinal[j*16+i] = uint8(q + 4)
- }
- }
- // Build hmask and strip high bit
- m := 0
- hm := uint8(1)
- for j := 0; j < QK_K; j++ {
- if lFinal[j] > 3 {
- hmask[m] |= hm
- lFinal[j] -= 4
- }
- m++
- if m == QK_K/8 {
- m = 0
- hm <<= 1
- }
- }
- // Pack QS: four 2-bit lanes per byte
- for nIdx := 0; nIdx < 256; nIdx += 128 {
- for l := 0; l < 32; l++ {
- qs[nIdx/4+l] = lFinal[nIdx+l] |
- (lFinal[nIdx+l+32] << 2) |
- (lFinal[nIdx+l+64] << 4) |
- (lFinal[nIdx+l+96] << 6)
- }
- }
- // Write d
- dF16 := float32ToFloat16(dVal)
- outBlock[108] = byte(dF16)
- outBlock[109] = byte(dF16 >> 8)
- }
- return out
- }
- // QuantizeQ2K quantizes float32 data to Q2_K format
- // Layout (84 bytes per 256 elements):
- // - Scales (16 bytes): 4-bit scale + 4-bit min per sub-block
- // - QS (64 bytes): packed 2-bit quants
- // - D (2 bytes): float16 super-scale
- // - DMin (2 bytes): float16 super-min-scale
- func QuantizeQ2K(data []float32) []byte {
- n := len(data)
- padding := (QK_K - (n % QK_K)) % QK_K
- if padding > 0 {
- padded := make([]float32, n+padding)
- copy(padded, data)
- data = padded
- }
- nBlocks := len(data) / QK_K
- out := make([]byte, nBlocks*84)
- var scales [16]float32
- var mins [16]float32
- var scaleNib [16]uint8
- var minNib [16]uint8
- var lFinal [256]uint8
- for b := 0; b < nBlocks; b++ {
- block := data[b*QK_K : (b+1)*QK_K]
- outBlock := out[b*84 : (b+1)*84]
- scalesPacked := outBlock[0:16]
- qs := outBlock[16:80]
- var maxScale float32
- var maxMin float32
- // Per-sub-block quant search
- for j := 0; j < 16; j++ {
- sub := block[j*16 : (j+1)*16]
- scale, lTmp, tgtMin := makeQKX2Quants(sub)
- scales[j] = scale
- mins[j] = tgtMin
- if scale > maxScale {
- maxScale = scale
- }
- if tgtMin > maxMin {
- maxMin = tgtMin
- }
- // lTmp unused here; we re-quantize after super-scale
- _ = lTmp
- }
- var dVal float32
- if maxScale > 0 {
- inv := 15.0 / maxScale
- for j := 0; j < 16; j++ {
- scaleNib[j] = uint8(clampInt(int(math.Round(float64(inv*scales[j]))), 0, 15))
- }
- dVal = maxScale / 15.0
- } else {
- for j := 0; j < 16; j++ {
- scaleNib[j] = 0
- }
- dVal = 0
- }
- var dminVal float32
- if maxMin > 0 {
- invMin := 15.0 / maxMin
- for j := 0; j < 16; j++ {
- minNib[j] = uint8(clampInt(int(math.Round(float64(invMin*mins[j]))), 0, 15))
- }
- dminVal = maxMin / 15.0
- } else {
- for j := 0; j < 16; j++ {
- minNib[j] = 0
- }
- dminVal = 0
- }
- // Pack scales/mins nibbles
- for j := 0; j < 16; j++ {
- scalesPacked[j] = (scaleNib[j] & 0xF) | ((minNib[j] & 0xF) << 4)
- }
- // Re-quantize weights with quantized scales/mins
- for j := 0; j < 16; j++ {
- dl := dVal * float32(scaleNib[j])
- if dl == 0 {
- for i := 0; i < 16; i++ {
- lFinal[j*16+i] = 0
- }
- continue
- }
- dm := dminVal * float32(minNib[j])
- sub := block[j*16 : (j+1)*16]
- for i := 0; i < 16; i++ {
- q := clampInt(nearestIntFloat32((sub[i]+dm)/dl), 0, 3)
- lFinal[j*16+i] = uint8(q)
- }
- }
- // Pack QS
- for nIdx := 0; nIdx < 256; nIdx += 128 {
- for l := 0; l < 32; l++ {
- qs[nIdx/4+l] = lFinal[nIdx+l] |
- (lFinal[nIdx+l+32] << 2) |
- (lFinal[nIdx+l+64] << 4) |
- (lFinal[nIdx+l+96] << 6)
- }
- }
- // Write d and dmin
- dF16 := float32ToFloat16(dVal)
- dminF16 := float32ToFloat16(dminVal)
- outBlock[80] = byte(dF16)
- outBlock[81] = byte(dF16 >> 8)
- outBlock[82] = byte(dminF16)
- outBlock[83] = byte(dminF16 >> 8)
- }
- return out
- }
- // QuantizeQ4K quantizes float32 data to Q4_K format
- // Layout (144 bytes per 256 elements):
- // - D (2 bytes): float16 super-scale
- // - DMin (2 bytes): float16 super-min-scale
- // - Scales (12 bytes): packed 6-bit scales and mins
- // - QS (128 bytes): 256 4-bit quants
- func QuantizeQ4K(data []float32) []byte {
- n := len(data)
- padding := (QK_K - (n % QK_K)) % QK_K
- if padding > 0 {
- padded := make([]float32, n+padding)
- copy(padded, data)
- data = padded
- }
- nBlocks := len(data) / QK_K
- out := make([]byte, nBlocks*144)
- for b := 0; b < nBlocks; b++ {
- superblock := data[b*QK_K : (b+1)*QK_K]
- outBlock := out[b*144 : (b+1)*144]
- // Calculate min/max/scale per 32-element sub-block (8 sub-blocks)
- var sbMin, sbMax, sbScale [8]float32
- var targetMins [8]float32
- for j := 0; j < 8; j++ {
- sub := superblock[j*32 : (j+1)*32]
- minVal := sub[0]
- maxVal := sub[0]
- for _, v := range sub {
- if v < minVal {
- minVal = v
- }
- if v > maxVal {
- maxVal = v
- }
- }
- sbMin[j] = minVal
- sbMax[j] = maxVal
- // Constrain min to be at most 0
- minConstrained := minVal
- if minConstrained > 0 {
- minConstrained = 0
- }
- sbScale[j] = (maxVal - minConstrained) / 15.0
- targetMins[j] = -minConstrained // >= 0
- }
- // Super-block scales
- var maxScaleVal, maxMinVal float32
- for j := 0; j < 8; j++ {
- if sbScale[j] > maxScaleVal {
- maxScaleVal = sbScale[j]
- }
- if targetMins[j] > maxMinVal {
- maxMinVal = targetMins[j]
- }
- }
- dVal := maxScaleVal / 63.0
- dminVal := maxMinVal / 63.0
- // Avoid division by zero
- if dVal == 0 {
- dVal = 1.0
- }
- if dminVal == 0 {
- dminVal = 1.0
- }
- // Quantize scales and mins to 6 bits
- var ls, lm [8]uint8
- for j := 0; j < 8; j++ {
- ls[j] = uint8(clampInt(int(math.Round(float64(sbScale[j]/dVal))), 0, 63))
- lm[j] = uint8(clampInt(int(math.Round(float64(targetMins[j]/dminVal))), 0, 63))
- }
- // Restore zeros
- if maxScaleVal == 0 {
- dVal = 0
- }
- if maxMinVal == 0 {
- dminVal = 0
- }
- // Reconstruct local scales/mins
- var recS, recM [8]float32
- for j := 0; j < 8; j++ {
- recS[j] = float32(ls[j]) * dVal
- recM[j] = float32(lm[j]) * dminVal
- }
- // Quantize weights: w = q * s - m => q = (w + m) / s
- var qVals [256]uint8
- for j := 0; j < 8; j++ {
- s := recS[j]
- m := recM[j]
- if s == 0 {
- s = 1.0
- }
- for i := 0; i < 32; i++ {
- q := int(math.Round(float64((superblock[j*32+i] + m) / s)))
- qVals[j*32+i] = uint8(clampInt(q, 0, 15))
- }
- }
- // Write D and DMin as float16
- dF16 := float32ToFloat16(dVal)
- dminF16 := float32ToFloat16(dminVal)
- outBlock[0] = byte(dF16)
- outBlock[1] = byte(dF16 >> 8)
- outBlock[2] = byte(dminF16)
- outBlock[3] = byte(dminF16 >> 8)
- // Pack scales (12 bytes)
- scales := outBlock[4:16]
- // scales[0..3] = ls[0..3] | (ls[4..7] high 2 bits << 6)
- // scales[4..7] = lm[0..3] | (lm[4..7] high 2 bits << 6)
- // scales[8..11] = (ls[4..7] low 4 bits) | (lm[4..7] low 4 bits << 4)
- for j := 0; j < 4; j++ {
- scales[j] = ls[j] | ((ls[j+4] >> 4) << 6)
- scales[j+4] = lm[j] | ((lm[j+4] >> 4) << 6)
- }
- for j := 0; j < 4; j++ {
- scales[8+j] = (ls[j+4] & 0xF) | ((lm[j+4] & 0xF) << 4)
- }
- // Pack QS (128 bytes): pairs of nibbles
- qs := outBlock[16:144]
- for chunk := 0; chunk < 4; chunk++ {
- base := chunk * 64
- for l := 0; l < 32; l++ {
- low := qVals[base+l]
- high := qVals[base+l+32]
- qs[chunk*32+l] = low | (high << 4)
- }
- }
- }
- return out
- }
- // Helper functions
- func clampInt(v, lo, hi int) int {
- if v < lo {
- return lo
- }
- if v > hi {
- return hi
- }
- return v
- }
- // nearestIntFloat32 matches llama.cpp's nearest_int for float32 inputs.
- func nearestIntFloat32(v float32) int {
- const magic = 12582912.0 // 2^23 + 2^22
- f := v + magic
- bits := math.Float32bits(f)
- return int(bits&0x007FFFFF) - 0x00400000
- }
- // float32ToFloat16 converts float32 to float16 (IEEE 754 half precision)
- func float32ToFloat16(f float32) uint16 {
- bits := math.Float32bits(f)
- sign := uint16((bits >> 16) & 0x8000)
- exp := int((bits >> 23) & 0xFF)
- mant := bits & 0x007FFFFF
- if exp == 0xFF {
- // Inf or NaN
- if mant == 0 {
- return sign | 0x7C00 // Inf
- }
- return sign | 0x7E00 // NaN
- }
- if exp == 0 {
- // Zero or denormal
- return sign
- }
- // Rebias exponent from 127 to 15
- newExp := exp - 127 + 15
- if newExp >= 31 {
- // Overflow to infinity
- return sign | 0x7C00
- }
- if newExp <= 0 {
- // Underflow to zero or denormal
- if newExp < -10 {
- return sign
- }
- // Denormal
- mant |= 0x00800000
- shift := uint(14 - newExp)
- // Round to nearest-even while shifting
- value := mant >> shift
- roundMask := (uint32(1) << shift) - 1
- roundMid := uint32(1) << (shift - 1)
- roundBits := mant & roundMask
- if roundBits > roundMid || (roundBits == roundMid && (value&1) != 0) {
- value++
- }
- // Renormalize if rounding overflowed the mantissa
- if value == 0x00000400 {
- return sign | uint16(1<<10)
- }
- return sign | uint16(value)
- }
- // Normalized number: round mantissa to nearest-even before truncation
- mant += 0x00001000 // add 0.5 ulp at bit 12 (23-10-1)
- // Handle mantissa overflow into exponent
- if mant&0x00800000 != 0 {
- mant = 0
- newExp++
- if newExp >= 31 {
- return sign | 0x7C00
- }
- }
- return sign | (uint16(newExp) << 10) | uint16(mant>>13)
- }
- // makeQ3Scale computes the RMSE-optimized scale for a 16-value block (Q3_K).
- // Returns the scale; quantized values are not needed for the first pass.
- func makeQ3Scale(x []float32) float32 {
- const nmax = 4.0
- const eps = 1e-15
- // Find max absolute and the value achieving it (with sign)
- var amax float64
- var maxVal float64
- for _, v := range x {
- av := math.Abs(float64(v))
- if av > amax {
- amax = av
- maxVal = float64(v)
- }
- }
- if amax < eps {
- return 0
- }
- iscale := -nmax / maxVal
- var L [16]float64
- var sumlx float64
- var suml2 float64
- for i, v := range x {
- l := float64(clampInt(int(math.Round(iscale*float64(v))), -4, 3))
- L[i] = l
- w := float64(v) * float64(v)
- sumlx += w * float64(v) * l
- suml2 += w * l * l
- }
- for iter := 0; iter < 5; iter++ {
- changed := 0
- for i, v := range x {
- w := float64(v) * float64(v)
- slx := sumlx - w*float64(v)*L[i]
- if slx > 0 {
- sl2 := suml2 - w*L[i]*L[i]
- newL := float64(clampInt(int(math.Round(float64(v)*sl2/slx)), -4, 3))
- if newL != L[i] {
- slxNew := slx + w*float64(v)*newL
- sl2New := sl2 + w*newL*newL
- if sl2New > 0 && slxNew*slxNew*suml2 > sumlx*sumlx*sl2New {
- L[i] = newL
- sumlx = slxNew
- suml2 = sl2New
- changed++
- }
- }
- }
- }
- if changed == 0 {
- break
- }
- }
- if suml2 == 0 {
- return 0
- }
- return float32(sumlx / suml2)
- }
- // packQ3Scales packs 16 6-bit signed scales into 12 bytes (llama.cpp layout).
- func packQ3Scales(ls [16]uint8, dst []byte) {
- for i := range dst {
- dst[i] = 0
- }
- for j := 0; j < 16; j++ {
- l := ls[j]
- low := l & 0xF
- high := (l >> 4) & 0x3
- if j < 8 {
- dst[j] = low
- } else {
- dst[j-8] |= low << 4
- }
- dst[8+(j%4)] |= high << (2 * (j / 4))
- }
- }
- // unpackQ3Scale reverses packQ3Scales for a single index, returning signed scale [-32,31].
- func unpackQ3Scale(packed []byte, idx int) int8 {
- var sc uint8
- if idx < 8 {
- sc = packed[idx] & 0xF
- } else {
- sc = packed[idx-8] >> 4
- }
- sc |= ((packed[8+(idx%4)] >> (2 * (idx / 4))) & 0x3) << 4
- return int8(sc) - 32
- }
- // makeQKX2Quants implements the search used by Q2_K (port of llama.cpp/makarna python).
- // Returns scale, quantized values (unused by caller), and targetMin (-min).
- func makeQKX2Quants(x []float32) (float32, [16]uint8, float32) {
- const nmax = 3.0
- const rmin = -0.5
- const rdelta = 0.1
- const nstep = 15
- var lBest [16]uint8
- minVal := x[0]
- maxVal := x[0]
- sumW := float32(math.Abs(float64(x[0])))
- sumX := sumW * x[0]
- for i := 1; i < 16; i++ {
- v := x[i]
- if v < minVal {
- minVal = v
- }
- if v > maxVal {
- maxVal = v
- }
- w := float32(math.Abs(float64(v)))
- sumW += w
- sumX += w * v
- }
- if minVal > 0 {
- minVal = 0
- }
- if maxVal == minVal {
- return 0, lBest, -minVal
- }
- iscale := float32(nmax) / (maxVal - minVal)
- scale := 1 / iscale
- var L [16]uint8
- bestErr := float32(0)
- for i := 0; i < 16; i++ {
- l := clampInt(nearestIntFloat32(iscale*(x[i]-minVal)), 0, 3)
- L[i] = uint8(l)
- diff := scale*float32(l) + minVal - x[i]
- if diff < 0 {
- diff = -diff
- }
- w := float32(math.Abs(float64(x[i])))
- bestErr += w * diff
- }
- bestScale := scale
- bestMin := minVal
- copy(lBest[:], L[:])
- for isIdx := 0; isIdx <= nstep; isIdx++ {
- iscale = (float32(rmin) + float32(rdelta)*float32(isIdx) + float32(nmax)) / (maxVal - minVal)
- var Laux [16]uint8
- var sumL, sumL2, sumXL float32
- for i := 0; i < 16; i++ {
- l := clampInt(nearestIntFloat32(iscale*(x[i]-minVal)), 0, 3)
- Laux[i] = uint8(l)
- lf := float32(l)
- w := float32(math.Abs(float64(x[i])))
- sumL += w * lf
- sumL2 += w * lf * lf
- sumXL += w * lf * x[i]
- }
- D := sumW*sumL2 - sumL*sumL
- if D > 0 {
- thisScale := (sumW*sumXL - sumX*sumL) / D
- thisMin := (sumL2*sumX - sumL*sumXL) / D
- if thisMin > 0 {
- thisMin = 0
- thisScale = sumXL / sumL2
- }
- curErr := float32(0)
- for i := 0; i < 16; i++ {
- diff := thisScale*float32(Laux[i]) + thisMin - x[i]
- if diff < 0 {
- diff = -diff
- }
- w := float32(math.Abs(float64(x[i])))
- curErr += w * diff
- }
- if curErr < bestErr {
- copy(lBest[:], Laux[:])
- bestErr = curErr
- bestScale = thisScale
- bestMin = thisMin
- }
- }
- }
- return bestScale, lBest, -bestMin
- }
|