//go:build amd64 // +build amd64 #include "textflag.h" // ============================================================================ // Q8_K Dequantization - AVX512 // BlockQ8_K layout: // offset 0: D (float32) // offset 4: QS[256] (int8) // ============================================================================ // func dequantQ8KAVX512(b *BlockQ8_K, out *float32) TEXT ·dequantQ8KAVX512(SB), NOSPLIT, $0-16 MOVQ b+0(FP), DI MOVQ out+8(FP), SI // Broadcast scale d to Z0 VBROADCASTSS (DI), Z0 // QS pointer = b + 4 LEAQ 4(DI), R8 MOVQ SI, R9 // Process 256 elements, 16 at a time (unrolled by 4) MOVQ $0, CX loop_q8: CMPQ CX, $256 JGE done_q8 // Load 16 int8, sign-extend to 16 int32, convert to float, multiply VPMOVSXBD (R8), Z1 VCVTDQ2PS Z1, Z1 VMULPS Z0, Z1, Z1 VMOVUPS Z1, (R9) VPMOVSXBD 16(R8), Z2 VCVTDQ2PS Z2, Z2 VMULPS Z0, Z2, Z2 VMOVUPS Z2, 64(R9) VPMOVSXBD 32(R8), Z3 VCVTDQ2PS Z3, Z3 VMULPS Z0, Z3, Z3 VMOVUPS Z3, 128(R9) VPMOVSXBD 48(R8), Z4 VCVTDQ2PS Z4, Z4 VMULPS Z0, Z4, Z4 VMOVUPS Z4, 192(R9) ADDQ $64, R8 ADDQ $256, R9 ADDQ $64, CX JMP loop_q8 done_q8: VZEROUPPER RET // ========================================================================== // Q8_K Fused Dot - AVX512 // Computes: sum_i x[i] * (d * float32(qs[i])) // func dotQ8KAVX512(b *BlockQ8_K, x *float32) float32 TEXT ·dotQ8KAVX512(SB), NOSPLIT, $0-24 MOVQ b+0(FP), DI MOVQ x+8(FP), SI // Load scale d MOVSS (DI), X0 // QS pointer = b + 4 LEAQ 4(DI), R8 // Accumulator Z1 VXORPS Z1, Z1, Z1 MOVQ $0, CX dot_q8_512_loop: CMPQ CX, $256 JGE dot_q8_512_reduce // 16x int8 -> 16x int32 -> 16x float VPMOVSXBD (R8), Z2 VCVTDQ2PS Z2, Z2 // load 16 floats from x VMOVUPS (SI), Z3 VFMADD231PS Z3, Z2, Z1 ADDQ $16, R8 ADDQ $64, SI ADDQ $16, CX JMP dot_q8_512_loop dot_q8_512_reduce: // horizontal reduce Z1 -> X1 VEXTRACTF32X8 $1, Z1, Y2 VADDPS Y2, Y1, Y1 VEXTRACTF128 $1, Y1, X2 VADDPS X2, X1, X1 VPSHUFD $0x4E, X1, X2 VADDPS X2, X1, X1 VPSHUFD $0xB1, X1, X2 VADDPS X2, X1, X1 // multiply by d MULSS X0, X1 MOVSS X1, ret+16(FP) VZEROUPPER RET TEXT ·dotQ4KInnerAVX512(SB), NOSPLIT, $0-40 MOVQ qs+0(FP), DI MOVQ x+8(FP), SI LEAQ 128(SI), R10 VBROADCASTSS d1+16(FP), Z0 VBROADCASTSS m1+20(FP), Z1 VBROADCASTSS d2+24(FP), Z2 VBROADCASTSS m2+28(FP), Z3 MOVL $0x0F0F0F0F, AX MOVD AX, X7 VPBROADCASTD X7, Z7 VXORPS Z12, Z12, Z12 VXORPS Z13, Z13, Z13 VXORPS Z14, Z14, Z14 VXORPS Z15, Z15, Z15 MOVQ $0, CX dot_q4k_512_loop: CMPQ CX, $32 JGE dot_q4k_512_reduce VPMOVZXBD (DI), Z8 VPANDD Z7, Z8, Z9 VCVTDQ2PS Z9, Z9 VMOVUPS (SI), Z10 VADDPS Z10, Z13, Z13 VFMADD231PS Z10, Z9, Z12 VPSRLD $4, Z8, Z8 VCVTDQ2PS Z8, Z8 VMOVUPS (R10), Z11 VADDPS Z11, Z15, Z15 VFMADD231PS Z11, Z8, Z14 ADDQ $16, DI ADDQ $64, SI ADDQ $64, R10 ADDQ $16, CX JMP dot_q4k_512_loop dot_q4k_512_reduce: VMULPS Z0, Z12, Z12 VMULPS Z1, Z13, Z13 VSUBPS Z13, Z12, Z12 VMULPS Z2, Z14, Z14 VMULPS Z3, Z15, Z15 VSUBPS Z15, Z14, Z14 VADDPS Z14, Z12, Z12 VEXTRACTF32X8 $1, Z12, Y1 VADDPS Y1, Y12, Y12 VEXTRACTF128 $1, Y12, X1 VADDPS X1, X12, X12 VPSHUFD $0x4E, X12, X1 VADDPS X1, X12, X12 VPSHUFD $0xB1, X12, X1 VADDPS X1, X12, X12 MOVSS X12, ret+32(FP) VZEROUPPER RET // func dotQ5KInnerAVX512(qs *byte, qh *byte, x *float32, d1, m1, d2, m2 float32, u1, u2 uint) float32 TEXT ·dotQ5KInnerAVX512(SB), NOSPLIT, $0-64 MOVQ qs+0(FP), DI MOVQ qh+8(FP), SI MOVQ x+16(FP), DX LEAQ 128(DX), R10 VBROADCASTSS d1+24(FP), Z0 VBROADCASTSS m1+28(FP), Z1 VBROADCASTSS d2+32(FP), Z2 VBROADCASTSS m2+36(FP), Z3 MOVL $0x0F0F0F0F, AX MOVD AX, X7 VPBROADCASTD X7, Z7 MOVL $1, AX MOVD AX, X4 VPBROADCASTD X4, Z4 MOVQ u1+40(FP), CX MOVQ $1, AX SHLQ CL, AX MOVL AX, BX MOVD BX, X6 VPBROADCASTD X6, Z6 MOVQ u2+48(FP), CX MOVQ $1, AX SHLQ CL, AX MOVL AX, BX MOVD BX, X5 VPBROADCASTD X5, Z5 VXORPS Z15, Z15, Z15 MOVQ u1+40(FP), AX CMPQ AX, $0 JE dot_q5k_loop_s0 CMPQ AX, $2 JE dot_q5k_loop_s1 CMPQ AX, $4 JE dot_q5k_loop_s2 JMP dot_q5k_loop_s3 dot_q5k_loop_s0: MOVQ $0, CX dot_q5k_loop0: CMPQ CX, $32 JGE dot_q5k_reduce VPMOVZXBD (DI), Z11 VPANDD Z7, Z11, Z9 VPSRLD $4, Z11, Z10 VPMOVZXBD (SI), Z12 VPANDD Z6, Z12, Z13 VPSRLD $0, Z13, Z13 VPANDD Z4, Z13, Z13 VPSLLD $4, Z13, Z13 VPANDD Z5, Z12, Z8 VPSRLD $1, Z8, Z8 VPANDD Z4, Z8, Z8 VPSLLD $4, Z8, Z8 VPADDD Z13, Z9, Z9 VPADDD Z8, Z10, Z10 VCVTDQ2PS Z9, Z9 VMULPS Z0, Z9, Z9 VSUBPS Z1, Z9, Z9 VMOVUPS (DX), Z14 VFMADD231PS Z14, Z9, Z15 VCVTDQ2PS Z10, Z10 VMULPS Z2, Z10, Z10 VSUBPS Z3, Z10, Z10 VMOVUPS (R10), Z14 VFMADD231PS Z14, Z10, Z15 ADDQ $16, DI ADDQ $16, SI ADDQ $64, DX ADDQ $64, R10 ADDQ $16, CX JMP dot_q5k_loop0 dot_q5k_loop_s1: MOVQ $0, CX dot_q5k_loop1: CMPQ CX, $32 JGE dot_q5k_reduce VPMOVZXBD (DI), Z11 VPANDD Z7, Z11, Z9 VPSRLD $4, Z11, Z10 VPMOVZXBD (SI), Z12 VPANDD Z6, Z12, Z13 VPSRLD $2, Z13, Z13 VPANDD Z4, Z13, Z13 VPSLLD $4, Z13, Z13 VPANDD Z5, Z12, Z8 VPSRLD $3, Z8, Z8 VPANDD Z4, Z8, Z8 VPSLLD $4, Z8, Z8 VPADDD Z13, Z9, Z9 VPADDD Z8, Z10, Z10 VCVTDQ2PS Z9, Z9 VMULPS Z0, Z9, Z9 VSUBPS Z1, Z9, Z9 VMOVUPS (DX), Z14 VFMADD231PS Z14, Z9, Z15 VCVTDQ2PS Z10, Z10 VMULPS Z2, Z10, Z10 VSUBPS Z3, Z10, Z10 VMOVUPS (R10), Z14 VFMADD231PS Z14, Z10, Z15 ADDQ $16, DI ADDQ $16, SI ADDQ $64, DX ADDQ $64, R10 ADDQ $16, CX JMP dot_q5k_loop1 dot_q5k_loop_s2: MOVQ $0, CX dot_q5k_loop2: CMPQ CX, $32 JGE dot_q5k_reduce VPMOVZXBD (DI), Z11 VPANDD Z7, Z11, Z9 VPSRLD $4, Z11, Z10 VPMOVZXBD (SI), Z12 VPANDD Z6, Z12, Z13 VPSRLD $4, Z13, Z13 VPANDD Z4, Z13, Z13 VPSLLD $4, Z13, Z13 VPANDD Z5, Z12, Z8 VPSRLD $5, Z8, Z8 VPANDD Z4, Z8, Z8 VPSLLD $4, Z8, Z8 VPADDD Z13, Z9, Z9 VPADDD Z8, Z10, Z10 VCVTDQ2PS Z9, Z9 VMULPS Z0, Z9, Z9 VSUBPS Z1, Z9, Z9 VMOVUPS (DX), Z14 VFMADD231PS Z14, Z9, Z15 VCVTDQ2PS Z10, Z10 VMULPS Z2, Z10, Z10 VSUBPS Z3, Z10, Z10 VMOVUPS (R10), Z14 VFMADD231PS Z14, Z10, Z15 ADDQ $16, DI ADDQ $16, SI ADDQ $64, DX ADDQ $64, R10 ADDQ $16, CX JMP dot_q5k_loop2 dot_q5k_loop_s3: MOVQ $0, CX dot_q5k_loop3: CMPQ CX, $32 JGE dot_q5k_reduce VPMOVZXBD (DI), Z11 VPANDD Z7, Z11, Z9 VPSRLD $4, Z11, Z10 VPMOVZXBD (SI), Z12 VPANDD Z6, Z12, Z13 VPSRLD $6, Z13, Z13 VPANDD Z4, Z13, Z13 VPSLLD $4, Z13, Z13 VPANDD Z5, Z12, Z8 VPSRLD $7, Z8, Z8 VPANDD Z4, Z8, Z8 VPSLLD $4, Z8, Z8 VPADDD Z13, Z9, Z9 VPADDD Z8, Z10, Z10 VCVTDQ2PS Z9, Z9 VMULPS Z0, Z9, Z9 VSUBPS Z1, Z9, Z9 VMOVUPS (DX), Z14 VFMADD231PS Z14, Z9, Z15 VCVTDQ2PS Z10, Z10 VMULPS Z2, Z10, Z10 VSUBPS Z3, Z10, Z10 VMOVUPS (R10), Z14 VFMADD231PS Z14, Z10, Z15 ADDQ $16, DI ADDQ $16, SI ADDQ $64, DX ADDQ $64, R10 ADDQ $16, CX JMP dot_q5k_loop3 dot_q5k_reduce: VEXTRACTF32X8 $1, Z15, Y1 VADDPS Y1, Y15, Y15 VEXTRACTF128 $1, Y15, X1 VADDPS X1, X15, X15 VPSHUFD $0x4E, X15, X1 VADDPS X1, X15, X15 VPSHUFD $0xB1, X15, X1 VADDPS X1, X15, X15 MOVSS X15, ret+56(FP) VZEROUPPER RET TEXT ·dotQ6KInnerAVX512(SB), NOSPLIT, $0-40 MOVQ ql+0(FP), DI MOVQ qh+8(FP), SI MOVQ scales+16(FP), DX MOVQ x+24(FP), R8 MOVL $0x0F, AX MOVD AX, X11 VPBROADCASTD X11, Z11 MOVL $0x03, AX MOVD AX, X10 VPBROADCASTD X10, Z10 MOVL $0x42000000, AX MOVD AX, X9 VBROADCASTSS X9, Z9 VXORPS Z15, Z15, Z15 MOVQ $0, CX dot_q6k_512_loop: CMPQ CX, $32 JGE dot_q6k_512_reduce MOVQ CX, R11 SHRQ $4, R11 VPMOVZXBD (DI)(CX*1), Z0 VPMOVZXBD 32(DI)(CX*1), Z1 VPMOVZXBD (SI)(CX*1), Z2 VPANDD Z11, Z0, Z3 VPANDD Z10, Z2, Z4 VPSLLD $4, Z4, Z4 VPADDD Z4, Z3, Z3 VCVTDQ2PS Z3, Z3 VSUBPS Z9, Z3, Z3 VBROADCASTSS (DX)(R11*4), Z4 VMULPS Z4, Z3, Z3 LEAQ (R8)(CX*4), R12 VMOVUPS (R12), Z5 VFMADD231PS Z5, Z3, Z15 VPANDD Z11, Z1, Z3 VPSRLD $2, Z2, Z4 VPANDD Z10, Z4, Z4 VPSLLD $4, Z4, Z4 VPADDD Z4, Z3, Z3 VCVTDQ2PS Z3, Z3 VSUBPS Z9, Z3, Z3 VBROADCASTSS 8(DX)(R11*4), Z4 VMULPS Z4, Z3, Z3 VMOVUPS 128(R12), Z5 VFMADD231PS Z5, Z3, Z15 VPSRLD $4, Z0, Z3 VPANDD Z11, Z3, Z3 VPSRLD $4, Z2, Z4 VPANDD Z10, Z4, Z4 VPSLLD $4, Z4, Z4 VPADDD Z4, Z3, Z3 VCVTDQ2PS Z3, Z3 VSUBPS Z9, Z3, Z3 VBROADCASTSS 16(DX)(R11*4), Z4 VMULPS Z4, Z3, Z3 VMOVUPS 256(R12), Z5 VFMADD231PS Z5, Z3, Z15 VPSRLD $4, Z1, Z3 VPANDD Z11, Z3, Z3 VPSRLD $6, Z2, Z4 VPANDD Z10, Z4, Z4 VPSLLD $4, Z4, Z4 VPADDD Z4, Z3, Z3 VCVTDQ2PS Z3, Z3 VSUBPS Z9, Z3, Z3 VBROADCASTSS 24(DX)(R11*4), Z4 VMULPS Z4, Z3, Z3 VMOVUPS 384(R12), Z5 VFMADD231PS Z5, Z3, Z15 ADDQ $16, CX JMP dot_q6k_512_loop dot_q6k_512_reduce: VEXTRACTF32X8 $1, Z15, Y1 VADDPS Y1, Y15, Y15 VEXTRACTF128 $1, Y15, X1 VADDPS X1, X15, X15 VPSHUFD $0x4E, X15, X1 VADDPS X1, X15, X15 VPSHUFD $0xB1, X15, X1 VADDPS X1, X15, X15 MOVSS X15, ret+32(FP) VZEROUPPER RET