//go:build amd64 // +build amd64 #include "textflag.h" // func dotAVX512(a *float32, b *float32, n int) float32 TEXT ·dotAVX512(SB), NOSPLIT, $0-24 // Load args MOVQ a+0(FP), DI MOVQ b+8(FP), SI MOVQ n+16(FP), CX // Accumulator VXORPS Z0, Z0, Z0 // If n <= 0 return 0 TESTQ CX, CX JLE dot512_zero // Process 16 floats per iteration loop16: CMPQ CX, $16 JL fold512 VMOVUPS (DI), Z1 VMOVUPS (SI), Z2 VFMADD231PS Z1, Z2, Z0 // Z0 += Z1 * Z2 ADDQ $64, DI ADDQ $64, SI SUBQ $16, CX JMP loop16 // Fold zmm0 -> xmm0 before handling tails. fold512: VEXTRACTF32X8 $1, Z0, Y1 VADDPS Y1, Y0, Y0 VEXTRACTF128 $1, Y0, X1 VADDPS X1, X0, X0 // Scalar tail loop1: CMPQ CX, $4 JL loop1_scalar VMOVUPS (DI), X1 VMOVUPS (SI), X2 VFMADD231PS X1, X2, X0 ADDQ $16, DI ADDQ $16, SI SUBQ $4, CX JMP loop1 loop1_scalar: TESTQ CX, CX JE reduce4 MOVSS (DI), X1 MOVSS (SI), X2 VFMADD231SS X1, X2, X0 ADDQ $4, DI ADDQ $4, SI DECQ CX JMP loop1_scalar // Horizontal sum of xmm0 (4 lanes) to scalar. reduce4: VMOVHLPS X0, X0, X1 VADDPS X1, X0, X0 VPSHUFD $0xB1, X0, X1 VADDPS X1, X0, X0 MOVSS X0, ret+24(FP) VZEROUPPER RET dot512_zero: VXORPS X0, X0, X0 MOVSS X0, ret+24(FP) RET // func axpyAVX512(alpha float32, x *float32, y *float32, n int) TEXT ·axpyAVX512(SB), NOSPLIT, $0-28 MOVSS alpha+0(FP), X0 VBROADCASTSS X0, Z0 MOVQ x+8(FP), DI MOVQ y+16(FP), SI MOVQ n+24(FP), CX TESTQ CX, CX JLE axpy512_done axpy512_loop16: CMPQ CX, $16 JL axpy512_loop1 VMOVUPS (DI), Z1 VMOVUPS (SI), Z2 VFMADD231PS Z0, Z1, Z2 VMOVUPS Z2, (SI) ADDQ $64, DI ADDQ $64, SI SUBQ $16, CX JMP axpy512_loop16 axpy512_loop1: TESTQ CX, CX JE axpy512_done MOVSS (DI), X1 MOVSS (SI), X2 VFMADD231SS X0, X1, X2 MOVSS X2, (SI) ADDQ $4, DI ADDQ $4, SI DECQ CX JMP axpy512_loop1 axpy512_done: RET