| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486 |
- package tensor
- import (
- "sync"
- "golang.org/x/sys/cpu"
- )
- var (
- hasAVX512 = cpu.X86.HasAVX512F && cpu.X86.HasAVX512DQ && cpu.X86.HasAVX512BW && cpu.X86.HasAVX512VL
- hasAVX2 = cpu.X86.HasAVX2 && cpu.X86.HasFMA
- )
- var q6kSimdOnce sync.Once
- var q6kSimdOK bool
- func absDiffF32(a, b float32) float32 {
- if a > b {
- return a - b
- }
- return b - a
- }
- func dequantizeQ6KScalar(b *BlockQ6_K, out []float32) {
- d := FP16ToFP32(b.D)
- qlPtr := 0
- qhPtr := 0
- scPtr := 0
- outPtr := 0
- for n := 0; n < 256; n += 128 {
- for l := 0; l < 32; l++ {
- is := l / 16
- ql0 := b.QL[qlPtr+l]
- ql32 := b.QL[qlPtr+l+32]
- qh := b.QH[qhPtr+l]
- q1 := int8((ql0&0xF)|(((qh>>0)&3)<<4)) - 32
- q2 := int8((ql32&0xF)|(((qh>>2)&3)<<4)) - 32
- q3 := int8((ql0>>4)|(((qh>>4)&3)<<4)) - 32
- q4 := int8((ql32>>4)|(((qh>>6)&3)<<4)) - 32
- out[outPtr+l+0] = d * float32(b.Scales[scPtr+is+0]) * float32(q1)
- out[outPtr+l+32] = d * float32(b.Scales[scPtr+is+2]) * float32(q2)
- out[outPtr+l+64] = d * float32(b.Scales[scPtr+is+4]) * float32(q3)
- out[outPtr+l+96] = d * float32(b.Scales[scPtr+is+6]) * float32(q4)
- }
- outPtr += 128
- qlPtr += 64
- qhPtr += 32
- scPtr += 8
- }
- }
- func q6kSimdReady() bool {
- if !hasAVX2 {
- return false
- }
- q6kSimdOnce.Do(func() {
- var b BlockQ6_K
- b.D = 0x3C00
- for i := range b.Scales {
- b.Scales[i] = int8((i % 16) - 8)
- }
- for i := range b.QL {
- b.QL[i] = uint8(i)
- }
- for i := range b.QH {
- b.QH[i] = uint8(i * 3)
- }
- var outScalar [256]float32
- dequantizeQ6KScalar(&b, outScalar[:])
- var outSimd [256]float32
- d := FP16ToFP32(b.D)
- dequantQ6KInnerAVX2(&b.QL[0], &b.QH[0], &b.Scales[0], &outSimd[0], d)
- dequantQ6KInnerAVX2(&b.QL[64], &b.QH[32], &b.Scales[8], &outSimd[128], d)
- for i := 0; i < 256; i++ {
- if absDiffF32(outSimd[i], outScalar[i]) > 1e-4 {
- q6kSimdOK = false
- return
- }
- }
- q6kSimdOK = true
- })
- return q6kSimdOK
- }
- // ============================================================================
- // Assembly Declarations
- // ============================================================================
- // Q8_K - Full AVX2/AVX512 vectorization
- //
- //go:noescape
- func dequantQ8KAVX512(b *BlockQ8_K, out *float32)
- //go:noescape
- func dequantQ8KAVX2(b *BlockQ8_K, out *float32)
- //go:noescape
- func dotQ8KAVX512(b *BlockQ8_K, x *float32) float32
- //go:noescape
- func dotQ8KAVX2(b *BlockQ8_K, x *float32) float32
- // Q4_K - Inner loop vectorization (32 quants at a time)
- //
- //go:noescape
- func dequantQ4KInnerAVX2(qs *byte, out *float32, d1, m1, d2, m2 float32)
- // Q4_K - Fused dot product (32 bytes -> 64 quants) against 64 float32 inputs.
- //
- //go:noescape
- func dotQ4KInnerAVX512(qs *byte, x *float32, d1, m1, d2, m2 float32) float32
- //go:noescape
- func dotQ4KInnerAVX2(qs *byte, x *float32, d1, m1, d2, m2 float32) float32
- //go:noescape
- func dotQ5KInnerAVX512(qs *byte, qh *byte, x *float32, d1, m1, d2, m2 float32, u1, u2 uint) float32
- //go:noescape
- func dotQ5KInnerAVX2(qs *byte, qh *byte, x *float32, d1, m1, d2, m2 float32, u1, u2 uint) float32
- // Q2_K - Inner loop vectorization (16 quants at a time)
- //
- //go:noescape
- func dequantQ2KInnerAVX2(q *byte, out *float32, dl, ml float32, shift uint)
- //go:noescape
- func dotQ2KInnerAVX2Fused(q *byte, x *float32, dl, ml float32, shift uint) float32
- // Q3_K - Inner loop vectorization
- //
- //go:noescape
- func dequantQ3KInnerAVX2(q *byte, hm *byte, out *float32, dl float32, m uint8, shift uint)
- //go:noescape
- func dotQ3KInnerAVX2Fused(q *byte, hm *byte, x *float32, dl float32, m uint8, shift uint) float32
- // Q6_K - Inner loop vectorization
- //
- //go:noescape
- func dequantQ6KInnerAVX2(ql *byte, qh *byte, scales *int8, out *float32, d float32)
- //go:noescape
- func dotQ6KInnerAVX2(ql *byte, qh *byte, scales *float32, x *float32) float32
- //go:noescape
- func dotQ6KInnerAVX512(ql *byte, qh *byte, scales *float32, x *float32) float32
- // ============================================================================
- // Q8_K Dequantization
- // ============================================================================
- // dequantQ8KSimd attempts a vector-friendly dequant for Q8_K.
- // Returns true if the fast path was taken.
- func dequantQ8KSimd(b *BlockQ8_K, out []float32) bool {
- if hasAVX512 {
- dequantQ8KAVX512(b, &out[0])
- return true
- }
- if hasAVX2 {
- dequantQ8KAVX2(b, &out[0])
- return true
- }
- return false
- }
- // ============================================================================
- // Q4_K Dequantization
- // ============================================================================
- // dequantQ4KSimd performs vectorized Q4_K dequantization using AVX2.
- func dequantQ4KSimd(b *BlockQ4_K, out []float32) bool {
- if !hasAVX2 {
- return false
- }
- d := FP16ToFP32(b.D)
- dmin := FP16ToFP32(b.DMin)
- // Decode 6-bit scales and mins
- var sc [8]uint8
- var m [8]uint8
- for j := 0; j < 4; j++ {
- sc[j] = b.Scales[j] & 63
- m[j] = b.Scales[j+4] & 63
- }
- for j := 4; j < 8; j++ {
- sc[j] = (b.Scales[j+4] & 0xF) | ((b.Scales[j-4] >> 6) << 4)
- m[j] = (b.Scales[j+4] >> 4) | ((b.Scales[j-0] >> 6) << 4)
- }
- outPtr := 0
- qsPtr := 0
- for i := 0; i < 8; i += 2 {
- d1 := d * float32(sc[i])
- m1 := dmin * float32(m[i])
- d2 := d * float32(sc[i+1])
- m2 := dmin * float32(m[i+1])
- // Use AVX2 kernel for inner loop
- dequantQ4KInnerAVX2(&b.QS[qsPtr], &out[outPtr], d1, m1, d2, m2)
- outPtr += 64
- qsPtr += 32
- }
- return true
- }
- func dotQ4KSimd(b *BlockQ4_K, x []float32) (float32, bool) {
- useAVX512 := hasAVX512
- useAVX2 := hasAVX2
- if !useAVX512 && !useAVX2 {
- return 0, false
- }
- if len(x) != QK_K {
- return 0, false
- }
- d := FP16ToFP32(b.D)
- dmin := FP16ToFP32(b.DMin)
- // Decode 6-bit scales and mins
- var sc [8]uint8
- var m [8]uint8
- for j := 0; j < 4; j++ {
- sc[j] = b.Scales[j] & 63
- m[j] = b.Scales[j+4] & 63
- }
- for j := 4; j < 8; j++ {
- sc[j] = (b.Scales[j+4] & 0xF) | ((b.Scales[j-4] >> 6) << 4)
- m[j] = (b.Scales[j+4] >> 4) | ((b.Scales[j-0] >> 6) << 4)
- }
- var sum float32
- outPtr := 0
- qsPtr := 0
- for i := 0; i < 8; i += 2 {
- d1 := d * float32(sc[i])
- m1 := dmin * float32(m[i])
- d2 := d * float32(sc[i+1])
- m2 := dmin * float32(m[i+1])
- if useAVX512 {
- sum += dotQ4KInnerAVX512(&b.QS[qsPtr], &x[outPtr], d1, m1, d2, m2)
- } else {
- sum += dotQ4KInnerAVX2(&b.QS[qsPtr], &x[outPtr], d1, m1, d2, m2)
- }
- outPtr += 64
- qsPtr += 32
- }
- return sum, true
- }
- func dotQ8KSimd(b *BlockQ8_K, x []float32) (float32, bool) {
- useAVX512 := hasAVX512
- useAVX2 := hasAVX2
- if !useAVX512 && !useAVX2 {
- return 0, false
- }
- if len(x) != QK_K {
- return 0, false
- }
- if useAVX512 {
- return dotQ8KAVX512(b, &x[0]), true
- }
- return dotQ8KAVX2(b, &x[0]), true
- }
- func dotQ2KSimd(b *BlockQ2_K, x []float32) (float32, bool) {
- if !hasAVX2 {
- return 0, false
- }
- if len(x) != QK_K {
- return 0, false
- }
- d := FP16ToFP32(b.D)
- dmin := FP16ToFP32(b.DMin)
- is := 0
- xIdx := 0
- qOffset := 0
- var sum float32
- for n := 0; n < QK_K; n += 128 {
- for shift := uint(0); shift < 8; shift += 2 {
- sc := b.Scales[is]
- is++
- dl := d * float32(sc&0xF)
- ml := dmin * float32(sc>>4)
- sum += dotQ2KInnerAVX2Fused(&b.QS[qOffset], &x[xIdx], dl, ml, shift)
- xIdx += 16
- sc = b.Scales[is]
- is++
- dl = d * float32(sc&0xF)
- ml = dmin * float32(sc>>4)
- sum += dotQ2KInnerAVX2Fused(&b.QS[qOffset+16], &x[xIdx], dl, ml, shift)
- xIdx += 16
- }
- qOffset += 32
- }
- return sum, true
- }
- func dotQ3KSimd(b *BlockQ3_K, x []float32) (float32, bool) {
- if !hasAVX2 {
- return 0, false
- }
- if len(x) != QK_K {
- return 0, false
- }
- d := FP16ToFP32(b.D)
- var scales [16]float32
- for i := 0; i < 16; i++ {
- scales[i] = float32(unpackQ3Scale(b.Scales[:], i))
- }
- q := b.QS[:]
- hm := b.HMask[:]
- is := 0
- xIdx := 0
- m := uint8(1)
- var sum float32
- for n := 0; n < QK_K; n += 128 {
- for j := 0; j < 4; j++ {
- dl := d * scales[is]
- sum += dotQ3KInnerAVX2Fused(&q[0], &hm[0], &x[xIdx], dl, m, uint(j*2))
- xIdx += 16
- is++
- dl = d * scales[is]
- sum += dotQ3KInnerAVX2Fused(&q[16], &hm[16], &x[xIdx], dl, m, uint(j*2))
- xIdx += 16
- is++
- m <<= 1
- }
- q = q[32:]
- }
- return sum, true
- }
- // ============================================================================
- // Q2_K Dequantization
- // ============================================================================
- // dequantQ2KSimd performs vectorized Q2_K dequantization.
- func dequantQ2KSimd(b *BlockQ2_K, out []float32) bool {
- if !hasAVX2 {
- return false
- }
- d := FP16ToFP32(b.D)
- dmin := FP16ToFP32(b.DMin)
- is := 0
- outIdx := 0
- qOffset := 0
- for n := 0; n < QK_K; n += 128 {
- for shift := uint(0); shift < 8; shift += 2 {
- sc := b.Scales[is]
- is++
- dl := d * float32(sc&0xF)
- ml := dmin * float32(sc>>4)
- // Process 16 elements with AVX2
- dequantQ2KInnerAVX2(&b.QS[qOffset], &out[outIdx], dl, ml, shift)
- outIdx += 16
- sc = b.Scales[is]
- is++
- dl = d * float32(sc&0xF)
- ml = dmin * float32(sc>>4)
- // Process next 16 elements
- dequantQ2KInnerAVX2(&b.QS[qOffset+16], &out[outIdx], dl, ml, shift)
- outIdx += 16
- }
- qOffset += 32
- }
- return true
- }
- // ============================================================================
- // Q3_K Dequantization
- // ============================================================================
- // dequantQ3KSimd returns false because benchmarks showed the scalar path is faster.
- // Q3_K has complex 3-bit + high-bit packing that doesn't benefit from SIMD.
- // Scalar: 443ns, Unrolled Go: 502ns
- func dequantQ3KSimd(b *BlockQ3_K, out []float32) bool {
- if !hasAVX2 {
- return false
- }
- d := FP16ToFP32(b.D)
- var scales [16]float32
- for i := 0; i < 16; i++ {
- scales[i] = float32(unpackQ3Scale(b.Scales[:], i))
- }
- q := b.QS[:]
- hm := b.HMask[:]
- outIdx := 0
- is := 0
- m := uint8(1)
- // Same loop structure as scalar, but vectorized inner loop
- for n := 0; n < QK_K; n += 128 {
- for j := 0; j < 4; j++ {
- // First 16
- dl1 := d * scales[is]
- dequantQ3KInnerAVX2(&q[0], &hm[0], &out[outIdx], dl1, m, uint(j*2))
- is++
- outIdx += 16
-
- // Second 16 (offset by 16 in q/hm)
- dl2 := d * scales[is]
- dequantQ3KInnerAVX2(&q[16], &hm[16], &out[outIdx], dl2, m, uint(j*2))
- is++
- outIdx += 16
- // In scalar, mask m shifts left
- m <<= 1
- }
- q = q[32:]
- // reset mask for next 128 block?
- // Wait, scalar: m starts at 1. Correct.
- }
- return true
- }
- // ============================================================================
- // Q6_K Dequantization
- // ============================================================================
- // dequantQ6KSimd returns false because benchmarks showed the scalar path is equivalent.
- // Q6_K has complex 6-bit packing that doesn't benefit from our Go-based optimization.
- // Scalar: 515ns, Unrolled Go: 521ns
- func dequantQ6KSimd(b *BlockQ6_K, out []float32) bool {
- // disabled: verification failure in block calculation
- if !q6kSimdReady() {
- return false
- }
- if len(out) != QK_K {
- return false
- }
- d := FP16ToFP32(b.D)
-
- dequantQ6KInnerAVX2(&b.QL[0], &b.QH[0], &b.Scales[0], &out[0], d)
- dequantQ6KInnerAVX2(&b.QL[64], &b.QH[32], &b.Scales[8], &out[128], d)
- return true
- }
- func dotQ6KSimd(b *BlockQ6_K, x []float32) (float32, bool) {
- if !q6kSimdReady() {
- return 0, false
- }
- if len(x) != QK_K {
- return 0, false
- }
- var tmp [128]float32
- d := FP16ToFP32(b.D)
- var sum float32
- dequantQ6KInnerAVX2(&b.QL[0], &b.QH[0], &b.Scales[0], &tmp[0], d)
- for i := 0; i < 128; i++ {
- sum += x[i] * tmp[i]
- }
- dequantQ6KInnerAVX2(&b.QL[64], &b.QH[32], &b.Scales[8], &tmp[0], d)
- for i := 0; i < 128; i++ {
- sum += x[128+i] * tmp[i]
- }
- return sum, true
- }
|