| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111 |
- //go:build amd64
- // +build amd64
- #include "textflag.h"
- // func dotAVX2(a *float32, b *float32, n int) float32
- TEXT ·dotAVX2(SB), NOSPLIT, $0-24
- MOVQ a+0(FP), DI
- MOVQ b+8(FP), SI
- MOVQ n+16(FP), CX
- // Accumulator ymm0 = 0
- VXORPS Y0, Y0, Y0
- TESTQ CX, CX
- JLE dot_zero
- // Process 8 floats per iteration
- loop8:
- CMPQ CX, $8
- JL fold
- VMOVUPS (DI), Y1
- VMOVUPS (SI), Y2
- VFMADD231PS Y1, Y2, Y0 // Y0 += Y1 * Y2
- ADDQ $32, DI
- ADDQ $32, SI
- SUBQ $8, CX
- JMP loop8
- // Fold ymm0 upper half into xmm0 before handling tails.
- fold:
- VEXTRACTF128 $1, Y0, X1
- VADDPS X1, X0, X0
- // Scalar tail
- loop1:
- CMPQ CX, $4
- JL loop_scalar
- VMOVUPS (DI), X1
- VMOVUPS (SI), X2
- VFMADD231PS X1, X2, X0
- ADDQ $16, DI
- ADDQ $16, SI
- SUBQ $4, CX
- JMP loop1
- loop_scalar:
- TESTQ CX, CX
- JE reduce4
- MOVSS (DI), X1
- MOVSS (SI), X2
- VFMADD231SS X1, X2, X0
- ADDQ $4, DI
- ADDQ $4, SI
- DECQ CX
- JMP loop_scalar
- // Horizontal sum of xmm0 (4 lanes) to scalar
- reduce4:
- VMOVHLPS X0, X0, X1
- VADDPS X1, X0, X0
- VPSHUFD $0xB1, X0, X1
- VADDPS X1, X0, X0
- MOVSS X0, ret+24(FP)
- VZEROUPPER
- RET
- dot_zero:
- VXORPS X0, X0, X0
- MOVSS X0, ret+24(FP)
- RET
- // func axpyAVX2(alpha float32, x *float32, y *float32, n int)
- TEXT ·axpyAVX2(SB), NOSPLIT, $0-28
- MOVSS alpha+0(FP), X0
- VBROADCASTSS X0, Y0
- MOVQ x+8(FP), DI
- MOVQ y+16(FP), SI
- MOVQ n+24(FP), CX
- TESTQ CX, CX
- JLE axpy_done
- axpy_loop8:
- CMPQ CX, $8
- JL axpy_loop1
- VMOVUPS (DI), Y1
- VMOVUPS (SI), Y2
- VFMADD231PS Y0, Y1, Y2
- VMOVUPS Y2, (SI)
- ADDQ $32, DI
- ADDQ $32, SI
- SUBQ $8, CX
- JMP axpy_loop8
- axpy_loop1:
- TESTQ CX, CX
- JE axpy_done
- MOVSS (DI), X1
- MOVSS (SI), X2
- VFMADD231SS X0, X1, X2
- MOVSS X2, (SI)
- ADDQ $4, DI
- ADDQ $4, SI
- DECQ CX
- JMP axpy_loop1
- axpy_done:
- RET
|