| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158 |
- //go:build amd64
- // +build amd64
- #include "textflag.h"
- // func softmaxMaxAVX512Asm(x *float32, n int) float32
- TEXT ·softmaxMaxAVX512Asm(SB), NOSPLIT, $0-24
- MOVQ x+0(FP), DI
- MOVQ n+8(FP), CX
- VBROADCASTSS ·negInfConst(SB), Z0
- loop_max:
- CMPQ CX, $16
- JL max_reduce
- VMOVUPS (DI), Z1
- VMAXPS Z1, Z0, Z0
- ADDQ $64, DI
- SUBQ $16, CX
- JMP loop_max
- max_reduce:
- VEXTRACTF32X8 $1, Z0, Y1
- VMAXPS Y1, Y0, Y0
- VEXTRACTF128 $1, Y0, X1
- VMAXPS X1, X0, X0
- VPERMILPS $0x4E, X0, X1
- VMAXPS X1, X0, X0
- VPERMILPS $0xB1, X0, X1
- VMAXPS X1, X0, X0
- // tail
- TESTQ CX, CX
- JE max_done
- max_tail:
- VMOVSS (DI), X1
- VMAXSS X1, X0, X0
- ADDQ $4, DI
- DECQ CX
- JNZ max_tail
- max_done:
- MOVSS X0, ret+16(FP)
- RET
- // func softmaxExpSumAVX512Asm(x *float32, n int, max float32) float32
- TEXT ·softmaxExpSumAVX512Asm(SB), NOSPLIT, $0-28
- MOVQ x+0(FP), DI
- MOVQ n+8(FP), CX
- MOVSS max+16(FP), X5
- VBROADCASTSS X5, Z5 // max
- VBROADCASTSS ·expHi(SB), Z14
- VBROADCASTSS ·expLo(SB), Z13
- VBROADCASTSS ·log2EF(SB), Z12
- VBROADCASTSS ·halfConst(SB), Z11
- VBROADCASTSS ·expC1(SB), Z10
- VBROADCASTSS ·expC2(SB), Z9
- VBROADCASTSS ·oneConst(SB), Z8
- VPBROADCASTD ·expBiasConst(SB), Z15
- VXORPS Z7, Z7, Z7 // sum accumulator
- loop_exp:
- CMPQ CX, $16
- JL exp_reduce
- VMOVUPS (DI), Z1
- VSUBPS Z5, Z1, Z1 // x - max
- VMINPS Z14, Z1, Z1
- VMAXPS Z13, Z1, Z1
- VMULPS Z12, Z1, Z2
- VADDPS Z11, Z2, Z2
- VRNDSCALEPS $1, Z2, Z2
- VCVTPS2DQ Z2, Z6
- VCVTDQ2PS Z6, Z3
- VMULPS Z10, Z3, Z4
- VSUBPS Z4, Z1, Z1
- VMULPS Z9, Z3, Z4
- VSUBPS Z4, Z1, Z1
- VMULPS Z1, Z1, Z4 // z = x*x
- VBROADCASTSS ·polyP0(SB), Z0
- VMULPS Z1, Z0, Z0
- VBROADCASTSS ·polyP1(SB), Z3
- VADDPS Z3, Z0, Z0
- VMULPS Z1, Z0, Z0
- VBROADCASTSS ·polyP2(SB), Z3
- VADDPS Z3, Z0, Z0
- VMULPS Z1, Z0, Z0
- VBROADCASTSS ·polyP3(SB), Z3
- VADDPS Z3, Z0, Z0
- VMULPS Z1, Z0, Z0
- VBROADCASTSS ·polyP4(SB), Z3
- VADDPS Z3, Z0, Z0
- VMULPS Z1, Z0, Z0
- VBROADCASTSS ·polyP5(SB), Z3
- VADDPS Z3, Z0, Z0
- VMULPS Z4, Z0, Z0 // y *= z
- VADDPS Z1, Z0, Z0 // y += x
- VADDPS Z8, Z0, Z0 // y += 1
- VPADDD Z15, Z6, Z6
- VPSLLD $23, Z6, Z6
- VMULPS Z6, Z0, Z0 // exp(x - max)
- VADDPS Z0, Z7, Z7
- VMOVUPS Z0, (DI)
- ADDQ $64, DI
- SUBQ $16, CX
- JMP loop_exp
- exp_reduce:
- VEXTRACTF32X8 $1, Z7, Y1
- VADDPS Y1, Y0, Y0
- VEXTRACTF128 $1, Y0, X1
- VADDPS X1, X0, X0
- VHADDPS X0, X0, X0
- VHADDPS X0, X0, X0
- MOVSS X0, ret+24(FP)
- RET
- // func softmaxScaleAVX512Asm(x *float32, n int, inv float32)
- TEXT ·softmaxScaleAVX512Asm(SB), NOSPLIT, $0-24
- MOVQ x+0(FP), DI
- MOVQ n+8(FP), CX
- MOVSS inv+16(FP), X1
- VBROADCASTSS X1, Z1
- loop_scale:
- CMPQ CX, $16
- JL scale_tail
- VMULPS (DI), Z1, Z0
- VMOVUPS Z0, (DI)
- ADDQ $64, DI
- SUBQ $16, CX
- JMP loop_scale
- scale_tail:
- TESTQ CX, CX
- JE scale_done
- scale_tail_loop:
- MOVSS (DI), X0
- VMULSS X1, X0, X0
- MOVSS X0, (DI)
- ADDQ $4, DI
- DECQ CX
- JNZ scale_tail_loop
- scale_done:
- RET
|