| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480 |
- package tensor
- import (
- "math"
- "sync"
- "unsafe"
- )
- // QK_K is the super-block size for K-quants
- const QK_K = 256
- // unpackQ3Scale extracts a signed 6-bit scale from packed bytes (llama.cpp layout)
- func unpackQ3Scale(packed []byte, idx int) int8 {
- var sc uint8
- if idx < 8 {
- sc = packed[idx] & 0xF
- } else {
- sc = packed[idx-8] >> 4
- }
- sc |= ((packed[8+(idx%4)] >> (2 * (idx / 4))) & 0x3) << 4
- return int8(sc) - 32
- }
- // BlockQ4_K represents a block of 256 weights quantized to 4 bits with super-block scales
- // Layout (144 bytes):
- // - D (2 bytes): float16 super-scale
- // - DMin (2 bytes): float16 super-min-scale
- // - Scales (12 bytes): 8 6-bit scales and 8 6-bit mins packed
- // - QS (128 bytes): 256 4-bit quants
- type BlockQ4_K struct {
- D uint16
- DMin uint16
- Scales [12]uint8
- QS [128]uint8
- }
- type BlockQ5_K struct {
- D uint16
- DMin uint16
- Scales [12]uint8
- QH [32]uint8
- QS [128]uint8
- }
- func getScaleMinK4(j int, q *[12]uint8) (d uint8, m uint8) {
- if j < 4 {
- d = q[j] & 63
- m = q[j+4] & 63
- return d, m
- }
- d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4)
- m = (q[j+4] >> 4) | ((q[j] >> 6) << 4)
- return d, m
- }
- // FP16ToFP32 converts a float16 (as uint16) to float32.
- // Implements full IEEE 754 half-precision conversion.
- func FP16ToFP32(n uint16) float32 {
- sign := uint32(n&0x8000) << 16
- exp := uint32(n&0x7C00) >> 10
- mant := uint32(n & 0x03FF)
- // Normalized case (most common for model weights)
- if exp > 0 && exp < 0x1F {
- return math.Float32frombits(sign | ((exp + 112) << 23) | (mant << 13))
- }
- // Zero or Denormalized
- if exp == 0 {
- if mant == 0 {
- return math.Float32frombits(sign)
- }
- // Denormalized number
- // Renormalize: multiply by 2^(-14)
- // 1024.0 is 2^10
- m := float32(mant) / 1024.0
- val := m * float32(math.Pow(2, -14))
- if sign != 0 {
- val = -val
- }
- return val
- }
- // Infinity or NaN (exp == 0x1F)
- if mant == 0 {
- return math.Float32frombits(sign | 0x7F800000) // Infinity
- }
- return math.Float32frombits(sign | 0x7FC00000 | (mant << 13)) // NaN
- }
- // DequantizeQ4_K dequantizes a single Q4_K block into 256 floats
- func DequantizeQ4_K(b *BlockQ4_K, out []float32) {
- if dequantQ4KSimd(b, out) {
- return
- }
- d := FP16ToFP32(b.D)
- dmin := FP16ToFP32(b.DMin)
- var sc [8]uint8
- var m [8]uint8
- for j := 0; j < 4; j++ {
- sc[j] = b.Scales[j] & 63
- m[j] = b.Scales[j+4] & 63
- }
- for j := 4; j < 8; j++ {
- sc[j] = (b.Scales[j+4] & 0xF) | ((b.Scales[j-4] >> 6) << 4)
- m[j] = (b.Scales[j+4] >> 4) | ((b.Scales[j-0] >> 6) << 4)
- }
- outPtr := 0
- qsPtr := 0
- for i := 0; i < 8; i += 2 {
- d1 := d * float32(sc[i])
- m1 := dmin * float32(m[i])
- d2 := d * float32(sc[i+1])
- m2 := dmin * float32(m[i+1])
- for l := 0; l < 32; l++ {
- val := b.QS[qsPtr+l]
- v1 := val & 0xF
- v2 := val >> 4
- out[outPtr] = float32(v1)*d1 - m1
- out[outPtr+32] = float32(v2)*d2 - m2
- outPtr++
- }
- outPtr += 32
- qsPtr += 32
- }
- }
- func DotQ4_K(b *BlockQ4_K, x []float32) float32 {
- if len(x) != QK_K {
- panic("DotQ4_K: mismatched slice length")
- }
- if sum, ok := dotQ4KSimd(b, x); ok {
- return sum
- }
- d := FP16ToFP32(b.D)
- dmin := FP16ToFP32(b.DMin)
- var sc [8]uint8
- var m [8]uint8
- for j := 0; j < 4; j++ {
- sc[j] = b.Scales[j] & 63
- m[j] = b.Scales[j+4] & 63
- }
- for j := 4; j < 8; j++ {
- sc[j] = (b.Scales[j+4] & 0xF) | ((b.Scales[j-4] >> 6) << 4)
- m[j] = (b.Scales[j+4] >> 4) | ((b.Scales[j-0] >> 6) << 4)
- }
- var sum float32
- outPtr := 0
- qsPtr := 0
- for i := 0; i < 8; i += 2 {
- d1 := d * float32(sc[i])
- m1 := dmin * float32(m[i])
- d2 := d * float32(sc[i+1])
- m2 := dmin * float32(m[i+1])
- for l := 0; l < 32; l++ {
- val := b.QS[qsPtr+l]
- v1 := val & 0xF
- v2 := val >> 4
- sum += x[outPtr] * (float32(v1)*d1 - m1)
- sum += x[outPtr+32] * (float32(v2)*d2 - m2)
- outPtr++
- }
- outPtr += 32
- qsPtr += 32
- }
- return sum
- }
- func DotQ2_K_Params(b *BlockQ2_K, p *Q2KDotParams, x []float32) float32 {
- if len(x) != QK_K {
- panic("DotQ2_K_Params: mismatched slice length")
- }
- if hasAVX2 {
- is := 0
- xIdx := 0
- qOffset := 0
- var sum float32
- for n := 0; n < QK_K; n += 128 {
- _ = n
- shift := uint(0)
- for j := 0; j < 4; j++ {
- dl := p.DL[is]
- ml := p.ML[is]
- is++
- sum += dotQ2KInnerAVX2Fused(&b.QS[qOffset], &x[xIdx], dl, ml, shift)
- xIdx += 16
- dl = p.DL[is]
- ml = p.ML[is]
- is++
- sum += dotQ2KInnerAVX2Fused(&b.QS[qOffset+16], &x[xIdx], dl, ml, shift)
- xIdx += 16
- shift += 2
- }
- qOffset += 32
- }
- return sum
- }
- q := b.QS[:]
- is := 0
- outIdx := 0
- var sum float32
- for n := 0; n < QK_K; n += 128 {
- shift := uint(0)
- for j := 0; j < 4; j++ {
- dl := p.DL[is]
- ml := p.ML[is]
- is++
- for l := 0; l < 16; l++ {
- val := float32(int8((q[l] >> shift) & 3))
- sum += x[outIdx] * (dl*val - ml)
- outIdx++
- }
- dl = p.DL[is]
- ml = p.ML[is]
- is++
- for l := 0; l < 16; l++ {
- val := float32(int8((q[l+16] >> shift) & 3))
- sum += x[outIdx] * (dl*val - ml)
- outIdx++
- }
- shift += 2
- }
- q = q[32:]
- }
- return sum
- }
- func DotQ2KTile8(sums *[8]float32, w []BlockQ2_K, wp []Q2KDotParams, base int, stride int, x *float32, n int) {
- if n <= 0 {
- return
- }
- xp := unsafe.Pointer(x)
- if hasAVX2 {
- for t := 0; t < n; t++ {
- idx := base + t*stride
- b := &w[idx]
- p := &wp[idx]
- is := 0
- xIdx := 0
- qOffset := 0
- for nn := 0; nn < QK_K; nn += 128 {
- _ = nn
- shift := uint(0)
- for j := 0; j < 4; j++ {
- dl := p.DL[is]
- ml := p.ML[is]
- is++
- xSeg := (*float32)(unsafe.Add(xp, uintptr(xIdx)*4))
- sums[t] += dotQ2KInnerAVX2Fused(&b.QS[qOffset], xSeg, dl, ml, shift)
- xIdx += 16
- dl = p.DL[is]
- ml = p.ML[is]
- is++
- xSeg = (*float32)(unsafe.Add(xp, uintptr(xIdx)*4))
- sums[t] += dotQ2KInnerAVX2Fused(&b.QS[qOffset+16], xSeg, dl, ml, shift)
- xIdx += 16
- shift += 2
- }
- qOffset += 32
- }
- }
- return
- }
- xSlice := unsafe.Slice(x, QK_K)
- for t := 0; t < n; t++ {
- idx := base + t*stride
- b := &w[idx]
- p := &wp[idx]
- sums[t] += DotQ2_K_Params(b, p, xSlice)
- }
- }
- func DotQ8KTile8(sums *[8]float32, w []BlockQ8_K, base int, stride int, x *float32, n int) {
- if n <= 0 {
- return
- }
- if hasAVX512 {
- for t := 0; t < n; t++ {
- idx := base + t*stride
- sums[t] += dotQ8KAVX512(&w[idx], x)
- }
- return
- }
- if hasAVX2 {
- for t := 0; t < n; t++ {
- idx := base + t*stride
- sums[t] += dotQ8KAVX2(&w[idx], x)
- }
- return
- }
- xSlice := unsafe.Slice(x, QK_K)
- for t := 0; t < n; t++ {
- idx := base + t*stride
- sums[t] += DotQ8_K(&w[idx], xSlice)
- }
- }
- func DotQ3_K_Params(b *BlockQ3_K, p *Q3KDotParams, x []float32) float32 {
- if len(x) != QK_K {
- panic("DotQ3_K_Params: mismatched slice length")
- }
- if hasAVX2 {
- q := b.QS[:]
- hm := b.HMask[:]
- is := 0
- xIdx := 0
- m := uint8(1)
- var sum float32
- for n := 0; n < QK_K; n += 128 {
- for j := 0; j < 4; j++ {
- dl := p.S[is]
- sum += dotQ3KInnerAVX2Fused(&q[0], &hm[0], &x[xIdx], dl, m, uint(j*2))
- xIdx += 16
- is++
- dl = p.S[is]
- sum += dotQ3KInnerAVX2Fused(&q[16], &hm[16], &x[xIdx], dl, m, uint(j*2))
- xIdx += 16
- is++
- m <<= 1
- }
- q = q[32:]
- }
- return sum
- }
- q := b.QS[:]
- hm := b.HMask[:]
- outIdx := 0
- is := 0
- m := uint8(1)
- var sum float32
- for n := 0; n < QK_K; n += 128 {
- shift := uint(0)
- for j := 0; j < 4; j++ {
- dl := p.S[is]
- is++
- for l := 0; l < 16; l++ {
- qv := int8((q[l] >> shift) & 0x3)
- if hm[l]&m == 0 {
- qv -= 4
- }
- sum += x[outIdx] * (dl * float32(qv))
- outIdx++
- }
- dl = p.S[is]
- is++
- for l := 0; l < 16; l++ {
- qv := int8((q[l+16] >> shift) & 0x3)
- if hm[l+16]&m == 0 {
- qv -= 4
- }
- sum += x[outIdx] * (dl * float32(qv))
- outIdx++
- }
- shift += 2
- m <<= 1
- }
- q = q[32:]
- }
- return sum
- }
- func DotQ3KTile8(sums *[8]float32, w []BlockQ3_K, wp []Q3KDotParams, base int, stride int, x *float32, n int) {
- if n <= 0 {
- return
- }
- xp := unsafe.Pointer(x)
- if hasAVX2 {
- for t := 0; t < n; t++ {
- idx := base + t*stride
- b := &w[idx]
- p := &wp[idx]
- q := b.QS[:]
- hm := b.HMask[:]
- is := 0
- xIdx := 0
- m := uint8(1)
- for nn := 0; nn < QK_K; nn += 128 {
- for j := 0; j < 4; j++ {
- dl := p.S[is]
- xSeg := (*float32)(unsafe.Add(xp, uintptr(xIdx)*4))
- sums[t] += dotQ3KInnerAVX2Fused(&q[0], &hm[0], xSeg, dl, m, uint(j*2))
- xIdx += 16
- is++
- dl = p.S[is]
- xSeg = (*float32)(unsafe.Add(xp, uintptr(xIdx)*4))
- sums[t] += dotQ3KInnerAVX2Fused(&q[16], &hm[16], xSeg, dl, m, uint(j*2))
- xIdx += 16
- is++
- m <<= 1
- }
- q = q[32:]
- }
- }
- return
- }
- for t := 0; t < n; t++ {
- idx := base + t*stride
- b := &w[idx]
- p := &wp[idx]
- xIdx := 0
- q := b.QS[:]
- hm := b.HMask[:]
- is := 0
- m := uint8(1)
- for nn := 0; nn < QK_K; nn += 128 {
- shift := uint(0)
- for j := 0; j < 4; j++ {
- dl := p.S[is]
- is++
- for l := 0; l < 16; l++ {
- qv := int8((q[l] >> shift) & 0x3)
- if hm[l]&m == 0 {
- qv -= 4
- }
- x0 := *(*float32)(unsafe.Add(xp, uintptr(xIdx)*4))
- sums[t] += x0 * (dl * float32(qv))
- xIdx++
- }
- dl = p.S[is]
- is++
- for l := 0; l < 16; l++ {
- qv := int8((q[l+16] >> shift) & 0x3)
- if hm[l+16]&m == 0 {
- qv -= 4
- }
- x0 := *(*float32)(unsafe.Add(xp, uintptr(xIdx)*4))
- sums[t] += x0 * (dl * float32(qv))
- xIdx++
- }
- shift += 2
- m <<= 1
- }
- q = q[32:]
- }
- }
- }
- func DotQ6_K_Params(b *BlockQ6_K, p *Q6KDotParams, x []float32) float32 {
- if len(x) != QK_K {
- panic("DotQ6_K_Params: mismatched slice length")
- }
- if hasAVX512 {
- sum := dotQ6KInnerAVX512(&b.QL[0], &b.QH[0], &p.S[0], &x[0])
- sum += dotQ6KInnerAVX512(&b.QL[64], &b.QH[32], &p.S[8], &x[128])
- return sum
- }
- if hasAVX2 {
- sum := dotQ6KInnerAVX2(&b.QL[0], &b.QH[0], &p.S[0], &x[0])
- sum += dotQ6KInnerAVX2(&b.QL[64], &b.QH[32], &p.S[8], &x[128])
- return sum
- }
- qlPtr := 0
- qhPtr := 0
- outPtr := 0
- var sum float32
- for half := 0; half < 2; half++ {
- scBase := half * 8
- for l := 0; l < 32; l++ {
- is := l / 16
- ql0 := b.QL[qlPtr+l]
- ql32 := b.QL[qlPtr+l+32]
- qh := b.QH[qhPtr+l]
- q1 := int8((ql0&0xF)|(((qh>>0)&3)<<4)) - 32
- q2 := int8((ql32&0xF)|(((qh>>2)&3)<<4)) - 32
- q3 := int8((ql0>>4)|(((qh>>4)&3)<<4)) - 32
- q4 := int8((ql32>>4)|(((qh>>6)&3)<<4)) - 32
- s0 := p.S[scBase+is+0]
- s2 := p.S[scBase+is+2]
- s4 := p.S[scBase+is+4]
- s6 := p.S[scBase+is+6]
- base := outPtr + l
- sum += x[base+0] * s0 * float32(q1)
- sum += x[base+32] * s2 * float32(q2)
- sum += x[base+64] * s4 * float32(q3)
- sum += x[base+96] * s6 * float32(q4)
- }
- outPtr += 128
- qlPtr += 64
- qhPtr += 32
- }
- return sum
- }
- func DotQ6KTile8(sums *[8]float32, w []BlockQ6_K, wp []Q6KDotParams, base int, stride int, x *float32, n int) {
- if n <= 0 {
- return
- }
- xp := unsafe.Pointer(x)
- if hasAVX512 {
- for t := 0; t < n; t++ {
- idx := base + t*stride
- b := &w[idx]
- p := &wp[idx]
- sums[t] += dotQ6KInnerAVX512(&b.QL[0], &b.QH[0], &p.S[0], x)
- x128 := (*float32)(unsafe.Add(xp, 128*4))
- sums[t] += dotQ6KInnerAVX512(&b.QL[64], &b.QH[32], &p.S[8], x128)
- }
- return
- }
- if hasAVX2 {
- for t := 0; t < n; t++ {
- idx := base + t*stride
- b := &w[idx]
- p := &wp[idx]
- sums[t] += dotQ6KInnerAVX2(&b.QL[0], &b.QH[0], &p.S[0], x)
- x128 := (*float32)(unsafe.Add(xp, 128*4))
- sums[t] += dotQ6KInnerAVX2(&b.QL[64], &b.QH[32], &p.S[8], x128)
- }
- return
- }
- for half := 0; half < 2; half++ {
- xBase := unsafe.Add(xp, uintptr(half*128)*4)
- qlPtr := half * 64
- qhPtr := half * 32
- scBase := half * 8
- for l := 0; l < 32; l++ {
- x0 := *(*float32)(unsafe.Add(xBase, uintptr(l)*4))
- x1 := *(*float32)(unsafe.Add(xBase, uintptr(32+l)*4))
- x2 := *(*float32)(unsafe.Add(xBase, uintptr(64+l)*4))
- x3 := *(*float32)(unsafe.Add(xBase, uintptr(96+l)*4))
- is := l / 16
- for t := 0; t < n; t++ {
- idx := base + t*stride
- b := &w[idx]
- p := &wp[idx]
- ql0 := b.QL[qlPtr+l]
- ql32 := b.QL[qlPtr+l+32]
- qh := b.QH[qhPtr+l]
- q1 := int8((ql0&0xF)|(((qh>>0)&3)<<4)) - 32
- q2 := int8((ql32&0xF)|(((qh>>2)&3)<<4)) - 32
- q3 := int8((ql0>>4)|(((qh>>4)&3)<<4)) - 32
- q4 := int8((ql32>>4)|(((qh>>6)&3)<<4)) - 32
- s0 := p.S[scBase+is+0]
- s2 := p.S[scBase+is+2]
- s4 := p.S[scBase+is+4]
- s6 := p.S[scBase+is+6]
- sums[t] += x0*s0*float32(q1) + x1*s2*float32(q2) + x2*s4*float32(q3) + x3*s6*float32(q4)
- }
- }
- }
- }
- type Q5KDotParams struct {
- D1 [4]float32
- M1 [4]float32
- D2 [4]float32
- M2 [4]float32
- }
- type q5kParamsKey struct {
- p unsafe.Pointer
- n int
- }
- var q5kDotParamsCache sync.Map
- func GetQ5KDotParams(blocks []BlockQ5_K) []Q5KDotParams {
- if len(blocks) == 0 {
- return nil
- }
- key := q5kParamsKey{p: unsafe.Pointer(&blocks[0]), n: len(blocks)}
- if v, ok := q5kDotParamsCache.Load(key); ok {
- return v.([]Q5KDotParams)
- }
- params := make([]Q5KDotParams, len(blocks))
- for bi := range blocks {
- b := &blocks[bi]
- d := FP16ToFP32(b.D)
- dmin := FP16ToFP32(b.DMin)
- var p Q5KDotParams
- seg := 0
- is := 0
- for j := 0; j < QK_K; j += 64 {
- sc, m := getScaleMinK4(is+0, &b.Scales)
- p.D1[seg] = d * float32(sc)
- p.M1[seg] = dmin * float32(m)
- sc, m = getScaleMinK4(is+1, &b.Scales)
- p.D2[seg] = d * float32(sc)
- p.M2[seg] = dmin * float32(m)
- seg++
- is += 2
- }
- params[bi] = p
- }
- q5kDotParamsCache.Store(key, params)
- return params
- }
- func DotQ5_K_Params(b *BlockQ5_K, p *Q5KDotParams, x []float32) float32 {
- if len(x) != QK_K {
- panic("DotQ5_K_Params: mismatched slice length")
- }
- var sum float32
- qsPtr := 0
- qh := b.QH[:]
- for seg := 0; seg < 4; seg++ {
- d1, m1 := p.D1[seg], p.M1[seg]
- d2, m2 := p.D2[seg], p.M2[seg]
- outPtr := seg * 64
- shift1 := uint(2 * seg)
- shift2 := uint(2*seg + 1)
- if hasAVX512 {
- sum += dotQ5KInnerAVX512(&b.QS[qsPtr], &b.QH[0], &x[outPtr], d1, m1, d2, m2, shift1, shift2)
- qsPtr += 32
- continue
- }
- if hasAVX2 {
- sum += dotQ5KInnerAVX2(&b.QS[qsPtr], &b.QH[0], &x[outPtr], d1, m1, d2, m2, shift1, shift2)
- qsPtr += 32
- continue
- }
- u1 := uint8(1 << shift1)
- u2 := uint8(1 << shift2)
- for l := 0; l < 32; l++ {
- v := int(b.QS[qsPtr+l] & 0xF)
- if (qh[l] & u1) != 0 {
- v += 16
- }
- sum += x[outPtr+l] * (d1*float32(v) - m1)
- }
- for l := 0; l < 32; l++ {
- v := int(b.QS[qsPtr+l] >> 4)
- if (qh[l] & u2) != 0 {
- v += 16
- }
- sum += x[outPtr+32+l] * (d2*float32(v) - m2)
- }
- qsPtr += 32
- }
- return sum
- }
- func DotQ5_K_ParamsPtr(b *BlockQ5_K, p *Q5KDotParams, x *float32) float32 {
- var sum float32
- xp := unsafe.Pointer(x)
- qsPtr := 0
- qh := b.QH[:]
- for seg := 0; seg < 4; seg++ {
- d1, m1 := p.D1[seg], p.M1[seg]
- d2, m2 := p.D2[seg], p.M2[seg]
- outPtr := seg * 64
- shift1 := uint(2 * seg)
- shift2 := uint(2*seg + 1)
- if hasAVX512 {
- xSeg := (*float32)(unsafe.Add(xp, uintptr(outPtr)*4))
- sum += dotQ5KInnerAVX512(&b.QS[qsPtr], &b.QH[0], xSeg, d1, m1, d2, m2, shift1, shift2)
- qsPtr += 32
- continue
- }
- if hasAVX2 {
- xSeg := (*float32)(unsafe.Add(xp, uintptr(outPtr)*4))
- sum += dotQ5KInnerAVX2(&b.QS[qsPtr], &b.QH[0], xSeg, d1, m1, d2, m2, shift1, shift2)
- qsPtr += 32
- continue
- }
- u1 := uint8(1 << shift1)
- u2 := uint8(1 << shift2)
- for l := 0; l < 32; l++ {
- v := int(b.QS[qsPtr+l] & 0xF)
- if (qh[l] & u1) != 0 {
- v += 16
- }
- x0 := *(*float32)(unsafe.Add(xp, uintptr(outPtr+l)*4))
- sum += x0 * (d1*float32(v) - m1)
- }
- for l := 0; l < 32; l++ {
- v := int(b.QS[qsPtr+l] >> 4)
- if (qh[l] & u2) != 0 {
- v += 16
- }
- x1 := *(*float32)(unsafe.Add(xp, uintptr(outPtr+32+l)*4))
- sum += x1 * (d2*float32(v) - m2)
- }
- qsPtr += 32
- }
- return sum
- }
- func DotQ5KTile8(sums *[8]float32, w []BlockQ5_K, wp []Q5KDotParams, base int, stride int, x *float32, n int) {
- if n <= 0 {
- return
- }
- xp := unsafe.Pointer(x)
- for seg := 0; seg < 4; seg++ {
- xSeg := (*float32)(unsafe.Add(xp, uintptr(seg*64)*4))
- qsPtr := seg * 32
- shift1 := uint(2 * seg)
- shift2 := uint(2*seg + 1)
- u1 := uint8(1 << shift1)
- u2 := uint8(1 << shift2)
- xsp := unsafe.Pointer(xSeg)
- for t := 0; t < n; t++ {
- idx := base + t*stride
- b := &w[idx]
- p := &wp[idx]
- d1, m1 := p.D1[seg], p.M1[seg]
- d2, m2 := p.D2[seg], p.M2[seg]
- if hasAVX512 {
- sums[t] += dotQ5KInnerAVX512(&b.QS[qsPtr], &b.QH[0], xSeg, d1, m1, d2, m2, shift1, shift2)
- continue
- }
- if hasAVX2 {
- sums[t] += dotQ5KInnerAVX2(&b.QS[qsPtr], &b.QH[0], xSeg, d1, m1, d2, m2, shift1, shift2)
- continue
- }
- qh := b.QH[:]
- for l := 0; l < 32; l++ {
- v0 := int(b.QS[qsPtr+l] & 0xF)
- if (qh[l] & u1) != 0 {
- v0 += 16
- }
- v1 := int(b.QS[qsPtr+l] >> 4)
- if (qh[l] & u2) != 0 {
- v1 += 16
- }
- x0 := *(*float32)(unsafe.Add(xsp, uintptr(l)*4))
- x1 := *(*float32)(unsafe.Add(xsp, uintptr(32+l)*4))
- sums[t] += x0*(d1*float32(v0)-m1) + x1*(d2*float32(v1)-m2)
- }
- }
- }
- }
- func DequantizeQ5_K(b *BlockQ5_K, out []float32) {
- d := FP16ToFP32(b.D)
- min := FP16ToFP32(b.DMin)
- outIdx := 0
- ql := b.QS[:]
- qh := b.QH[:]
- is := 0
- u1 := uint8(1)
- u2 := uint8(2)
- for j := 0; j < QK_K; j += 64 {
- sc, m := getScaleMinK4(is+0, &b.Scales)
- d1 := d * float32(sc)
- m1 := min * float32(m)
- sc, m = getScaleMinK4(is+1, &b.Scales)
- d2 := d * float32(sc)
- m2 := min * float32(m)
- for l := 0; l < 32; l++ {
- v := int(ql[l] & 0xF)
- if (qh[l] & u1) != 0 {
- v += 16
- }
- out[outIdx] = d1*float32(v) - m1
- outIdx++
- }
- for l := 0; l < 32; l++ {
- v := int(ql[l] >> 4)
- if (qh[l] & u2) != 0 {
- v += 16
- }
- out[outIdx] = d2*float32(v) - m2
- outIdx++
- }
- ql = ql[32:]
- is += 2
- u1 <<= 2
- u2 <<= 2
- }
- }
- func DotQ5_K(b *BlockQ5_K, x []float32) float32 {
- if len(x) != QK_K {
- panic("DotQ5_K: mismatched slice length")
- }
- d := FP16ToFP32(b.D)
- min := FP16ToFP32(b.DMin)
- ql := b.QS[:]
- qh := b.QH[:]
- is := 0
- u1 := uint8(1)
- u2 := uint8(2)
- var sum float32
- outIdx := 0
- for j := 0; j < QK_K; j += 64 {
- sc, m := getScaleMinK4(is+0, &b.Scales)
- d1 := d * float32(sc)
- m1 := min * float32(m)
- sc, m = getScaleMinK4(is+1, &b.Scales)
- d2 := d * float32(sc)
- m2 := min * float32(m)
- for l := 0; l < 32; l++ {
- v := int(ql[l] & 0xF)
- if (qh[l] & u1) != 0 {
- v += 16
- }
- sum += x[outIdx] * (d1*float32(v) - m1)
- outIdx++
- }
- for l := 0; l < 32; l++ {
- v := int(ql[l] >> 4)
- if (qh[l] & u2) != 0 {
- v += 16
- }
- sum += x[outIdx] * (d2*float32(v) - m2)
- outIdx++
- }
- ql = ql[32:]
- is += 2
- u1 <<= 2
- u2 <<= 2
- }
- return sum
- }
- type Q4KDotParams struct {
- D1 [4]float32
- M1 [4]float32
- D2 [4]float32
- M2 [4]float32
- }
- type q4kParamsKey struct {
- p unsafe.Pointer
- n int
- }
- var q4kDotParamsCache sync.Map
- func GetQ4KDotParams(blocks []BlockQ4_K) []Q4KDotParams {
- if len(blocks) == 0 {
- return nil
- }
- key := q4kParamsKey{p: unsafe.Pointer(&blocks[0]), n: len(blocks)}
- if v, ok := q4kDotParamsCache.Load(key); ok {
- return v.([]Q4KDotParams)
- }
- params := make([]Q4KDotParams, len(blocks))
- for bi := range blocks {
- b := &blocks[bi]
- d := FP16ToFP32(b.D)
- dmin := FP16ToFP32(b.DMin)
- var sc [8]uint8
- var m [8]uint8
- for j := 0; j < 4; j++ {
- sc[j] = b.Scales[j] & 63
- m[j] = b.Scales[j+4] & 63
- }
- for j := 4; j < 8; j++ {
- sc[j] = (b.Scales[j+4] & 0xF) | ((b.Scales[j-4] >> 6) << 4)
- m[j] = (b.Scales[j+4] >> 4) | ((b.Scales[j-0] >> 6) << 4)
- }
- var p Q4KDotParams
- seg := 0
- for i := 0; i < 8; i += 2 {
- p.D1[seg] = d * float32(sc[i])
- p.M1[seg] = dmin * float32(m[i])
- p.D2[seg] = d * float32(sc[i+1])
- p.M2[seg] = dmin * float32(m[i+1])
- seg++
- }
- params[bi] = p
- }
- q4kDotParamsCache.Store(key, params)
- return params
- }
- func DotQ4_K_Params(b *BlockQ4_K, p *Q4KDotParams, x []float32) float32 {
- if len(x) != QK_K {
- panic("DotQ4_K_Params: mismatched slice length")
- }
- var sum float32
- outPtr := 0
- qsPtr := 0
- for seg := 0; seg < 4; seg++ {
- d1, m1 := p.D1[seg], p.M1[seg]
- d2, m2 := p.D2[seg], p.M2[seg]
- if hasAVX512 {
- sum += dotQ4KInnerAVX512(&b.QS[qsPtr], &x[outPtr], d1, m1, d2, m2)
- } else if hasAVX2 {
- sum += dotQ4KInnerAVX2(&b.QS[qsPtr], &x[outPtr], d1, m1, d2, m2)
- } else {
- for l := 0; l < 32; l++ {
- val := b.QS[qsPtr+l]
- v1 := val & 0xF
- v2 := val >> 4
- sum += x[outPtr] * (float32(v1)*d1 - m1)
- sum += x[outPtr+32] * (float32(v2)*d2 - m2)
- outPtr++
- }
- outPtr += 32
- qsPtr += 32
- continue
- }
- outPtr += 64
- qsPtr += 32
- }
- return sum
- }
- func DotQ4_K_ParamsPtr(b *BlockQ4_K, p *Q4KDotParams, x *float32) float32 {
- var sum float32
- xp := unsafe.Pointer(x)
- outPtr := 0
- qsPtr := 0
- for seg := 0; seg < 4; seg++ {
- d1, m1 := p.D1[seg], p.M1[seg]
- d2, m2 := p.D2[seg], p.M2[seg]
- xSeg := (*float32)(unsafe.Add(xp, uintptr(outPtr)*4))
- if hasAVX512 {
- sum += dotQ4KInnerAVX512(&b.QS[qsPtr], xSeg, d1, m1, d2, m2)
- } else if hasAVX2 {
- sum += dotQ4KInnerAVX2(&b.QS[qsPtr], xSeg, d1, m1, d2, m2)
- } else {
- for l := 0; l < 32; l++ {
- val := b.QS[qsPtr+l]
- v1 := val & 0xF
- v2 := val >> 4
- x0 := *(*float32)(unsafe.Add(xp, uintptr(outPtr)*4))
- x1 := *(*float32)(unsafe.Add(xp, uintptr(outPtr+32)*4))
- sum += x0 * (float32(v1)*d1 - m1)
- sum += x1 * (float32(v2)*d2 - m2)
- outPtr++
- }
- outPtr += 32
- qsPtr += 32
- continue
- }
- outPtr += 64
- qsPtr += 32
- }
- return sum
- }
- func DotQ4KTile8(sums *[8]float32, w []BlockQ4_K, wp []Q4KDotParams, base int, stride int, x *float32, n int) {
- if n <= 0 {
- return
- }
- xp := unsafe.Pointer(x)
- outPtr := 0
- qsPtr := 0
- for seg := 0; seg < 4; seg++ {
- xSeg := (*float32)(unsafe.Add(xp, uintptr(outPtr)*4))
- for t := 0; t < n; t++ {
- idx := base + t*stride
- b := &w[idx]
- p := &wp[idx]
- d1, m1 := p.D1[seg], p.M1[seg]
- d2, m2 := p.D2[seg], p.M2[seg]
- if hasAVX512 {
- sums[t] += dotQ4KInnerAVX512(&b.QS[qsPtr], xSeg, d1, m1, d2, m2)
- } else if hasAVX2 {
- sums[t] += dotQ4KInnerAVX2(&b.QS[qsPtr], xSeg, d1, m1, d2, m2)
- } else {
- xsp := unsafe.Pointer(xSeg)
- for l := 0; l < 32; l++ {
- val := b.QS[qsPtr+l]
- v1 := val & 0xF
- v2 := val >> 4
- x0 := *(*float32)(unsafe.Add(xsp, uintptr(l)*4))
- x1 := *(*float32)(unsafe.Add(xsp, uintptr(32+l)*4))
- sums[t] += x0*(float32(v1)*d1-m1) + x1*(float32(v2)*d2-m2)
- }
- }
- }
- outPtr += 64
- qsPtr += 32
- }
- }
- // BlockQ8_K represents a block of 256 weights quantized to 8 bits
- // Layout (292 bytes):
- // - D (4 bytes): float32 scale
- // - QS (256 bytes): 256 int8 quants
- // - BSums (32 bytes): 16 int16 block sums (for dot product optimization, not used in dequant)
- type BlockQ8_K struct {
- D float32
- QS [256]int8
- BSums [16]int16
- }
- // DequantizeQ8_K dequantizes a single Q8_K block into 256 floats
- func DequantizeQ8_K(b *BlockQ8_K, out []float32) {
- if dequantQ8KSimd(b, out) {
- return
- }
- d := b.D
- for i := 0; i < 256; i++ {
- out[i] = d * float32(b.QS[i])
- }
- }
- func DotQ8_K(b *BlockQ8_K, x []float32) float32 {
- if len(x) != QK_K {
- panic("DotQ8_K: mismatched slice length")
- }
- if sum, ok := dotQ8KSimd(b, x); ok {
- return sum
- }
- d := b.D
- var sum float32
- for i := 0; i < 256; i++ {
- sum += x[i] * float32(b.QS[i])
- }
- return d * sum
- }
- // BlockQ2_K represents a block of 256 weights quantized to 2 bits
- // Layout (84 bytes):
- // - Scales (16 bytes): 16 4-bit scales and 16 4-bit mins packed
- // - QS (64 bytes): 256 2-bit quants packed
- // - D (2 bytes): float16 super-scale
- // - DMin (2 bytes): float16 super-min-scale
- type BlockQ2_K struct {
- Scales [16]uint8
- QS [64]uint8
- D uint16
- DMin uint16
- }
- type Q2KDotParams struct {
- DL [16]float32
- ML [16]float32
- }
- type q2kParamsKey struct {
- p unsafe.Pointer
- n int
- }
- var q2kDotParamsCache sync.Map
- func GetQ2KDotParams(blocks []BlockQ2_K) []Q2KDotParams {
- if len(blocks) == 0 {
- return nil
- }
- key := q2kParamsKey{p: unsafe.Pointer(&blocks[0]), n: len(blocks)}
- if v, ok := q2kDotParamsCache.Load(key); ok {
- return v.([]Q2KDotParams)
- }
- params := make([]Q2KDotParams, len(blocks))
- for bi := range blocks {
- b := &blocks[bi]
- d := FP16ToFP32(b.D)
- dmin := FP16ToFP32(b.DMin)
- var p Q2KDotParams
- pi := 0
- is := 0
- for n := 0; n < QK_K; n += 128 {
- _ = n
- for shift := uint(0); shift < 8; shift += 2 {
- sc := b.Scales[is]
- is++
- p.DL[pi] = d * float32(sc&0xF)
- p.ML[pi] = dmin * float32(sc>>4)
- pi++
- sc = b.Scales[is]
- is++
- p.DL[pi] = d * float32(sc&0xF)
- p.ML[pi] = dmin * float32(sc>>4)
- pi++
- }
- }
- params[bi] = p
- }
- q2kDotParamsCache.Store(key, params)
- return params
- }
- // DequantizeQ2_K dequantizes a single Q2_K block into 256 floats
- // Mirrors llama.cpp dequantize_row_q2_K
- func DequantizeQ2_K(b *BlockQ2_K, out []float32) {
- if dequantQ2KSimd(b, out) {
- return
- }
- d := FP16ToFP32(b.D)
- dmin := FP16ToFP32(b.DMin)
- q := b.QS[:]
- is := 0
- outIdx := 0
- for n := 0; n < QK_K; n += 128 {
- shift := uint(0)
- for j := 0; j < 4; j++ {
- sc := b.Scales[is]
- is++
- dl := d * float32(sc&0xF)
- ml := dmin * float32(sc>>4)
- for l := 0; l < 16; l++ {
- val := int8((q[l] >> shift) & 3)
- out[outIdx] = dl*float32(val) - ml
- outIdx++
- }
- sc = b.Scales[is]
- is++
- dl = d * float32(sc&0xF)
- ml = dmin * float32(sc>>4)
- for l := 0; l < 16; l++ {
- val := int8((q[l+16] >> shift) & 3)
- out[outIdx] = dl*float32(val) - ml
- outIdx++
- }
- shift += 2
- }
- q = q[32:]
- }
- }
- func DotQ2_K(b *BlockQ2_K, x []float32) float32 {
- if len(x) != QK_K {
- panic("DotQ2_K: mismatched slice length")
- }
- if sum, ok := dotQ2KSimd(b, x); ok {
- return sum
- }
- d := FP16ToFP32(b.D)
- dmin := FP16ToFP32(b.DMin)
- q := b.QS[:]
- is := 0
- outIdx := 0
- var sum float32
- for n := 0; n < QK_K; n += 128 {
- shift := uint(0)
- for j := 0; j < 4; j++ {
- sc := b.Scales[is]
- is++
- dl := d * float32(sc&0xF)
- ml := dmin * float32(sc>>4)
- for l := 0; l < 16; l++ {
- val := float32(int8((q[l] >> shift) & 3))
- sum += x[outIdx] * (dl*val - ml)
- outIdx++
- }
- sc = b.Scales[is]
- is++
- dl = d * float32(sc&0xF)
- ml = dmin * float32(sc>>4)
- for l := 0; l < 16; l++ {
- val := float32(int8((q[l+16] >> shift) & 3))
- sum += x[outIdx] * (dl*val - ml)
- outIdx++
- }
- shift += 2
- }
- q = q[32:]
- }
- return sum
- }
- // BlockQ3_K represents a block of 256 weights quantized to 3 bits
- // Layout (110 bytes):
- // - HMask (32 bytes): high bit of 3-bit quants
- // - QS (64 bytes): low 2 bits of 3-bit quants
- // - Scales (12 bytes): 6-bit scales packed
- // - D (2 bytes): float16 super-scale
- type BlockQ3_K struct {
- HMask [32]uint8
- QS [64]uint8
- Scales [12]uint8
- D uint16
- }
- type Q3KDotParams struct {
- S [16]float32
- }
- type q3kParamsKey struct {
- p unsafe.Pointer
- n int
- }
- var q3kDotParamsCache sync.Map
- func GetQ3KDotParams(blocks []BlockQ3_K) []Q3KDotParams {
- if len(blocks) == 0 {
- return nil
- }
- key := q3kParamsKey{p: unsafe.Pointer(&blocks[0]), n: len(blocks)}
- if v, ok := q3kDotParamsCache.Load(key); ok {
- return v.([]Q3KDotParams)
- }
- params := make([]Q3KDotParams, len(blocks))
- for bi := range blocks {
- b := &blocks[bi]
- d := FP16ToFP32(b.D)
- var p Q3KDotParams
- for i := 0; i < 16; i++ {
- p.S[i] = d * float32(unpackQ3Scale(b.Scales[:], i))
- }
- params[bi] = p
- }
- q3kDotParamsCache.Store(key, params)
- return params
- }
- // DequantizeQ3_K dequantizes a single Q3_K block into 256 floats
- func DequantizeQ3_K(b *BlockQ3_K, out []float32) {
- if dequantQ3KSimd(b, out) {
- return
- }
- d := FP16ToFP32(b.D)
- var scales [16]int8
- for i := 0; i < 16; i++ {
- scales[i] = unpackQ3Scale(b.Scales[:], i)
- }
- q := b.QS[:]
- hm := b.HMask[:]
- outIdx := 0
- is := 0
- m := uint8(1)
- for n := 0; n < QK_K; n += 128 {
- shift := uint(0)
- for j := 0; j < 4; j++ {
- dl := d * float32(scales[is])
- is++
- for l := 0; l < 16; l++ {
- qv := int8((q[l] >> shift) & 0x3)
- if hm[l]&m == 0 {
- qv -= 4
- }
- out[outIdx] = dl * float32(qv)
- outIdx++
- }
- dl = d * float32(scales[is])
- is++
- for l := 0; l < 16; l++ {
- qv := int8((q[l+16] >> shift) & 0x3)
- if hm[l+16]&m == 0 {
- qv -= 4
- }
- out[outIdx] = dl * float32(qv)
- outIdx++
- }
- shift += 2
- m <<= 1
- }
- q = q[32:]
- }
- }
- func DotQ3_K(b *BlockQ3_K, x []float32) float32 {
- if len(x) != QK_K {
- panic("DotQ3_K: mismatched slice length")
- }
- if sum, ok := dotQ3KSimd(b, x); ok {
- return sum
- }
- d := FP16ToFP32(b.D)
- var scales [16]int8
- for i := 0; i < 16; i++ {
- scales[i] = unpackQ3Scale(b.Scales[:], i)
- }
- q := b.QS[:]
- hm := b.HMask[:]
- outIdx := 0
- is := 0
- m := uint8(1)
- var sum float32
- for n := 0; n < QK_K; n += 128 {
- shift := uint(0)
- for j := 0; j < 4; j++ {
- dl := d * float32(scales[is])
- is++
- for l := 0; l < 16; l++ {
- qv := int8((q[l] >> shift) & 0x3)
- if hm[l]&m == 0 {
- qv -= 4
- }
- sum += x[outIdx] * (dl * float32(qv))
- outIdx++
- }
- dl = d * float32(scales[is])
- is++
- for l := 0; l < 16; l++ {
- qv := int8((q[l+16] >> shift) & 0x3)
- if hm[l+16]&m == 0 {
- qv -= 4
- }
- sum += x[outIdx] * (dl * float32(qv))
- outIdx++
- }
- shift += 2
- m <<= 1
- }
- q = q[32:]
- }
- return sum
- }
- // BlockQ6_K represents a block of 256 weights quantized to 6 bits
- // Layout (210 bytes):
- // - QL (128 bytes): lower 4 bits of 6-bit quants
- // - QH (64 bytes): upper 2 bits of 6-bit quants
- // - Scales (16 bytes): 8-bit signed scales
- // - D (2 bytes): float16 super-scale
- type BlockQ6_K struct {
- QL [128]uint8
- QH [64]uint8
- Scales [16]int8
- D uint16
- }
- type Q6KDotParams struct {
- S [16]float32
- }
- type q6kParamsKey struct {
- p unsafe.Pointer
- n int
- }
- var q6kDotParamsCache sync.Map
- func GetQ6KDotParams(blocks []BlockQ6_K) []Q6KDotParams {
- if len(blocks) == 0 {
- return nil
- }
- key := q6kParamsKey{p: unsafe.Pointer(&blocks[0]), n: len(blocks)}
- if v, ok := q6kDotParamsCache.Load(key); ok {
- return v.([]Q6KDotParams)
- }
- params := make([]Q6KDotParams, len(blocks))
- for bi := range blocks {
- b := &blocks[bi]
- d := FP16ToFP32(b.D)
- var p Q6KDotParams
- for i := 0; i < 16; i++ {
- p.S[i] = d * float32(b.Scales[i])
- }
- params[bi] = p
- }
- q6kDotParamsCache.Store(key, params)
- return params
- }
- // DequantizeQ6_K dequantizes a single Q6_K block into 256 floats
- // Logic adapted from llama.cpp's dequantize_row_q6_K
- func DequantizeQ6_K(b *BlockQ6_K, out []float32) {
- if dequantQ6KSimd(b, out) {
- return
- }
- d := FP16ToFP32(b.D)
- qlPtr := 0
- qhPtr := 0
- scPtr := 0
- outPtr := 0
- for n := 0; n < 256; n += 128 {
- for l := 0; l < 32; l++ {
- is := l / 16
- ql0 := b.QL[qlPtr+l]
- ql32 := b.QL[qlPtr+l+32]
- qh := b.QH[qhPtr+l]
- q1 := int8((ql0&0xF)|(((qh>>0)&3)<<4)) - 32
- q2 := int8((ql32&0xF)|(((qh>>2)&3)<<4)) - 32
- q3 := int8((ql0>>4)|(((qh>>4)&3)<<4)) - 32
- q4 := int8((ql32>>4)|(((qh>>6)&3)<<4)) - 32
- out[outPtr+l+0] = d * float32(b.Scales[scPtr+is+0]) * float32(q1)
- out[outPtr+l+32] = d * float32(b.Scales[scPtr+is+2]) * float32(q2)
- out[outPtr+l+64] = d * float32(b.Scales[scPtr+is+4]) * float32(q3)
- out[outPtr+l+96] = d * float32(b.Scales[scPtr+is+6]) * float32(q4)
- }
- outPtr += 128
- qlPtr += 64
- qhPtr += 32
- scPtr += 8
- }
- }
- func DotQ6_K(b *BlockQ6_K, x []float32) float32 {
- if len(x) != QK_K {
- panic("DotQ6_K: mismatched slice length")
- }
- if sum, ok := dotQ6KSimd(b, x); ok {
- return sum
- }
- d := FP16ToFP32(b.D)
- qlPtr := 0
- qhPtr := 0
- scPtr := 0
- outPtr := 0
- var sum float32
- for n := 0; n < 256; n += 128 {
- for l := 0; l < 32; l++ {
- is := l / 16
- ql0 := b.QL[qlPtr+l]
- ql32 := b.QL[qlPtr+l+32]
- qh := b.QH[qhPtr+l]
- q1 := int8((ql0&0xF)|(((qh>>0)&3)<<4)) - 32
- q2 := int8((ql32&0xF)|(((qh>>2)&3)<<4)) - 32
- q3 := int8((ql0>>4)|(((qh>>4)&3)<<4)) - 32
- q4 := int8((ql32>>4)|(((qh>>6)&3)<<4)) - 32
- s0 := d * float32(b.Scales[scPtr+is+0])
- s2 := d * float32(b.Scales[scPtr+is+2])
- s4 := d * float32(b.Scales[scPtr+is+4])
- s6 := d * float32(b.Scales[scPtr+is+6])
- base := outPtr + l
- sum += x[base+0] * s0 * float32(q1)
- sum += x[base+32] * s2 * float32(q2)
- sum += x[base+64] * s4 * float32(q3)
- sum += x[base+96] * s6 * float32(q4)
- }
- outPtr += 128
- qlPtr += 64
- qhPtr += 32
- scPtr += 8
- }
- return sum
- }
|