//go:build amd64 // +build amd64 #include "textflag.h" // func softmaxMaxAVX512Asm(x *float32, n int) float32 TEXT ·softmaxMaxAVX512Asm(SB), NOSPLIT, $0-24 MOVQ x+0(FP), DI MOVQ n+8(FP), CX VBROADCASTSS ·negInfConst(SB), Z0 loop_max: CMPQ CX, $16 JL max_reduce VMOVUPS (DI), Z1 VMAXPS Z1, Z0, Z0 ADDQ $64, DI SUBQ $16, CX JMP loop_max max_reduce: VEXTRACTF32X8 $1, Z0, Y1 VMAXPS Y1, Y0, Y0 VEXTRACTF128 $1, Y0, X1 VMAXPS X1, X0, X0 VPERMILPS $0x4E, X0, X1 VMAXPS X1, X0, X0 VPERMILPS $0xB1, X0, X1 VMAXPS X1, X0, X0 // tail TESTQ CX, CX JE max_done max_tail: VMOVSS (DI), X1 VMAXSS X1, X0, X0 ADDQ $4, DI DECQ CX JNZ max_tail max_done: MOVSS X0, ret+16(FP) RET // func softmaxExpSumAVX512Asm(x *float32, n int, max float32) float32 TEXT ·softmaxExpSumAVX512Asm(SB), NOSPLIT, $0-28 MOVQ x+0(FP), DI MOVQ n+8(FP), CX MOVSS max+16(FP), X5 VBROADCASTSS X5, Z5 // max VBROADCASTSS ·expHi(SB), Z14 VBROADCASTSS ·expLo(SB), Z13 VBROADCASTSS ·log2EF(SB), Z12 VBROADCASTSS ·halfConst(SB), Z11 VBROADCASTSS ·expC1(SB), Z10 VBROADCASTSS ·expC2(SB), Z9 VBROADCASTSS ·oneConst(SB), Z8 VPBROADCASTD ·expBiasConst(SB), Z15 VXORPS Z7, Z7, Z7 // sum accumulator loop_exp: CMPQ CX, $16 JL exp_reduce VMOVUPS (DI), Z1 VSUBPS Z5, Z1, Z1 // x - max VMINPS Z14, Z1, Z1 VMAXPS Z13, Z1, Z1 VMULPS Z12, Z1, Z2 VADDPS Z11, Z2, Z2 VRNDSCALEPS $1, Z2, Z2 VCVTPS2DQ Z2, Z6 VCVTDQ2PS Z6, Z3 VMULPS Z10, Z3, Z4 VSUBPS Z4, Z1, Z1 VMULPS Z9, Z3, Z4 VSUBPS Z4, Z1, Z1 VMULPS Z1, Z1, Z4 // z = x*x VBROADCASTSS ·polyP0(SB), Z0 VMULPS Z1, Z0, Z0 VBROADCASTSS ·polyP1(SB), Z3 VADDPS Z3, Z0, Z0 VMULPS Z1, Z0, Z0 VBROADCASTSS ·polyP2(SB), Z3 VADDPS Z3, Z0, Z0 VMULPS Z1, Z0, Z0 VBROADCASTSS ·polyP3(SB), Z3 VADDPS Z3, Z0, Z0 VMULPS Z1, Z0, Z0 VBROADCASTSS ·polyP4(SB), Z3 VADDPS Z3, Z0, Z0 VMULPS Z1, Z0, Z0 VBROADCASTSS ·polyP5(SB), Z3 VADDPS Z3, Z0, Z0 VMULPS Z4, Z0, Z0 // y *= z VADDPS Z1, Z0, Z0 // y += x VADDPS Z8, Z0, Z0 // y += 1 VPADDD Z15, Z6, Z6 VPSLLD $23, Z6, Z6 VMULPS Z6, Z0, Z0 // exp(x - max) VADDPS Z0, Z7, Z7 VMOVUPS Z0, (DI) ADDQ $64, DI SUBQ $16, CX JMP loop_exp exp_reduce: VEXTRACTF32X8 $1, Z7, Y1 VADDPS Y1, Y0, Y0 VEXTRACTF128 $1, Y0, X1 VADDPS X1, X0, X0 VHADDPS X0, X0, X0 VHADDPS X0, X0, X0 MOVSS X0, ret+24(FP) RET // func softmaxScaleAVX512Asm(x *float32, n int, inv float32) TEXT ·softmaxScaleAVX512Asm(SB), NOSPLIT, $0-24 MOVQ x+0(FP), DI MOVQ n+8(FP), CX MOVSS inv+16(FP), X1 VBROADCASTSS X1, Z1 loop_scale: CMPQ CX, $16 JL scale_tail VMULPS (DI), Z1, Z0 VMOVUPS Z0, (DI) ADDQ $64, DI SUBQ $16, CX JMP loop_scale scale_tail: TESTQ CX, CX JE scale_done scale_tail_loop: MOVSS (DI), X0 VMULSS X1, X0, X0 MOVSS X0, (DI) ADDQ $4, DI DECQ CX JNZ scale_tail_loop scale_done: RET