| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114 |
- //go:build amd64
- // +build amd64
- #include "textflag.h"
- // func dotAVX512(a *float32, b *float32, n int) float32
- TEXT ·dotAVX512(SB), NOSPLIT, $0-24
- // Load args
- MOVQ a+0(FP), DI
- MOVQ b+8(FP), SI
- MOVQ n+16(FP), CX
- // Accumulator
- VXORPS Z0, Z0, Z0
- // If n <= 0 return 0
- TESTQ CX, CX
- JLE dot512_zero
- // Process 16 floats per iteration
- loop16:
- CMPQ CX, $16
- JL fold512
- VMOVUPS (DI), Z1
- VMOVUPS (SI), Z2
- VFMADD231PS Z1, Z2, Z0 // Z0 += Z1 * Z2
- ADDQ $64, DI
- ADDQ $64, SI
- SUBQ $16, CX
- JMP loop16
- // Fold zmm0 -> xmm0 before handling tails.
- fold512:
- VEXTRACTF32X8 $1, Z0, Y1
- VADDPS Y1, Y0, Y0
- VEXTRACTF128 $1, Y0, X1
- VADDPS X1, X0, X0
- // Scalar tail
- loop1:
- CMPQ CX, $4
- JL loop1_scalar
- VMOVUPS (DI), X1
- VMOVUPS (SI), X2
- VFMADD231PS X1, X2, X0
- ADDQ $16, DI
- ADDQ $16, SI
- SUBQ $4, CX
- JMP loop1
- loop1_scalar:
- TESTQ CX, CX
- JE reduce4
- MOVSS (DI), X1
- MOVSS (SI), X2
- VFMADD231SS X1, X2, X0
- ADDQ $4, DI
- ADDQ $4, SI
- DECQ CX
- JMP loop1_scalar
- // Horizontal sum of xmm0 (4 lanes) to scalar.
- reduce4:
- VMOVHLPS X0, X0, X1
- VADDPS X1, X0, X0
- VPSHUFD $0xB1, X0, X1
- VADDPS X1, X0, X0
- MOVSS X0, ret+24(FP)
- VZEROUPPER
- RET
- dot512_zero:
- VXORPS X0, X0, X0
- MOVSS X0, ret+24(FP)
- RET
- // func axpyAVX512(alpha float32, x *float32, y *float32, n int)
- TEXT ·axpyAVX512(SB), NOSPLIT, $0-28
- MOVSS alpha+0(FP), X0
- VBROADCASTSS X0, Z0
- MOVQ x+8(FP), DI
- MOVQ y+16(FP), SI
- MOVQ n+24(FP), CX
- TESTQ CX, CX
- JLE axpy512_done
- axpy512_loop16:
- CMPQ CX, $16
- JL axpy512_loop1
- VMOVUPS (DI), Z1
- VMOVUPS (SI), Z2
- VFMADD231PS Z0, Z1, Z2
- VMOVUPS Z2, (SI)
- ADDQ $64, DI
- ADDQ $64, SI
- SUBQ $16, CX
- JMP axpy512_loop16
- axpy512_loop1:
- TESTQ CX, CX
- JE axpy512_done
- MOVSS (DI), X1
- MOVSS (SI), X2
- VFMADD231SS X0, X1, X2
- MOVSS X2, (SI)
- ADDQ $4, DI
- ADDQ $4, SI
- DECQ CX
- JMP axpy512_loop1
- axpy512_done:
- RET
|