//go:build amd64 // +build amd64 #include "textflag.h" // func dotAVX2(a *float32, b *float32, n int) float32 TEXT ·dotAVX2(SB), NOSPLIT, $0-24 MOVQ a+0(FP), DI MOVQ b+8(FP), SI MOVQ n+16(FP), CX // Accumulator ymm0 = 0 VXORPS Y0, Y0, Y0 TESTQ CX, CX JLE dot_zero // Process 8 floats per iteration loop8: CMPQ CX, $8 JL fold VMOVUPS (DI), Y1 VMOVUPS (SI), Y2 VFMADD231PS Y1, Y2, Y0 // Y0 += Y1 * Y2 ADDQ $32, DI ADDQ $32, SI SUBQ $8, CX JMP loop8 // Fold ymm0 upper half into xmm0 before handling tails. fold: VEXTRACTF128 $1, Y0, X1 VADDPS X1, X0, X0 // Scalar tail loop1: CMPQ CX, $4 JL loop_scalar VMOVUPS (DI), X1 VMOVUPS (SI), X2 VFMADD231PS X1, X2, X0 ADDQ $16, DI ADDQ $16, SI SUBQ $4, CX JMP loop1 loop_scalar: TESTQ CX, CX JE reduce4 MOVSS (DI), X1 MOVSS (SI), X2 VFMADD231SS X1, X2, X0 ADDQ $4, DI ADDQ $4, SI DECQ CX JMP loop_scalar // Horizontal sum of xmm0 (4 lanes) to scalar reduce4: VMOVHLPS X0, X0, X1 VADDPS X1, X0, X0 VPSHUFD $0xB1, X0, X1 VADDPS X1, X0, X0 MOVSS X0, ret+24(FP) VZEROUPPER RET dot_zero: VXORPS X0, X0, X0 MOVSS X0, ret+24(FP) RET // func axpyAVX2(alpha float32, x *float32, y *float32, n int) TEXT ·axpyAVX2(SB), NOSPLIT, $0-28 MOVSS alpha+0(FP), X0 VBROADCASTSS X0, Y0 MOVQ x+8(FP), DI MOVQ y+16(FP), SI MOVQ n+24(FP), CX TESTQ CX, CX JLE axpy_done axpy_loop8: CMPQ CX, $8 JL axpy_loop1 VMOVUPS (DI), Y1 VMOVUPS (SI), Y2 VFMADD231PS Y0, Y1, Y2 VMOVUPS Y2, (SI) ADDQ $32, DI ADDQ $32, SI SUBQ $8, CX JMP axpy_loop8 axpy_loop1: TESTQ CX, CX JE axpy_done MOVSS (DI), X1 MOVSS (SI), X2 VFMADD231SS X0, X1, X2 MOVSS X2, (SI) ADDQ $4, DI ADDQ $4, SI DECQ CX JMP axpy_loop1 axpy_done: RET