//go:build amd64 // +build amd64 #include "textflag.h" // func softmaxMaxAVX2Asm(x *float32, n int) float32 TEXT ·softmaxMaxAVX2Asm(SB), NOSPLIT, $0-24 MOVQ x+0(FP), DI MOVQ n+8(FP), CX VBROADCASTSS ·negInfConst(SB), Y0 loop_max: CMPQ CX, $8 JL max_reduce VMOVUPS (DI), Y1 VMAXPS Y1, Y0, Y0 ADDQ $32, DI SUBQ $8, CX JMP loop_max max_reduce: VEXTRACTF128 $0, Y0, X0 VEXTRACTF128 $1, Y0, X1 VMAXPS X1, X0, X0 VPERMILPS $0x4E, X0, X1 VMAXPS X1, X0, X0 VPERMILPS $0xB1, X0, X1 VMAXPS X1, X0, X0 // tail TESTQ CX, CX JE max_done max_tail: VMOVSS (DI), X1 VMAXSS X1, X0, X0 ADDQ $4, DI DECQ CX JNZ max_tail max_done: MOVSS X0, ret+16(FP) RET // func softmaxExpSumAVX2Asm(x *float32, n int, max float32) float32 TEXT ·softmaxExpSumAVX2Asm(SB), NOSPLIT, $0-28 MOVQ x+0(FP), DI MOVQ n+8(FP), CX MOVSS max+16(FP), X5 VBROADCASTSS X5, Y5 // max VBROADCASTSS ·expHi(SB), Y14 VBROADCASTSS ·expLo(SB), Y13 VBROADCASTSS ·log2EF(SB), Y12 VBROADCASTSS ·halfConst(SB), Y11 VBROADCASTSS ·expC1(SB), Y10 VBROADCASTSS ·expC2(SB), Y9 VBROADCASTSS ·oneConst(SB), Y8 VPBROADCASTD ·expBiasConst(SB), Y15 VXORPS Y7, Y7, Y7 // sum accumulator loop_exp: CMPQ CX, $8 JL exp_reduce VMOVUPS (DI), Y1 VSUBPS Y5, Y1, Y1 // x - max VMINPS Y14, Y1, Y1 VMAXPS Y13, Y1, Y1 VMULPS Y12, Y1, Y2 VADDPS Y11, Y2, Y2 VROUNDPS $1, Y2, Y2 VCVTPS2DQ Y2, Y6 VCVTDQ2PS Y6, Y3 VMULPS Y10, Y3, Y4 VSUBPS Y4, Y1, Y1 VMULPS Y9, Y3, Y4 VSUBPS Y4, Y1, Y1 VMULPS Y1, Y1, Y4 // z = x*x VBROADCASTSS ·polyP0(SB), Y0 VMULPS Y1, Y0, Y0 VBROADCASTSS ·polyP1(SB), Y3 VADDPS Y3, Y0, Y0 VMULPS Y1, Y0, Y0 VBROADCASTSS ·polyP2(SB), Y3 VADDPS Y3, Y0, Y0 VMULPS Y1, Y0, Y0 VBROADCASTSS ·polyP3(SB), Y3 VADDPS Y3, Y0, Y0 VMULPS Y1, Y0, Y0 VBROADCASTSS ·polyP4(SB), Y3 VADDPS Y3, Y0, Y0 VMULPS Y1, Y0, Y0 VBROADCASTSS ·polyP5(SB), Y3 VADDPS Y3, Y0, Y0 VMULPS Y4, Y0, Y0 // y *= z VADDPS Y1, Y0, Y0 // y += x VADDPS Y8, Y0, Y0 // y += 1 VPADDD Y15, Y6, Y6 VPSLLD $23, Y6, Y6 VMULPS Y6, Y0, Y0 // exp(x - max) VADDPS Y0, Y7, Y7 VMOVUPS Y0, (DI) ADDQ $32, DI SUBQ $8, CX JMP loop_exp exp_reduce: VEXTRACTF128 $0, Y7, X0 VEXTRACTF128 $1, Y7, X1 VADDPS X1, X0, X0 VHADDPS X0, X0, X0 VHADDPS X0, X0, X0 MOVSS X0, ret+24(FP) RET // func softmaxScaleAVX2Asm(x *float32, n int, inv float32) TEXT ·softmaxScaleAVX2Asm(SB), NOSPLIT, $0-24 MOVQ x+0(FP), DI MOVQ n+8(FP), CX MOVSS inv+16(FP), X1 VBROADCASTSS X1, Y1 loop_scale: CMPQ CX, $8 JL scale_tail VMULPS (DI), Y1, Y0 VMOVUPS Y0, (DI) ADDQ $32, DI SUBQ $8, CX JMP loop_scale scale_tail: TESTQ CX, CX JE scale_done scale_tail_loop: MOVSS (DI), X0 VMULSS X1, X0, X0 MOVSS X0, (DI) ADDQ $4, DI DECQ CX JNZ scale_tail_loop scale_done: RET // Additional constants DATA ·negInfConst+0(SB)/4, $0xff800000 GLOBL ·negInfConst(SB), RODATA, $4