package tensor import ( "sync" "golang.org/x/sys/cpu" ) var ( hasAVX512 = cpu.X86.HasAVX512F && cpu.X86.HasAVX512DQ && cpu.X86.HasAVX512BW && cpu.X86.HasAVX512VL hasAVX2 = cpu.X86.HasAVX2 && cpu.X86.HasFMA ) var q6kSimdOnce sync.Once var q6kSimdOK bool func absDiffF32(a, b float32) float32 { if a > b { return a - b } return b - a } func dequantizeQ6KScalar(b *BlockQ6_K, out []float32) { d := FP16ToFP32(b.D) qlPtr := 0 qhPtr := 0 scPtr := 0 outPtr := 0 for n := 0; n < 256; n += 128 { for l := 0; l < 32; l++ { is := l / 16 ql0 := b.QL[qlPtr+l] ql32 := b.QL[qlPtr+l+32] qh := b.QH[qhPtr+l] q1 := int8((ql0&0xF)|(((qh>>0)&3)<<4)) - 32 q2 := int8((ql32&0xF)|(((qh>>2)&3)<<4)) - 32 q3 := int8((ql0>>4)|(((qh>>4)&3)<<4)) - 32 q4 := int8((ql32>>4)|(((qh>>6)&3)<<4)) - 32 out[outPtr+l+0] = d * float32(b.Scales[scPtr+is+0]) * float32(q1) out[outPtr+l+32] = d * float32(b.Scales[scPtr+is+2]) * float32(q2) out[outPtr+l+64] = d * float32(b.Scales[scPtr+is+4]) * float32(q3) out[outPtr+l+96] = d * float32(b.Scales[scPtr+is+6]) * float32(q4) } outPtr += 128 qlPtr += 64 qhPtr += 32 scPtr += 8 } } func q6kSimdReady() bool { if !hasAVX2 { return false } q6kSimdOnce.Do(func() { var b BlockQ6_K b.D = 0x3C00 for i := range b.Scales { b.Scales[i] = int8((i % 16) - 8) } for i := range b.QL { b.QL[i] = uint8(i) } for i := range b.QH { b.QH[i] = uint8(i * 3) } var outScalar [256]float32 dequantizeQ6KScalar(&b, outScalar[:]) var outSimd [256]float32 d := FP16ToFP32(b.D) dequantQ6KInnerAVX2(&b.QL[0], &b.QH[0], &b.Scales[0], &outSimd[0], d) dequantQ6KInnerAVX2(&b.QL[64], &b.QH[32], &b.Scales[8], &outSimd[128], d) for i := 0; i < 256; i++ { if absDiffF32(outSimd[i], outScalar[i]) > 1e-4 { q6kSimdOK = false return } } q6kSimdOK = true }) return q6kSimdOK } // ============================================================================ // Assembly Declarations // ============================================================================ // Q8_K - Full AVX2/AVX512 vectorization // //go:noescape func dequantQ8KAVX512(b *BlockQ8_K, out *float32) //go:noescape func dequantQ8KAVX2(b *BlockQ8_K, out *float32) //go:noescape func dotQ8KAVX512(b *BlockQ8_K, x *float32) float32 //go:noescape func dotQ8KAVX2(b *BlockQ8_K, x *float32) float32 // Q4_K - Inner loop vectorization (32 quants at a time) // //go:noescape func dequantQ4KInnerAVX2(qs *byte, out *float32, d1, m1, d2, m2 float32) // Q4_K - Fused dot product (32 bytes -> 64 quants) against 64 float32 inputs. // //go:noescape func dotQ4KInnerAVX512(qs *byte, x *float32, d1, m1, d2, m2 float32) float32 //go:noescape func dotQ4KInnerAVX2(qs *byte, x *float32, d1, m1, d2, m2 float32) float32 //go:noescape func dotQ5KInnerAVX512(qs *byte, qh *byte, x *float32, d1, m1, d2, m2 float32, u1, u2 uint) float32 //go:noescape func dotQ5KInnerAVX2(qs *byte, qh *byte, x *float32, d1, m1, d2, m2 float32, u1, u2 uint) float32 // Q2_K - Inner loop vectorization (16 quants at a time) // //go:noescape func dequantQ2KInnerAVX2(q *byte, out *float32, dl, ml float32, shift uint) //go:noescape func dotQ2KInnerAVX2Fused(q *byte, x *float32, dl, ml float32, shift uint) float32 // Q3_K - Inner loop vectorization // //go:noescape func dequantQ3KInnerAVX2(q *byte, hm *byte, out *float32, dl float32, m uint8, shift uint) //go:noescape func dotQ3KInnerAVX2Fused(q *byte, hm *byte, x *float32, dl float32, m uint8, shift uint) float32 // Q6_K - Inner loop vectorization // //go:noescape func dequantQ6KInnerAVX2(ql *byte, qh *byte, scales *int8, out *float32, d float32) //go:noescape func dotQ6KInnerAVX2(ql *byte, qh *byte, scales *float32, x *float32) float32 //go:noescape func dotQ6KInnerAVX512(ql *byte, qh *byte, scales *float32, x *float32) float32 // ============================================================================ // Q8_K Dequantization // ============================================================================ // dequantQ8KSimd attempts a vector-friendly dequant for Q8_K. // Returns true if the fast path was taken. func dequantQ8KSimd(b *BlockQ8_K, out []float32) bool { if hasAVX512 { dequantQ8KAVX512(b, &out[0]) return true } if hasAVX2 { dequantQ8KAVX2(b, &out[0]) return true } return false } // ============================================================================ // Q4_K Dequantization // ============================================================================ // dequantQ4KSimd performs vectorized Q4_K dequantization using AVX2. func dequantQ4KSimd(b *BlockQ4_K, out []float32) bool { if !hasAVX2 { return false } d := FP16ToFP32(b.D) dmin := FP16ToFP32(b.DMin) // Decode 6-bit scales and mins var sc [8]uint8 var m [8]uint8 for j := 0; j < 4; j++ { sc[j] = b.Scales[j] & 63 m[j] = b.Scales[j+4] & 63 } for j := 4; j < 8; j++ { sc[j] = (b.Scales[j+4] & 0xF) | ((b.Scales[j-4] >> 6) << 4) m[j] = (b.Scales[j+4] >> 4) | ((b.Scales[j-0] >> 6) << 4) } outPtr := 0 qsPtr := 0 for i := 0; i < 8; i += 2 { d1 := d * float32(sc[i]) m1 := dmin * float32(m[i]) d2 := d * float32(sc[i+1]) m2 := dmin * float32(m[i+1]) // Use AVX2 kernel for inner loop dequantQ4KInnerAVX2(&b.QS[qsPtr], &out[outPtr], d1, m1, d2, m2) outPtr += 64 qsPtr += 32 } return true } func dotQ4KSimd(b *BlockQ4_K, x []float32) (float32, bool) { useAVX512 := hasAVX512 useAVX2 := hasAVX2 if !useAVX512 && !useAVX2 { return 0, false } if len(x) != QK_K { return 0, false } d := FP16ToFP32(b.D) dmin := FP16ToFP32(b.DMin) // Decode 6-bit scales and mins var sc [8]uint8 var m [8]uint8 for j := 0; j < 4; j++ { sc[j] = b.Scales[j] & 63 m[j] = b.Scales[j+4] & 63 } for j := 4; j < 8; j++ { sc[j] = (b.Scales[j+4] & 0xF) | ((b.Scales[j-4] >> 6) << 4) m[j] = (b.Scales[j+4] >> 4) | ((b.Scales[j-0] >> 6) << 4) } var sum float32 outPtr := 0 qsPtr := 0 for i := 0; i < 8; i += 2 { d1 := d * float32(sc[i]) m1 := dmin * float32(m[i]) d2 := d * float32(sc[i+1]) m2 := dmin * float32(m[i+1]) if useAVX512 { sum += dotQ4KInnerAVX512(&b.QS[qsPtr], &x[outPtr], d1, m1, d2, m2) } else { sum += dotQ4KInnerAVX2(&b.QS[qsPtr], &x[outPtr], d1, m1, d2, m2) } outPtr += 64 qsPtr += 32 } return sum, true } func dotQ8KSimd(b *BlockQ8_K, x []float32) (float32, bool) { useAVX512 := hasAVX512 useAVX2 := hasAVX2 if !useAVX512 && !useAVX2 { return 0, false } if len(x) != QK_K { return 0, false } if useAVX512 { return dotQ8KAVX512(b, &x[0]), true } return dotQ8KAVX2(b, &x[0]), true } func dotQ2KSimd(b *BlockQ2_K, x []float32) (float32, bool) { if !hasAVX2 { return 0, false } if len(x) != QK_K { return 0, false } d := FP16ToFP32(b.D) dmin := FP16ToFP32(b.DMin) is := 0 xIdx := 0 qOffset := 0 var sum float32 for n := 0; n < QK_K; n += 128 { for shift := uint(0); shift < 8; shift += 2 { sc := b.Scales[is] is++ dl := d * float32(sc&0xF) ml := dmin * float32(sc>>4) sum += dotQ2KInnerAVX2Fused(&b.QS[qOffset], &x[xIdx], dl, ml, shift) xIdx += 16 sc = b.Scales[is] is++ dl = d * float32(sc&0xF) ml = dmin * float32(sc>>4) sum += dotQ2KInnerAVX2Fused(&b.QS[qOffset+16], &x[xIdx], dl, ml, shift) xIdx += 16 } qOffset += 32 } return sum, true } func dotQ3KSimd(b *BlockQ3_K, x []float32) (float32, bool) { if !hasAVX2 { return 0, false } if len(x) != QK_K { return 0, false } d := FP16ToFP32(b.D) var scales [16]float32 for i := 0; i < 16; i++ { scales[i] = float32(unpackQ3Scale(b.Scales[:], i)) } q := b.QS[:] hm := b.HMask[:] is := 0 xIdx := 0 m := uint8(1) var sum float32 for n := 0; n < QK_K; n += 128 { for j := 0; j < 4; j++ { dl := d * scales[is] sum += dotQ3KInnerAVX2Fused(&q[0], &hm[0], &x[xIdx], dl, m, uint(j*2)) xIdx += 16 is++ dl = d * scales[is] sum += dotQ3KInnerAVX2Fused(&q[16], &hm[16], &x[xIdx], dl, m, uint(j*2)) xIdx += 16 is++ m <<= 1 } q = q[32:] } return sum, true } // ============================================================================ // Q2_K Dequantization // ============================================================================ // dequantQ2KSimd performs vectorized Q2_K dequantization. func dequantQ2KSimd(b *BlockQ2_K, out []float32) bool { if !hasAVX2 { return false } d := FP16ToFP32(b.D) dmin := FP16ToFP32(b.DMin) is := 0 outIdx := 0 qOffset := 0 for n := 0; n < QK_K; n += 128 { for shift := uint(0); shift < 8; shift += 2 { sc := b.Scales[is] is++ dl := d * float32(sc&0xF) ml := dmin * float32(sc>>4) // Process 16 elements with AVX2 dequantQ2KInnerAVX2(&b.QS[qOffset], &out[outIdx], dl, ml, shift) outIdx += 16 sc = b.Scales[is] is++ dl = d * float32(sc&0xF) ml = dmin * float32(sc>>4) // Process next 16 elements dequantQ2KInnerAVX2(&b.QS[qOffset+16], &out[outIdx], dl, ml, shift) outIdx += 16 } qOffset += 32 } return true } // ============================================================================ // Q3_K Dequantization // ============================================================================ // dequantQ3KSimd returns false because benchmarks showed the scalar path is faster. // Q3_K has complex 3-bit + high-bit packing that doesn't benefit from SIMD. // Scalar: 443ns, Unrolled Go: 502ns func dequantQ3KSimd(b *BlockQ3_K, out []float32) bool { if !hasAVX2 { return false } d := FP16ToFP32(b.D) var scales [16]float32 for i := 0; i < 16; i++ { scales[i] = float32(unpackQ3Scale(b.Scales[:], i)) } q := b.QS[:] hm := b.HMask[:] outIdx := 0 is := 0 m := uint8(1) // Same loop structure as scalar, but vectorized inner loop for n := 0; n < QK_K; n += 128 { for j := 0; j < 4; j++ { // First 16 dl1 := d * scales[is] dequantQ3KInnerAVX2(&q[0], &hm[0], &out[outIdx], dl1, m, uint(j*2)) is++ outIdx += 16 // Second 16 (offset by 16 in q/hm) dl2 := d * scales[is] dequantQ3KInnerAVX2(&q[16], &hm[16], &out[outIdx], dl2, m, uint(j*2)) is++ outIdx += 16 // In scalar, mask m shifts left m <<= 1 } q = q[32:] // reset mask for next 128 block? // Wait, scalar: m starts at 1. Correct. } return true } // ============================================================================ // Q6_K Dequantization // ============================================================================ // dequantQ6KSimd returns false because benchmarks showed the scalar path is equivalent. // Q6_K has complex 6-bit packing that doesn't benefit from our Go-based optimization. // Scalar: 515ns, Unrolled Go: 521ns func dequantQ6KSimd(b *BlockQ6_K, out []float32) bool { // disabled: verification failure in block calculation if !q6kSimdReady() { return false } if len(out) != QK_K { return false } d := FP16ToFP32(b.D) dequantQ6KInnerAVX2(&b.QL[0], &b.QH[0], &b.Scales[0], &out[0], d) dequantQ6KInnerAVX2(&b.QL[64], &b.QH[32], &b.Scales[8], &out[128], d) return true } func dotQ6KSimd(b *BlockQ6_K, x []float32) (float32, bool) { if !q6kSimdReady() { return 0, false } if len(x) != QK_K { return 0, false } var tmp [128]float32 d := FP16ToFP32(b.D) var sum float32 dequantQ6KInnerAVX2(&b.QL[0], &b.QH[0], &b.Scales[0], &tmp[0], d) for i := 0; i < 128; i++ { sum += x[i] * tmp[i] } dequantQ6KInnerAVX2(&b.QL[64], &b.QH[32], &b.Scales[8], &tmp[0], d) for i := 0; i < 128; i++ { sum += x[128+i] * tmp[i] } return sum, true }