package tensor import ( "math" "sync" "unsafe" ) // QK_K is the super-block size for K-quants const QK_K = 256 // unpackQ3Scale extracts a signed 6-bit scale from packed bytes (llama.cpp layout) func unpackQ3Scale(packed []byte, idx int) int8 { var sc uint8 if idx < 8 { sc = packed[idx] & 0xF } else { sc = packed[idx-8] >> 4 } sc |= ((packed[8+(idx%4)] >> (2 * (idx / 4))) & 0x3) << 4 return int8(sc) - 32 } // BlockQ4_K represents a block of 256 weights quantized to 4 bits with super-block scales // Layout (144 bytes): // - D (2 bytes): float16 super-scale // - DMin (2 bytes): float16 super-min-scale // - Scales (12 bytes): 8 6-bit scales and 8 6-bit mins packed // - QS (128 bytes): 256 4-bit quants type BlockQ4_K struct { D uint16 DMin uint16 Scales [12]uint8 QS [128]uint8 } type BlockQ5_K struct { D uint16 DMin uint16 Scales [12]uint8 QH [32]uint8 QS [128]uint8 } func getScaleMinK4(j int, q *[12]uint8) (d uint8, m uint8) { if j < 4 { d = q[j] & 63 m = q[j+4] & 63 return d, m } d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4) m = (q[j+4] >> 4) | ((q[j] >> 6) << 4) return d, m } // FP16ToFP32 converts a float16 (as uint16) to float32. // Implements full IEEE 754 half-precision conversion. func FP16ToFP32(n uint16) float32 { sign := uint32(n&0x8000) << 16 exp := uint32(n&0x7C00) >> 10 mant := uint32(n & 0x03FF) // Normalized case (most common for model weights) if exp > 0 && exp < 0x1F { return math.Float32frombits(sign | ((exp + 112) << 23) | (mant << 13)) } // Zero or Denormalized if exp == 0 { if mant == 0 { return math.Float32frombits(sign) } // Denormalized number // Renormalize: multiply by 2^(-14) // 1024.0 is 2^10 m := float32(mant) / 1024.0 val := m * float32(math.Pow(2, -14)) if sign != 0 { val = -val } return val } // Infinity or NaN (exp == 0x1F) if mant == 0 { return math.Float32frombits(sign | 0x7F800000) // Infinity } return math.Float32frombits(sign | 0x7FC00000 | (mant << 13)) // NaN } // DequantizeQ4_K dequantizes a single Q4_K block into 256 floats func DequantizeQ4_K(b *BlockQ4_K, out []float32) { if dequantQ4KSimd(b, out) { return } d := FP16ToFP32(b.D) dmin := FP16ToFP32(b.DMin) var sc [8]uint8 var m [8]uint8 for j := 0; j < 4; j++ { sc[j] = b.Scales[j] & 63 m[j] = b.Scales[j+4] & 63 } for j := 4; j < 8; j++ { sc[j] = (b.Scales[j+4] & 0xF) | ((b.Scales[j-4] >> 6) << 4) m[j] = (b.Scales[j+4] >> 4) | ((b.Scales[j-0] >> 6) << 4) } outPtr := 0 qsPtr := 0 for i := 0; i < 8; i += 2 { d1 := d * float32(sc[i]) m1 := dmin * float32(m[i]) d2 := d * float32(sc[i+1]) m2 := dmin * float32(m[i+1]) for l := 0; l < 32; l++ { val := b.QS[qsPtr+l] v1 := val & 0xF v2 := val >> 4 out[outPtr] = float32(v1)*d1 - m1 out[outPtr+32] = float32(v2)*d2 - m2 outPtr++ } outPtr += 32 qsPtr += 32 } } func DotQ4_K(b *BlockQ4_K, x []float32) float32 { if len(x) != QK_K { panic("DotQ4_K: mismatched slice length") } if sum, ok := dotQ4KSimd(b, x); ok { return sum } d := FP16ToFP32(b.D) dmin := FP16ToFP32(b.DMin) var sc [8]uint8 var m [8]uint8 for j := 0; j < 4; j++ { sc[j] = b.Scales[j] & 63 m[j] = b.Scales[j+4] & 63 } for j := 4; j < 8; j++ { sc[j] = (b.Scales[j+4] & 0xF) | ((b.Scales[j-4] >> 6) << 4) m[j] = (b.Scales[j+4] >> 4) | ((b.Scales[j-0] >> 6) << 4) } var sum float32 outPtr := 0 qsPtr := 0 for i := 0; i < 8; i += 2 { d1 := d * float32(sc[i]) m1 := dmin * float32(m[i]) d2 := d * float32(sc[i+1]) m2 := dmin * float32(m[i+1]) for l := 0; l < 32; l++ { val := b.QS[qsPtr+l] v1 := val & 0xF v2 := val >> 4 sum += x[outPtr] * (float32(v1)*d1 - m1) sum += x[outPtr+32] * (float32(v2)*d2 - m2) outPtr++ } outPtr += 32 qsPtr += 32 } return sum } func DotQ2_K_Params(b *BlockQ2_K, p *Q2KDotParams, x []float32) float32 { if len(x) != QK_K { panic("DotQ2_K_Params: mismatched slice length") } if hasAVX2 { is := 0 xIdx := 0 qOffset := 0 var sum float32 for n := 0; n < QK_K; n += 128 { _ = n shift := uint(0) for j := 0; j < 4; j++ { dl := p.DL[is] ml := p.ML[is] is++ sum += dotQ2KInnerAVX2Fused(&b.QS[qOffset], &x[xIdx], dl, ml, shift) xIdx += 16 dl = p.DL[is] ml = p.ML[is] is++ sum += dotQ2KInnerAVX2Fused(&b.QS[qOffset+16], &x[xIdx], dl, ml, shift) xIdx += 16 shift += 2 } qOffset += 32 } return sum } q := b.QS[:] is := 0 outIdx := 0 var sum float32 for n := 0; n < QK_K; n += 128 { shift := uint(0) for j := 0; j < 4; j++ { dl := p.DL[is] ml := p.ML[is] is++ for l := 0; l < 16; l++ { val := float32(int8((q[l] >> shift) & 3)) sum += x[outIdx] * (dl*val - ml) outIdx++ } dl = p.DL[is] ml = p.ML[is] is++ for l := 0; l < 16; l++ { val := float32(int8((q[l+16] >> shift) & 3)) sum += x[outIdx] * (dl*val - ml) outIdx++ } shift += 2 } q = q[32:] } return sum } func DotQ2KTile8(sums *[8]float32, w []BlockQ2_K, wp []Q2KDotParams, base int, stride int, x *float32, n int) { if n <= 0 { return } xp := unsafe.Pointer(x) if hasAVX2 { for t := 0; t < n; t++ { idx := base + t*stride b := &w[idx] p := &wp[idx] is := 0 xIdx := 0 qOffset := 0 for nn := 0; nn < QK_K; nn += 128 { _ = nn shift := uint(0) for j := 0; j < 4; j++ { dl := p.DL[is] ml := p.ML[is] is++ xSeg := (*float32)(unsafe.Add(xp, uintptr(xIdx)*4)) sums[t] += dotQ2KInnerAVX2Fused(&b.QS[qOffset], xSeg, dl, ml, shift) xIdx += 16 dl = p.DL[is] ml = p.ML[is] is++ xSeg = (*float32)(unsafe.Add(xp, uintptr(xIdx)*4)) sums[t] += dotQ2KInnerAVX2Fused(&b.QS[qOffset+16], xSeg, dl, ml, shift) xIdx += 16 shift += 2 } qOffset += 32 } } return } xSlice := unsafe.Slice(x, QK_K) for t := 0; t < n; t++ { idx := base + t*stride b := &w[idx] p := &wp[idx] sums[t] += DotQ2_K_Params(b, p, xSlice) } } func DotQ8KTile8(sums *[8]float32, w []BlockQ8_K, base int, stride int, x *float32, n int) { if n <= 0 { return } if hasAVX512 { for t := 0; t < n; t++ { idx := base + t*stride sums[t] += dotQ8KAVX512(&w[idx], x) } return } if hasAVX2 { for t := 0; t < n; t++ { idx := base + t*stride sums[t] += dotQ8KAVX2(&w[idx], x) } return } xSlice := unsafe.Slice(x, QK_K) for t := 0; t < n; t++ { idx := base + t*stride sums[t] += DotQ8_K(&w[idx], xSlice) } } func DotQ3_K_Params(b *BlockQ3_K, p *Q3KDotParams, x []float32) float32 { if len(x) != QK_K { panic("DotQ3_K_Params: mismatched slice length") } if hasAVX2 { q := b.QS[:] hm := b.HMask[:] is := 0 xIdx := 0 m := uint8(1) var sum float32 for n := 0; n < QK_K; n += 128 { for j := 0; j < 4; j++ { dl := p.S[is] sum += dotQ3KInnerAVX2Fused(&q[0], &hm[0], &x[xIdx], dl, m, uint(j*2)) xIdx += 16 is++ dl = p.S[is] sum += dotQ3KInnerAVX2Fused(&q[16], &hm[16], &x[xIdx], dl, m, uint(j*2)) xIdx += 16 is++ m <<= 1 } q = q[32:] } return sum } q := b.QS[:] hm := b.HMask[:] outIdx := 0 is := 0 m := uint8(1) var sum float32 for n := 0; n < QK_K; n += 128 { shift := uint(0) for j := 0; j < 4; j++ { dl := p.S[is] is++ for l := 0; l < 16; l++ { qv := int8((q[l] >> shift) & 0x3) if hm[l]&m == 0 { qv -= 4 } sum += x[outIdx] * (dl * float32(qv)) outIdx++ } dl = p.S[is] is++ for l := 0; l < 16; l++ { qv := int8((q[l+16] >> shift) & 0x3) if hm[l+16]&m == 0 { qv -= 4 } sum += x[outIdx] * (dl * float32(qv)) outIdx++ } shift += 2 m <<= 1 } q = q[32:] } return sum } func DotQ3KTile8(sums *[8]float32, w []BlockQ3_K, wp []Q3KDotParams, base int, stride int, x *float32, n int) { if n <= 0 { return } xp := unsafe.Pointer(x) if hasAVX2 { for t := 0; t < n; t++ { idx := base + t*stride b := &w[idx] p := &wp[idx] q := b.QS[:] hm := b.HMask[:] is := 0 xIdx := 0 m := uint8(1) for nn := 0; nn < QK_K; nn += 128 { for j := 0; j < 4; j++ { dl := p.S[is] xSeg := (*float32)(unsafe.Add(xp, uintptr(xIdx)*4)) sums[t] += dotQ3KInnerAVX2Fused(&q[0], &hm[0], xSeg, dl, m, uint(j*2)) xIdx += 16 is++ dl = p.S[is] xSeg = (*float32)(unsafe.Add(xp, uintptr(xIdx)*4)) sums[t] += dotQ3KInnerAVX2Fused(&q[16], &hm[16], xSeg, dl, m, uint(j*2)) xIdx += 16 is++ m <<= 1 } q = q[32:] } } return } for t := 0; t < n; t++ { idx := base + t*stride b := &w[idx] p := &wp[idx] xIdx := 0 q := b.QS[:] hm := b.HMask[:] is := 0 m := uint8(1) for nn := 0; nn < QK_K; nn += 128 { shift := uint(0) for j := 0; j < 4; j++ { dl := p.S[is] is++ for l := 0; l < 16; l++ { qv := int8((q[l] >> shift) & 0x3) if hm[l]&m == 0 { qv -= 4 } x0 := *(*float32)(unsafe.Add(xp, uintptr(xIdx)*4)) sums[t] += x0 * (dl * float32(qv)) xIdx++ } dl = p.S[is] is++ for l := 0; l < 16; l++ { qv := int8((q[l+16] >> shift) & 0x3) if hm[l+16]&m == 0 { qv -= 4 } x0 := *(*float32)(unsafe.Add(xp, uintptr(xIdx)*4)) sums[t] += x0 * (dl * float32(qv)) xIdx++ } shift += 2 m <<= 1 } q = q[32:] } } } func DotQ6_K_Params(b *BlockQ6_K, p *Q6KDotParams, x []float32) float32 { if len(x) != QK_K { panic("DotQ6_K_Params: mismatched slice length") } if hasAVX512 { sum := dotQ6KInnerAVX512(&b.QL[0], &b.QH[0], &p.S[0], &x[0]) sum += dotQ6KInnerAVX512(&b.QL[64], &b.QH[32], &p.S[8], &x[128]) return sum } if hasAVX2 { sum := dotQ6KInnerAVX2(&b.QL[0], &b.QH[0], &p.S[0], &x[0]) sum += dotQ6KInnerAVX2(&b.QL[64], &b.QH[32], &p.S[8], &x[128]) return sum } qlPtr := 0 qhPtr := 0 outPtr := 0 var sum float32 for half := 0; half < 2; half++ { scBase := half * 8 for l := 0; l < 32; l++ { is := l / 16 ql0 := b.QL[qlPtr+l] ql32 := b.QL[qlPtr+l+32] qh := b.QH[qhPtr+l] q1 := int8((ql0&0xF)|(((qh>>0)&3)<<4)) - 32 q2 := int8((ql32&0xF)|(((qh>>2)&3)<<4)) - 32 q3 := int8((ql0>>4)|(((qh>>4)&3)<<4)) - 32 q4 := int8((ql32>>4)|(((qh>>6)&3)<<4)) - 32 s0 := p.S[scBase+is+0] s2 := p.S[scBase+is+2] s4 := p.S[scBase+is+4] s6 := p.S[scBase+is+6] base := outPtr + l sum += x[base+0] * s0 * float32(q1) sum += x[base+32] * s2 * float32(q2) sum += x[base+64] * s4 * float32(q3) sum += x[base+96] * s6 * float32(q4) } outPtr += 128 qlPtr += 64 qhPtr += 32 } return sum } func DotQ6KTile8(sums *[8]float32, w []BlockQ6_K, wp []Q6KDotParams, base int, stride int, x *float32, n int) { if n <= 0 { return } xp := unsafe.Pointer(x) if hasAVX512 { for t := 0; t < n; t++ { idx := base + t*stride b := &w[idx] p := &wp[idx] sums[t] += dotQ6KInnerAVX512(&b.QL[0], &b.QH[0], &p.S[0], x) x128 := (*float32)(unsafe.Add(xp, 128*4)) sums[t] += dotQ6KInnerAVX512(&b.QL[64], &b.QH[32], &p.S[8], x128) } return } if hasAVX2 { for t := 0; t < n; t++ { idx := base + t*stride b := &w[idx] p := &wp[idx] sums[t] += dotQ6KInnerAVX2(&b.QL[0], &b.QH[0], &p.S[0], x) x128 := (*float32)(unsafe.Add(xp, 128*4)) sums[t] += dotQ6KInnerAVX2(&b.QL[64], &b.QH[32], &p.S[8], x128) } return } for half := 0; half < 2; half++ { xBase := unsafe.Add(xp, uintptr(half*128)*4) qlPtr := half * 64 qhPtr := half * 32 scBase := half * 8 for l := 0; l < 32; l++ { x0 := *(*float32)(unsafe.Add(xBase, uintptr(l)*4)) x1 := *(*float32)(unsafe.Add(xBase, uintptr(32+l)*4)) x2 := *(*float32)(unsafe.Add(xBase, uintptr(64+l)*4)) x3 := *(*float32)(unsafe.Add(xBase, uintptr(96+l)*4)) is := l / 16 for t := 0; t < n; t++ { idx := base + t*stride b := &w[idx] p := &wp[idx] ql0 := b.QL[qlPtr+l] ql32 := b.QL[qlPtr+l+32] qh := b.QH[qhPtr+l] q1 := int8((ql0&0xF)|(((qh>>0)&3)<<4)) - 32 q2 := int8((ql32&0xF)|(((qh>>2)&3)<<4)) - 32 q3 := int8((ql0>>4)|(((qh>>4)&3)<<4)) - 32 q4 := int8((ql32>>4)|(((qh>>6)&3)<<4)) - 32 s0 := p.S[scBase+is+0] s2 := p.S[scBase+is+2] s4 := p.S[scBase+is+4] s6 := p.S[scBase+is+6] sums[t] += x0*s0*float32(q1) + x1*s2*float32(q2) + x2*s4*float32(q3) + x3*s6*float32(q4) } } } } type Q5KDotParams struct { D1 [4]float32 M1 [4]float32 D2 [4]float32 M2 [4]float32 } type q5kParamsKey struct { p unsafe.Pointer n int } var q5kDotParamsCache sync.Map func GetQ5KDotParams(blocks []BlockQ5_K) []Q5KDotParams { if len(blocks) == 0 { return nil } key := q5kParamsKey{p: unsafe.Pointer(&blocks[0]), n: len(blocks)} if v, ok := q5kDotParamsCache.Load(key); ok { return v.([]Q5KDotParams) } params := make([]Q5KDotParams, len(blocks)) for bi := range blocks { b := &blocks[bi] d := FP16ToFP32(b.D) dmin := FP16ToFP32(b.DMin) var p Q5KDotParams seg := 0 is := 0 for j := 0; j < QK_K; j += 64 { sc, m := getScaleMinK4(is+0, &b.Scales) p.D1[seg] = d * float32(sc) p.M1[seg] = dmin * float32(m) sc, m = getScaleMinK4(is+1, &b.Scales) p.D2[seg] = d * float32(sc) p.M2[seg] = dmin * float32(m) seg++ is += 2 } params[bi] = p } q5kDotParamsCache.Store(key, params) return params } func DotQ5_K_Params(b *BlockQ5_K, p *Q5KDotParams, x []float32) float32 { if len(x) != QK_K { panic("DotQ5_K_Params: mismatched slice length") } var sum float32 qsPtr := 0 qh := b.QH[:] for seg := 0; seg < 4; seg++ { d1, m1 := p.D1[seg], p.M1[seg] d2, m2 := p.D2[seg], p.M2[seg] outPtr := seg * 64 shift1 := uint(2 * seg) shift2 := uint(2*seg + 1) if hasAVX512 { sum += dotQ5KInnerAVX512(&b.QS[qsPtr], &b.QH[0], &x[outPtr], d1, m1, d2, m2, shift1, shift2) qsPtr += 32 continue } if hasAVX2 { sum += dotQ5KInnerAVX2(&b.QS[qsPtr], &b.QH[0], &x[outPtr], d1, m1, d2, m2, shift1, shift2) qsPtr += 32 continue } u1 := uint8(1 << shift1) u2 := uint8(1 << shift2) for l := 0; l < 32; l++ { v := int(b.QS[qsPtr+l] & 0xF) if (qh[l] & u1) != 0 { v += 16 } sum += x[outPtr+l] * (d1*float32(v) - m1) } for l := 0; l < 32; l++ { v := int(b.QS[qsPtr+l] >> 4) if (qh[l] & u2) != 0 { v += 16 } sum += x[outPtr+32+l] * (d2*float32(v) - m2) } qsPtr += 32 } return sum } func DotQ5_K_ParamsPtr(b *BlockQ5_K, p *Q5KDotParams, x *float32) float32 { var sum float32 xp := unsafe.Pointer(x) qsPtr := 0 qh := b.QH[:] for seg := 0; seg < 4; seg++ { d1, m1 := p.D1[seg], p.M1[seg] d2, m2 := p.D2[seg], p.M2[seg] outPtr := seg * 64 shift1 := uint(2 * seg) shift2 := uint(2*seg + 1) if hasAVX512 { xSeg := (*float32)(unsafe.Add(xp, uintptr(outPtr)*4)) sum += dotQ5KInnerAVX512(&b.QS[qsPtr], &b.QH[0], xSeg, d1, m1, d2, m2, shift1, shift2) qsPtr += 32 continue } if hasAVX2 { xSeg := (*float32)(unsafe.Add(xp, uintptr(outPtr)*4)) sum += dotQ5KInnerAVX2(&b.QS[qsPtr], &b.QH[0], xSeg, d1, m1, d2, m2, shift1, shift2) qsPtr += 32 continue } u1 := uint8(1 << shift1) u2 := uint8(1 << shift2) for l := 0; l < 32; l++ { v := int(b.QS[qsPtr+l] & 0xF) if (qh[l] & u1) != 0 { v += 16 } x0 := *(*float32)(unsafe.Add(xp, uintptr(outPtr+l)*4)) sum += x0 * (d1*float32(v) - m1) } for l := 0; l < 32; l++ { v := int(b.QS[qsPtr+l] >> 4) if (qh[l] & u2) != 0 { v += 16 } x1 := *(*float32)(unsafe.Add(xp, uintptr(outPtr+32+l)*4)) sum += x1 * (d2*float32(v) - m2) } qsPtr += 32 } return sum } func DotQ5KTile8(sums *[8]float32, w []BlockQ5_K, wp []Q5KDotParams, base int, stride int, x *float32, n int) { if n <= 0 { return } xp := unsafe.Pointer(x) for seg := 0; seg < 4; seg++ { xSeg := (*float32)(unsafe.Add(xp, uintptr(seg*64)*4)) qsPtr := seg * 32 shift1 := uint(2 * seg) shift2 := uint(2*seg + 1) u1 := uint8(1 << shift1) u2 := uint8(1 << shift2) xsp := unsafe.Pointer(xSeg) for t := 0; t < n; t++ { idx := base + t*stride b := &w[idx] p := &wp[idx] d1, m1 := p.D1[seg], p.M1[seg] d2, m2 := p.D2[seg], p.M2[seg] if hasAVX512 { sums[t] += dotQ5KInnerAVX512(&b.QS[qsPtr], &b.QH[0], xSeg, d1, m1, d2, m2, shift1, shift2) continue } if hasAVX2 { sums[t] += dotQ5KInnerAVX2(&b.QS[qsPtr], &b.QH[0], xSeg, d1, m1, d2, m2, shift1, shift2) continue } qh := b.QH[:] for l := 0; l < 32; l++ { v0 := int(b.QS[qsPtr+l] & 0xF) if (qh[l] & u1) != 0 { v0 += 16 } v1 := int(b.QS[qsPtr+l] >> 4) if (qh[l] & u2) != 0 { v1 += 16 } x0 := *(*float32)(unsafe.Add(xsp, uintptr(l)*4)) x1 := *(*float32)(unsafe.Add(xsp, uintptr(32+l)*4)) sums[t] += x0*(d1*float32(v0)-m1) + x1*(d2*float32(v1)-m2) } } } } func DequantizeQ5_K(b *BlockQ5_K, out []float32) { d := FP16ToFP32(b.D) min := FP16ToFP32(b.DMin) outIdx := 0 ql := b.QS[:] qh := b.QH[:] is := 0 u1 := uint8(1) u2 := uint8(2) for j := 0; j < QK_K; j += 64 { sc, m := getScaleMinK4(is+0, &b.Scales) d1 := d * float32(sc) m1 := min * float32(m) sc, m = getScaleMinK4(is+1, &b.Scales) d2 := d * float32(sc) m2 := min * float32(m) for l := 0; l < 32; l++ { v := int(ql[l] & 0xF) if (qh[l] & u1) != 0 { v += 16 } out[outIdx] = d1*float32(v) - m1 outIdx++ } for l := 0; l < 32; l++ { v := int(ql[l] >> 4) if (qh[l] & u2) != 0 { v += 16 } out[outIdx] = d2*float32(v) - m2 outIdx++ } ql = ql[32:] is += 2 u1 <<= 2 u2 <<= 2 } } func DotQ5_K(b *BlockQ5_K, x []float32) float32 { if len(x) != QK_K { panic("DotQ5_K: mismatched slice length") } d := FP16ToFP32(b.D) min := FP16ToFP32(b.DMin) ql := b.QS[:] qh := b.QH[:] is := 0 u1 := uint8(1) u2 := uint8(2) var sum float32 outIdx := 0 for j := 0; j < QK_K; j += 64 { sc, m := getScaleMinK4(is+0, &b.Scales) d1 := d * float32(sc) m1 := min * float32(m) sc, m = getScaleMinK4(is+1, &b.Scales) d2 := d * float32(sc) m2 := min * float32(m) for l := 0; l < 32; l++ { v := int(ql[l] & 0xF) if (qh[l] & u1) != 0 { v += 16 } sum += x[outIdx] * (d1*float32(v) - m1) outIdx++ } for l := 0; l < 32; l++ { v := int(ql[l] >> 4) if (qh[l] & u2) != 0 { v += 16 } sum += x[outIdx] * (d2*float32(v) - m2) outIdx++ } ql = ql[32:] is += 2 u1 <<= 2 u2 <<= 2 } return sum } type Q4KDotParams struct { D1 [4]float32 M1 [4]float32 D2 [4]float32 M2 [4]float32 } type q4kParamsKey struct { p unsafe.Pointer n int } var q4kDotParamsCache sync.Map func GetQ4KDotParams(blocks []BlockQ4_K) []Q4KDotParams { if len(blocks) == 0 { return nil } key := q4kParamsKey{p: unsafe.Pointer(&blocks[0]), n: len(blocks)} if v, ok := q4kDotParamsCache.Load(key); ok { return v.([]Q4KDotParams) } params := make([]Q4KDotParams, len(blocks)) for bi := range blocks { b := &blocks[bi] d := FP16ToFP32(b.D) dmin := FP16ToFP32(b.DMin) var sc [8]uint8 var m [8]uint8 for j := 0; j < 4; j++ { sc[j] = b.Scales[j] & 63 m[j] = b.Scales[j+4] & 63 } for j := 4; j < 8; j++ { sc[j] = (b.Scales[j+4] & 0xF) | ((b.Scales[j-4] >> 6) << 4) m[j] = (b.Scales[j+4] >> 4) | ((b.Scales[j-0] >> 6) << 4) } var p Q4KDotParams seg := 0 for i := 0; i < 8; i += 2 { p.D1[seg] = d * float32(sc[i]) p.M1[seg] = dmin * float32(m[i]) p.D2[seg] = d * float32(sc[i+1]) p.M2[seg] = dmin * float32(m[i+1]) seg++ } params[bi] = p } q4kDotParamsCache.Store(key, params) return params } func DotQ4_K_Params(b *BlockQ4_K, p *Q4KDotParams, x []float32) float32 { if len(x) != QK_K { panic("DotQ4_K_Params: mismatched slice length") } var sum float32 outPtr := 0 qsPtr := 0 for seg := 0; seg < 4; seg++ { d1, m1 := p.D1[seg], p.M1[seg] d2, m2 := p.D2[seg], p.M2[seg] if hasAVX512 { sum += dotQ4KInnerAVX512(&b.QS[qsPtr], &x[outPtr], d1, m1, d2, m2) } else if hasAVX2 { sum += dotQ4KInnerAVX2(&b.QS[qsPtr], &x[outPtr], d1, m1, d2, m2) } else { for l := 0; l < 32; l++ { val := b.QS[qsPtr+l] v1 := val & 0xF v2 := val >> 4 sum += x[outPtr] * (float32(v1)*d1 - m1) sum += x[outPtr+32] * (float32(v2)*d2 - m2) outPtr++ } outPtr += 32 qsPtr += 32 continue } outPtr += 64 qsPtr += 32 } return sum } func DotQ4_K_ParamsPtr(b *BlockQ4_K, p *Q4KDotParams, x *float32) float32 { var sum float32 xp := unsafe.Pointer(x) outPtr := 0 qsPtr := 0 for seg := 0; seg < 4; seg++ { d1, m1 := p.D1[seg], p.M1[seg] d2, m2 := p.D2[seg], p.M2[seg] xSeg := (*float32)(unsafe.Add(xp, uintptr(outPtr)*4)) if hasAVX512 { sum += dotQ4KInnerAVX512(&b.QS[qsPtr], xSeg, d1, m1, d2, m2) } else if hasAVX2 { sum += dotQ4KInnerAVX2(&b.QS[qsPtr], xSeg, d1, m1, d2, m2) } else { for l := 0; l < 32; l++ { val := b.QS[qsPtr+l] v1 := val & 0xF v2 := val >> 4 x0 := *(*float32)(unsafe.Add(xp, uintptr(outPtr)*4)) x1 := *(*float32)(unsafe.Add(xp, uintptr(outPtr+32)*4)) sum += x0 * (float32(v1)*d1 - m1) sum += x1 * (float32(v2)*d2 - m2) outPtr++ } outPtr += 32 qsPtr += 32 continue } outPtr += 64 qsPtr += 32 } return sum } func DotQ4KTile8(sums *[8]float32, w []BlockQ4_K, wp []Q4KDotParams, base int, stride int, x *float32, n int) { if n <= 0 { return } xp := unsafe.Pointer(x) outPtr := 0 qsPtr := 0 for seg := 0; seg < 4; seg++ { xSeg := (*float32)(unsafe.Add(xp, uintptr(outPtr)*4)) for t := 0; t < n; t++ { idx := base + t*stride b := &w[idx] p := &wp[idx] d1, m1 := p.D1[seg], p.M1[seg] d2, m2 := p.D2[seg], p.M2[seg] if hasAVX512 { sums[t] += dotQ4KInnerAVX512(&b.QS[qsPtr], xSeg, d1, m1, d2, m2) } else if hasAVX2 { sums[t] += dotQ4KInnerAVX2(&b.QS[qsPtr], xSeg, d1, m1, d2, m2) } else { xsp := unsafe.Pointer(xSeg) for l := 0; l < 32; l++ { val := b.QS[qsPtr+l] v1 := val & 0xF v2 := val >> 4 x0 := *(*float32)(unsafe.Add(xsp, uintptr(l)*4)) x1 := *(*float32)(unsafe.Add(xsp, uintptr(32+l)*4)) sums[t] += x0*(float32(v1)*d1-m1) + x1*(float32(v2)*d2-m2) } } } outPtr += 64 qsPtr += 32 } } // BlockQ8_K represents a block of 256 weights quantized to 8 bits // Layout (292 bytes): // - D (4 bytes): float32 scale // - QS (256 bytes): 256 int8 quants // - BSums (32 bytes): 16 int16 block sums (for dot product optimization, not used in dequant) type BlockQ8_K struct { D float32 QS [256]int8 BSums [16]int16 } // DequantizeQ8_K dequantizes a single Q8_K block into 256 floats func DequantizeQ8_K(b *BlockQ8_K, out []float32) { if dequantQ8KSimd(b, out) { return } d := b.D for i := 0; i < 256; i++ { out[i] = d * float32(b.QS[i]) } } func DotQ8_K(b *BlockQ8_K, x []float32) float32 { if len(x) != QK_K { panic("DotQ8_K: mismatched slice length") } if sum, ok := dotQ8KSimd(b, x); ok { return sum } d := b.D var sum float32 for i := 0; i < 256; i++ { sum += x[i] * float32(b.QS[i]) } return d * sum } // BlockQ2_K represents a block of 256 weights quantized to 2 bits // Layout (84 bytes): // - Scales (16 bytes): 16 4-bit scales and 16 4-bit mins packed // - QS (64 bytes): 256 2-bit quants packed // - D (2 bytes): float16 super-scale // - DMin (2 bytes): float16 super-min-scale type BlockQ2_K struct { Scales [16]uint8 QS [64]uint8 D uint16 DMin uint16 } type Q2KDotParams struct { DL [16]float32 ML [16]float32 } type q2kParamsKey struct { p unsafe.Pointer n int } var q2kDotParamsCache sync.Map func GetQ2KDotParams(blocks []BlockQ2_K) []Q2KDotParams { if len(blocks) == 0 { return nil } key := q2kParamsKey{p: unsafe.Pointer(&blocks[0]), n: len(blocks)} if v, ok := q2kDotParamsCache.Load(key); ok { return v.([]Q2KDotParams) } params := make([]Q2KDotParams, len(blocks)) for bi := range blocks { b := &blocks[bi] d := FP16ToFP32(b.D) dmin := FP16ToFP32(b.DMin) var p Q2KDotParams pi := 0 is := 0 for n := 0; n < QK_K; n += 128 { _ = n for shift := uint(0); shift < 8; shift += 2 { sc := b.Scales[is] is++ p.DL[pi] = d * float32(sc&0xF) p.ML[pi] = dmin * float32(sc>>4) pi++ sc = b.Scales[is] is++ p.DL[pi] = d * float32(sc&0xF) p.ML[pi] = dmin * float32(sc>>4) pi++ } } params[bi] = p } q2kDotParamsCache.Store(key, params) return params } // DequantizeQ2_K dequantizes a single Q2_K block into 256 floats // Mirrors llama.cpp dequantize_row_q2_K func DequantizeQ2_K(b *BlockQ2_K, out []float32) { if dequantQ2KSimd(b, out) { return } d := FP16ToFP32(b.D) dmin := FP16ToFP32(b.DMin) q := b.QS[:] is := 0 outIdx := 0 for n := 0; n < QK_K; n += 128 { shift := uint(0) for j := 0; j < 4; j++ { sc := b.Scales[is] is++ dl := d * float32(sc&0xF) ml := dmin * float32(sc>>4) for l := 0; l < 16; l++ { val := int8((q[l] >> shift) & 3) out[outIdx] = dl*float32(val) - ml outIdx++ } sc = b.Scales[is] is++ dl = d * float32(sc&0xF) ml = dmin * float32(sc>>4) for l := 0; l < 16; l++ { val := int8((q[l+16] >> shift) & 3) out[outIdx] = dl*float32(val) - ml outIdx++ } shift += 2 } q = q[32:] } } func DotQ2_K(b *BlockQ2_K, x []float32) float32 { if len(x) != QK_K { panic("DotQ2_K: mismatched slice length") } if sum, ok := dotQ2KSimd(b, x); ok { return sum } d := FP16ToFP32(b.D) dmin := FP16ToFP32(b.DMin) q := b.QS[:] is := 0 outIdx := 0 var sum float32 for n := 0; n < QK_K; n += 128 { shift := uint(0) for j := 0; j < 4; j++ { sc := b.Scales[is] is++ dl := d * float32(sc&0xF) ml := dmin * float32(sc>>4) for l := 0; l < 16; l++ { val := float32(int8((q[l] >> shift) & 3)) sum += x[outIdx] * (dl*val - ml) outIdx++ } sc = b.Scales[is] is++ dl = d * float32(sc&0xF) ml = dmin * float32(sc>>4) for l := 0; l < 16; l++ { val := float32(int8((q[l+16] >> shift) & 3)) sum += x[outIdx] * (dl*val - ml) outIdx++ } shift += 2 } q = q[32:] } return sum } // BlockQ3_K represents a block of 256 weights quantized to 3 bits // Layout (110 bytes): // - HMask (32 bytes): high bit of 3-bit quants // - QS (64 bytes): low 2 bits of 3-bit quants // - Scales (12 bytes): 6-bit scales packed // - D (2 bytes): float16 super-scale type BlockQ3_K struct { HMask [32]uint8 QS [64]uint8 Scales [12]uint8 D uint16 } type Q3KDotParams struct { S [16]float32 } type q3kParamsKey struct { p unsafe.Pointer n int } var q3kDotParamsCache sync.Map func GetQ3KDotParams(blocks []BlockQ3_K) []Q3KDotParams { if len(blocks) == 0 { return nil } key := q3kParamsKey{p: unsafe.Pointer(&blocks[0]), n: len(blocks)} if v, ok := q3kDotParamsCache.Load(key); ok { return v.([]Q3KDotParams) } params := make([]Q3KDotParams, len(blocks)) for bi := range blocks { b := &blocks[bi] d := FP16ToFP32(b.D) var p Q3KDotParams for i := 0; i < 16; i++ { p.S[i] = d * float32(unpackQ3Scale(b.Scales[:], i)) } params[bi] = p } q3kDotParamsCache.Store(key, params) return params } // DequantizeQ3_K dequantizes a single Q3_K block into 256 floats func DequantizeQ3_K(b *BlockQ3_K, out []float32) { if dequantQ3KSimd(b, out) { return } d := FP16ToFP32(b.D) var scales [16]int8 for i := 0; i < 16; i++ { scales[i] = unpackQ3Scale(b.Scales[:], i) } q := b.QS[:] hm := b.HMask[:] outIdx := 0 is := 0 m := uint8(1) for n := 0; n < QK_K; n += 128 { shift := uint(0) for j := 0; j < 4; j++ { dl := d * float32(scales[is]) is++ for l := 0; l < 16; l++ { qv := int8((q[l] >> shift) & 0x3) if hm[l]&m == 0 { qv -= 4 } out[outIdx] = dl * float32(qv) outIdx++ } dl = d * float32(scales[is]) is++ for l := 0; l < 16; l++ { qv := int8((q[l+16] >> shift) & 0x3) if hm[l+16]&m == 0 { qv -= 4 } out[outIdx] = dl * float32(qv) outIdx++ } shift += 2 m <<= 1 } q = q[32:] } } func DotQ3_K(b *BlockQ3_K, x []float32) float32 { if len(x) != QK_K { panic("DotQ3_K: mismatched slice length") } if sum, ok := dotQ3KSimd(b, x); ok { return sum } d := FP16ToFP32(b.D) var scales [16]int8 for i := 0; i < 16; i++ { scales[i] = unpackQ3Scale(b.Scales[:], i) } q := b.QS[:] hm := b.HMask[:] outIdx := 0 is := 0 m := uint8(1) var sum float32 for n := 0; n < QK_K; n += 128 { shift := uint(0) for j := 0; j < 4; j++ { dl := d * float32(scales[is]) is++ for l := 0; l < 16; l++ { qv := int8((q[l] >> shift) & 0x3) if hm[l]&m == 0 { qv -= 4 } sum += x[outIdx] * (dl * float32(qv)) outIdx++ } dl = d * float32(scales[is]) is++ for l := 0; l < 16; l++ { qv := int8((q[l+16] >> shift) & 0x3) if hm[l+16]&m == 0 { qv -= 4 } sum += x[outIdx] * (dl * float32(qv)) outIdx++ } shift += 2 m <<= 1 } q = q[32:] } return sum } // BlockQ6_K represents a block of 256 weights quantized to 6 bits // Layout (210 bytes): // - QL (128 bytes): lower 4 bits of 6-bit quants // - QH (64 bytes): upper 2 bits of 6-bit quants // - Scales (16 bytes): 8-bit signed scales // - D (2 bytes): float16 super-scale type BlockQ6_K struct { QL [128]uint8 QH [64]uint8 Scales [16]int8 D uint16 } type Q6KDotParams struct { S [16]float32 } type q6kParamsKey struct { p unsafe.Pointer n int } var q6kDotParamsCache sync.Map func GetQ6KDotParams(blocks []BlockQ6_K) []Q6KDotParams { if len(blocks) == 0 { return nil } key := q6kParamsKey{p: unsafe.Pointer(&blocks[0]), n: len(blocks)} if v, ok := q6kDotParamsCache.Load(key); ok { return v.([]Q6KDotParams) } params := make([]Q6KDotParams, len(blocks)) for bi := range blocks { b := &blocks[bi] d := FP16ToFP32(b.D) var p Q6KDotParams for i := 0; i < 16; i++ { p.S[i] = d * float32(b.Scales[i]) } params[bi] = p } q6kDotParamsCache.Store(key, params) return params } // DequantizeQ6_K dequantizes a single Q6_K block into 256 floats // Logic adapted from llama.cpp's dequantize_row_q6_K func DequantizeQ6_K(b *BlockQ6_K, out []float32) { if dequantQ6KSimd(b, out) { return } d := FP16ToFP32(b.D) qlPtr := 0 qhPtr := 0 scPtr := 0 outPtr := 0 for n := 0; n < 256; n += 128 { for l := 0; l < 32; l++ { is := l / 16 ql0 := b.QL[qlPtr+l] ql32 := b.QL[qlPtr+l+32] qh := b.QH[qhPtr+l] q1 := int8((ql0&0xF)|(((qh>>0)&3)<<4)) - 32 q2 := int8((ql32&0xF)|(((qh>>2)&3)<<4)) - 32 q3 := int8((ql0>>4)|(((qh>>4)&3)<<4)) - 32 q4 := int8((ql32>>4)|(((qh>>6)&3)<<4)) - 32 out[outPtr+l+0] = d * float32(b.Scales[scPtr+is+0]) * float32(q1) out[outPtr+l+32] = d * float32(b.Scales[scPtr+is+2]) * float32(q2) out[outPtr+l+64] = d * float32(b.Scales[scPtr+is+4]) * float32(q3) out[outPtr+l+96] = d * float32(b.Scales[scPtr+is+6]) * float32(q4) } outPtr += 128 qlPtr += 64 qhPtr += 32 scPtr += 8 } } func DotQ6_K(b *BlockQ6_K, x []float32) float32 { if len(x) != QK_K { panic("DotQ6_K: mismatched slice length") } if sum, ok := dotQ6KSimd(b, x); ok { return sum } d := FP16ToFP32(b.D) qlPtr := 0 qhPtr := 0 scPtr := 0 outPtr := 0 var sum float32 for n := 0; n < 256; n += 128 { for l := 0; l < 32; l++ { is := l / 16 ql0 := b.QL[qlPtr+l] ql32 := b.QL[qlPtr+l+32] qh := b.QH[qhPtr+l] q1 := int8((ql0&0xF)|(((qh>>0)&3)<<4)) - 32 q2 := int8((ql32&0xF)|(((qh>>2)&3)<<4)) - 32 q3 := int8((ql0>>4)|(((qh>>4)&3)<<4)) - 32 q4 := int8((ql32>>4)|(((qh>>6)&3)<<4)) - 32 s0 := d * float32(b.Scales[scPtr+is+0]) s2 := d * float32(b.Scales[scPtr+is+2]) s4 := d * float32(b.Scales[scPtr+is+4]) s6 := d * float32(b.Scales[scPtr+is+6]) base := outPtr + l sum += x[base+0] * s0 * float32(q1) sum += x[base+32] * s2 * float32(q2) sum += x[base+64] * s4 * float32(q3) sum += x[base+96] * s6 * float32(q4) } outPtr += 128 qlPtr += 64 qhPtr += 32 scPtr += 8 } return sum }