| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521 |
- //go:build amd64
- // +build amd64
- #include "textflag.h"
- // ============================================================================
- // Q8_K Dequantization - AVX512
- // BlockQ8_K layout:
- // offset 0: D (float32)
- // offset 4: QS[256] (int8)
- // ============================================================================
- // func dequantQ8KAVX512(b *BlockQ8_K, out *float32)
- TEXT ·dequantQ8KAVX512(SB), NOSPLIT, $0-16
- MOVQ b+0(FP), DI
- MOVQ out+8(FP), SI
- // Broadcast scale d to Z0
- VBROADCASTSS (DI), Z0
- // QS pointer = b + 4
- LEAQ 4(DI), R8
- MOVQ SI, R9
- // Process 256 elements, 16 at a time (unrolled by 4)
- MOVQ $0, CX
- loop_q8:
- CMPQ CX, $256
- JGE done_q8
- // Load 16 int8, sign-extend to 16 int32, convert to float, multiply
- VPMOVSXBD (R8), Z1
- VCVTDQ2PS Z1, Z1
- VMULPS Z0, Z1, Z1
- VMOVUPS Z1, (R9)
- VPMOVSXBD 16(R8), Z2
- VCVTDQ2PS Z2, Z2
- VMULPS Z0, Z2, Z2
- VMOVUPS Z2, 64(R9)
- VPMOVSXBD 32(R8), Z3
- VCVTDQ2PS Z3, Z3
- VMULPS Z0, Z3, Z3
- VMOVUPS Z3, 128(R9)
- VPMOVSXBD 48(R8), Z4
- VCVTDQ2PS Z4, Z4
- VMULPS Z0, Z4, Z4
- VMOVUPS Z4, 192(R9)
- ADDQ $64, R8
- ADDQ $256, R9
- ADDQ $64, CX
- JMP loop_q8
- done_q8:
- VZEROUPPER
- RET
- // ==========================================================================
- // Q8_K Fused Dot - AVX512
- // Computes: sum_i x[i] * (d * float32(qs[i]))
- // func dotQ8KAVX512(b *BlockQ8_K, x *float32) float32
- TEXT ·dotQ8KAVX512(SB), NOSPLIT, $0-24
- MOVQ b+0(FP), DI
- MOVQ x+8(FP), SI
- // Load scale d
- MOVSS (DI), X0
- // QS pointer = b + 4
- LEAQ 4(DI), R8
- // Accumulator Z1
- VXORPS Z1, Z1, Z1
- MOVQ $0, CX
- dot_q8_512_loop:
- CMPQ CX, $256
- JGE dot_q8_512_reduce
- // 16x int8 -> 16x int32 -> 16x float
- VPMOVSXBD (R8), Z2
- VCVTDQ2PS Z2, Z2
- // load 16 floats from x
- VMOVUPS (SI), Z3
- VFMADD231PS Z3, Z2, Z1
- ADDQ $16, R8
- ADDQ $64, SI
- ADDQ $16, CX
- JMP dot_q8_512_loop
- dot_q8_512_reduce:
- // horizontal reduce Z1 -> X1
- VEXTRACTF32X8 $1, Z1, Y2
- VADDPS Y2, Y1, Y1
- VEXTRACTF128 $1, Y1, X2
- VADDPS X2, X1, X1
- VPSHUFD $0x4E, X1, X2
- VADDPS X2, X1, X1
- VPSHUFD $0xB1, X1, X2
- VADDPS X2, X1, X1
- // multiply by d
- MULSS X0, X1
- MOVSS X1, ret+16(FP)
- VZEROUPPER
- RET
- TEXT ·dotQ4KInnerAVX512(SB), NOSPLIT, $0-40
- MOVQ qs+0(FP), DI
- MOVQ x+8(FP), SI
- LEAQ 128(SI), R10
- VBROADCASTSS d1+16(FP), Z0
- VBROADCASTSS m1+20(FP), Z1
- VBROADCASTSS d2+24(FP), Z2
- VBROADCASTSS m2+28(FP), Z3
- MOVL $0x0F0F0F0F, AX
- MOVD AX, X7
- VPBROADCASTD X7, Z7
- VXORPS Z12, Z12, Z12
- VXORPS Z13, Z13, Z13
- VXORPS Z14, Z14, Z14
- VXORPS Z15, Z15, Z15
- MOVQ $0, CX
- dot_q4k_512_loop:
- CMPQ CX, $32
- JGE dot_q4k_512_reduce
- VPMOVZXBD (DI), Z8
- VPANDD Z7, Z8, Z9
- VCVTDQ2PS Z9, Z9
- VMOVUPS (SI), Z10
- VADDPS Z10, Z13, Z13
- VFMADD231PS Z10, Z9, Z12
- VPSRLD $4, Z8, Z8
- VCVTDQ2PS Z8, Z8
- VMOVUPS (R10), Z11
- VADDPS Z11, Z15, Z15
- VFMADD231PS Z11, Z8, Z14
- ADDQ $16, DI
- ADDQ $64, SI
- ADDQ $64, R10
- ADDQ $16, CX
- JMP dot_q4k_512_loop
- dot_q4k_512_reduce:
- VMULPS Z0, Z12, Z12
- VMULPS Z1, Z13, Z13
- VSUBPS Z13, Z12, Z12
- VMULPS Z2, Z14, Z14
- VMULPS Z3, Z15, Z15
- VSUBPS Z15, Z14, Z14
- VADDPS Z14, Z12, Z12
- VEXTRACTF32X8 $1, Z12, Y1
- VADDPS Y1, Y12, Y12
- VEXTRACTF128 $1, Y12, X1
- VADDPS X1, X12, X12
- VPSHUFD $0x4E, X12, X1
- VADDPS X1, X12, X12
- VPSHUFD $0xB1, X12, X1
- VADDPS X1, X12, X12
- MOVSS X12, ret+32(FP)
- VZEROUPPER
- RET
- // func dotQ5KInnerAVX512(qs *byte, qh *byte, x *float32, d1, m1, d2, m2 float32, u1, u2 uint) float32
- TEXT ·dotQ5KInnerAVX512(SB), NOSPLIT, $0-64
- MOVQ qs+0(FP), DI
- MOVQ qh+8(FP), SI
- MOVQ x+16(FP), DX
- LEAQ 128(DX), R10
- VBROADCASTSS d1+24(FP), Z0
- VBROADCASTSS m1+28(FP), Z1
- VBROADCASTSS d2+32(FP), Z2
- VBROADCASTSS m2+36(FP), Z3
- MOVL $0x0F0F0F0F, AX
- MOVD AX, X7
- VPBROADCASTD X7, Z7
- MOVL $1, AX
- MOVD AX, X4
- VPBROADCASTD X4, Z4
- MOVQ u1+40(FP), CX
- MOVQ $1, AX
- SHLQ CL, AX
- MOVL AX, BX
- MOVD BX, X6
- VPBROADCASTD X6, Z6
- MOVQ u2+48(FP), CX
- MOVQ $1, AX
- SHLQ CL, AX
- MOVL AX, BX
- MOVD BX, X5
- VPBROADCASTD X5, Z5
- VXORPS Z15, Z15, Z15
- MOVQ u1+40(FP), AX
- CMPQ AX, $0
- JE dot_q5k_loop_s0
- CMPQ AX, $2
- JE dot_q5k_loop_s1
- CMPQ AX, $4
- JE dot_q5k_loop_s2
- JMP dot_q5k_loop_s3
- dot_q5k_loop_s0:
- MOVQ $0, CX
- dot_q5k_loop0:
- CMPQ CX, $32
- JGE dot_q5k_reduce
- VPMOVZXBD (DI), Z11
- VPANDD Z7, Z11, Z9
- VPSRLD $4, Z11, Z10
- VPMOVZXBD (SI), Z12
- VPANDD Z6, Z12, Z13
- VPSRLD $0, Z13, Z13
- VPANDD Z4, Z13, Z13
- VPSLLD $4, Z13, Z13
- VPANDD Z5, Z12, Z8
- VPSRLD $1, Z8, Z8
- VPANDD Z4, Z8, Z8
- VPSLLD $4, Z8, Z8
- VPADDD Z13, Z9, Z9
- VPADDD Z8, Z10, Z10
- VCVTDQ2PS Z9, Z9
- VMULPS Z0, Z9, Z9
- VSUBPS Z1, Z9, Z9
- VMOVUPS (DX), Z14
- VFMADD231PS Z14, Z9, Z15
- VCVTDQ2PS Z10, Z10
- VMULPS Z2, Z10, Z10
- VSUBPS Z3, Z10, Z10
- VMOVUPS (R10), Z14
- VFMADD231PS Z14, Z10, Z15
- ADDQ $16, DI
- ADDQ $16, SI
- ADDQ $64, DX
- ADDQ $64, R10
- ADDQ $16, CX
- JMP dot_q5k_loop0
- dot_q5k_loop_s1:
- MOVQ $0, CX
- dot_q5k_loop1:
- CMPQ CX, $32
- JGE dot_q5k_reduce
- VPMOVZXBD (DI), Z11
- VPANDD Z7, Z11, Z9
- VPSRLD $4, Z11, Z10
- VPMOVZXBD (SI), Z12
- VPANDD Z6, Z12, Z13
- VPSRLD $2, Z13, Z13
- VPANDD Z4, Z13, Z13
- VPSLLD $4, Z13, Z13
- VPANDD Z5, Z12, Z8
- VPSRLD $3, Z8, Z8
- VPANDD Z4, Z8, Z8
- VPSLLD $4, Z8, Z8
- VPADDD Z13, Z9, Z9
- VPADDD Z8, Z10, Z10
- VCVTDQ2PS Z9, Z9
- VMULPS Z0, Z9, Z9
- VSUBPS Z1, Z9, Z9
- VMOVUPS (DX), Z14
- VFMADD231PS Z14, Z9, Z15
- VCVTDQ2PS Z10, Z10
- VMULPS Z2, Z10, Z10
- VSUBPS Z3, Z10, Z10
- VMOVUPS (R10), Z14
- VFMADD231PS Z14, Z10, Z15
- ADDQ $16, DI
- ADDQ $16, SI
- ADDQ $64, DX
- ADDQ $64, R10
- ADDQ $16, CX
- JMP dot_q5k_loop1
- dot_q5k_loop_s2:
- MOVQ $0, CX
- dot_q5k_loop2:
- CMPQ CX, $32
- JGE dot_q5k_reduce
- VPMOVZXBD (DI), Z11
- VPANDD Z7, Z11, Z9
- VPSRLD $4, Z11, Z10
- VPMOVZXBD (SI), Z12
- VPANDD Z6, Z12, Z13
- VPSRLD $4, Z13, Z13
- VPANDD Z4, Z13, Z13
- VPSLLD $4, Z13, Z13
- VPANDD Z5, Z12, Z8
- VPSRLD $5, Z8, Z8
- VPANDD Z4, Z8, Z8
- VPSLLD $4, Z8, Z8
- VPADDD Z13, Z9, Z9
- VPADDD Z8, Z10, Z10
- VCVTDQ2PS Z9, Z9
- VMULPS Z0, Z9, Z9
- VSUBPS Z1, Z9, Z9
- VMOVUPS (DX), Z14
- VFMADD231PS Z14, Z9, Z15
- VCVTDQ2PS Z10, Z10
- VMULPS Z2, Z10, Z10
- VSUBPS Z3, Z10, Z10
- VMOVUPS (R10), Z14
- VFMADD231PS Z14, Z10, Z15
- ADDQ $16, DI
- ADDQ $16, SI
- ADDQ $64, DX
- ADDQ $64, R10
- ADDQ $16, CX
- JMP dot_q5k_loop2
- dot_q5k_loop_s3:
- MOVQ $0, CX
- dot_q5k_loop3:
- CMPQ CX, $32
- JGE dot_q5k_reduce
- VPMOVZXBD (DI), Z11
- VPANDD Z7, Z11, Z9
- VPSRLD $4, Z11, Z10
- VPMOVZXBD (SI), Z12
- VPANDD Z6, Z12, Z13
- VPSRLD $6, Z13, Z13
- VPANDD Z4, Z13, Z13
- VPSLLD $4, Z13, Z13
- VPANDD Z5, Z12, Z8
- VPSRLD $7, Z8, Z8
- VPANDD Z4, Z8, Z8
- VPSLLD $4, Z8, Z8
- VPADDD Z13, Z9, Z9
- VPADDD Z8, Z10, Z10
- VCVTDQ2PS Z9, Z9
- VMULPS Z0, Z9, Z9
- VSUBPS Z1, Z9, Z9
- VMOVUPS (DX), Z14
- VFMADD231PS Z14, Z9, Z15
- VCVTDQ2PS Z10, Z10
- VMULPS Z2, Z10, Z10
- VSUBPS Z3, Z10, Z10
- VMOVUPS (R10), Z14
- VFMADD231PS Z14, Z10, Z15
- ADDQ $16, DI
- ADDQ $16, SI
- ADDQ $64, DX
- ADDQ $64, R10
- ADDQ $16, CX
- JMP dot_q5k_loop3
- dot_q5k_reduce:
- VEXTRACTF32X8 $1, Z15, Y1
- VADDPS Y1, Y15, Y15
- VEXTRACTF128 $1, Y15, X1
- VADDPS X1, X15, X15
- VPSHUFD $0x4E, X15, X1
- VADDPS X1, X15, X15
- VPSHUFD $0xB1, X15, X1
- VADDPS X1, X15, X15
- MOVSS X15, ret+56(FP)
- VZEROUPPER
- RET
- TEXT ·dotQ6KInnerAVX512(SB), NOSPLIT, $0-40
- MOVQ ql+0(FP), DI
- MOVQ qh+8(FP), SI
- MOVQ scales+16(FP), DX
- MOVQ x+24(FP), R8
- MOVL $0x0F, AX
- MOVD AX, X11
- VPBROADCASTD X11, Z11
- MOVL $0x03, AX
- MOVD AX, X10
- VPBROADCASTD X10, Z10
- MOVL $0x42000000, AX
- MOVD AX, X9
- VBROADCASTSS X9, Z9
- VXORPS Z15, Z15, Z15
- MOVQ $0, CX
- dot_q6k_512_loop:
- CMPQ CX, $32
- JGE dot_q6k_512_reduce
- MOVQ CX, R11
- SHRQ $4, R11
- VPMOVZXBD (DI)(CX*1), Z0
- VPMOVZXBD 32(DI)(CX*1), Z1
- VPMOVZXBD (SI)(CX*1), Z2
- VPANDD Z11, Z0, Z3
- VPANDD Z10, Z2, Z4
- VPSLLD $4, Z4, Z4
- VPADDD Z4, Z3, Z3
- VCVTDQ2PS Z3, Z3
- VSUBPS Z9, Z3, Z3
- VBROADCASTSS (DX)(R11*4), Z4
- VMULPS Z4, Z3, Z3
- LEAQ (R8)(CX*4), R12
- VMOVUPS (R12), Z5
- VFMADD231PS Z5, Z3, Z15
- VPANDD Z11, Z1, Z3
- VPSRLD $2, Z2, Z4
- VPANDD Z10, Z4, Z4
- VPSLLD $4, Z4, Z4
- VPADDD Z4, Z3, Z3
- VCVTDQ2PS Z3, Z3
- VSUBPS Z9, Z3, Z3
- VBROADCASTSS 8(DX)(R11*4), Z4
- VMULPS Z4, Z3, Z3
- VMOVUPS 128(R12), Z5
- VFMADD231PS Z5, Z3, Z15
- VPSRLD $4, Z0, Z3
- VPANDD Z11, Z3, Z3
- VPSRLD $4, Z2, Z4
- VPANDD Z10, Z4, Z4
- VPSLLD $4, Z4, Z4
- VPADDD Z4, Z3, Z3
- VCVTDQ2PS Z3, Z3
- VSUBPS Z9, Z3, Z3
- VBROADCASTSS 16(DX)(R11*4), Z4
- VMULPS Z4, Z3, Z3
- VMOVUPS 256(R12), Z5
- VFMADD231PS Z5, Z3, Z15
- VPSRLD $4, Z1, Z3
- VPANDD Z11, Z3, Z3
- VPSRLD $6, Z2, Z4
- VPANDD Z10, Z4, Z4
- VPSLLD $4, Z4, Z4
- VPADDD Z4, Z3, Z3
- VCVTDQ2PS Z3, Z3
- VSUBPS Z9, Z3, Z3
- VBROADCASTSS 24(DX)(R11*4), Z4
- VMULPS Z4, Z3, Z3
- VMOVUPS 384(R12), Z5
- VFMADD231PS Z5, Z3, Z15
- ADDQ $16, CX
- JMP dot_q6k_512_loop
- dot_q6k_512_reduce:
- VEXTRACTF32X8 $1, Z15, Y1
- VADDPS Y1, Y15, Y15
- VEXTRACTF128 $1, Y15, X1
- VADDPS X1, X15, X15
- VPSHUFD $0x4E, X15, X1
- VADDPS X1, X15, X15
- VPSHUFD $0xB1, X15, X1
- VADDPS X1, X15, X15
- MOVSS X15, ret+32(FP)
- VZEROUPPER
- RET
|