| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160 |
- //go:build amd64
- // +build amd64
- #include "textflag.h"
- // func softmaxMaxAVX2Asm(x *float32, n int) float32
- TEXT ·softmaxMaxAVX2Asm(SB), NOSPLIT, $0-24
- MOVQ x+0(FP), DI
- MOVQ n+8(FP), CX
- VBROADCASTSS ·negInfConst(SB), Y0
- loop_max:
- CMPQ CX, $8
- JL max_reduce
- VMOVUPS (DI), Y1
- VMAXPS Y1, Y0, Y0
- ADDQ $32, DI
- SUBQ $8, CX
- JMP loop_max
- max_reduce:
- VEXTRACTF128 $0, Y0, X0
- VEXTRACTF128 $1, Y0, X1
- VMAXPS X1, X0, X0
- VPERMILPS $0x4E, X0, X1
- VMAXPS X1, X0, X0
- VPERMILPS $0xB1, X0, X1
- VMAXPS X1, X0, X0
- // tail
- TESTQ CX, CX
- JE max_done
- max_tail:
- VMOVSS (DI), X1
- VMAXSS X1, X0, X0
- ADDQ $4, DI
- DECQ CX
- JNZ max_tail
- max_done:
- MOVSS X0, ret+16(FP)
- RET
- // func softmaxExpSumAVX2Asm(x *float32, n int, max float32) float32
- TEXT ·softmaxExpSumAVX2Asm(SB), NOSPLIT, $0-28
- MOVQ x+0(FP), DI
- MOVQ n+8(FP), CX
- MOVSS max+16(FP), X5
- VBROADCASTSS X5, Y5 // max
- VBROADCASTSS ·expHi(SB), Y14
- VBROADCASTSS ·expLo(SB), Y13
- VBROADCASTSS ·log2EF(SB), Y12
- VBROADCASTSS ·halfConst(SB), Y11
- VBROADCASTSS ·expC1(SB), Y10
- VBROADCASTSS ·expC2(SB), Y9
- VBROADCASTSS ·oneConst(SB), Y8
- VPBROADCASTD ·expBiasConst(SB), Y15
- VXORPS Y7, Y7, Y7 // sum accumulator
- loop_exp:
- CMPQ CX, $8
- JL exp_reduce
- VMOVUPS (DI), Y1
- VSUBPS Y5, Y1, Y1 // x - max
- VMINPS Y14, Y1, Y1
- VMAXPS Y13, Y1, Y1
- VMULPS Y12, Y1, Y2
- VADDPS Y11, Y2, Y2
- VROUNDPS $1, Y2, Y2
- VCVTPS2DQ Y2, Y6
- VCVTDQ2PS Y6, Y3
- VMULPS Y10, Y3, Y4
- VSUBPS Y4, Y1, Y1
- VMULPS Y9, Y3, Y4
- VSUBPS Y4, Y1, Y1
- VMULPS Y1, Y1, Y4 // z = x*x
- VBROADCASTSS ·polyP0(SB), Y0
- VMULPS Y1, Y0, Y0
- VBROADCASTSS ·polyP1(SB), Y3
- VADDPS Y3, Y0, Y0
- VMULPS Y1, Y0, Y0
- VBROADCASTSS ·polyP2(SB), Y3
- VADDPS Y3, Y0, Y0
- VMULPS Y1, Y0, Y0
- VBROADCASTSS ·polyP3(SB), Y3
- VADDPS Y3, Y0, Y0
- VMULPS Y1, Y0, Y0
- VBROADCASTSS ·polyP4(SB), Y3
- VADDPS Y3, Y0, Y0
- VMULPS Y1, Y0, Y0
- VBROADCASTSS ·polyP5(SB), Y3
- VADDPS Y3, Y0, Y0
- VMULPS Y4, Y0, Y0 // y *= z
- VADDPS Y1, Y0, Y0 // y += x
- VADDPS Y8, Y0, Y0 // y += 1
- VPADDD Y15, Y6, Y6
- VPSLLD $23, Y6, Y6
- VMULPS Y6, Y0, Y0 // exp(x - max)
- VADDPS Y0, Y7, Y7
- VMOVUPS Y0, (DI)
- ADDQ $32, DI
- SUBQ $8, CX
- JMP loop_exp
- exp_reduce:
- VEXTRACTF128 $0, Y7, X0
- VEXTRACTF128 $1, Y7, X1
- VADDPS X1, X0, X0
- VHADDPS X0, X0, X0
- VHADDPS X0, X0, X0
- MOVSS X0, ret+24(FP)
- RET
- // func softmaxScaleAVX2Asm(x *float32, n int, inv float32)
- TEXT ·softmaxScaleAVX2Asm(SB), NOSPLIT, $0-24
- MOVQ x+0(FP), DI
- MOVQ n+8(FP), CX
- MOVSS inv+16(FP), X1
- VBROADCASTSS X1, Y1
- loop_scale:
- CMPQ CX, $8
- JL scale_tail
- VMULPS (DI), Y1, Y0
- VMOVUPS Y0, (DI)
- ADDQ $32, DI
- SUBQ $8, CX
- JMP loop_scale
- scale_tail:
- TESTQ CX, CX
- JE scale_done
- scale_tail_loop:
- MOVSS (DI), X0
- VMULSS X1, X0, X0
- MOVSS X0, (DI)
- ADDQ $4, DI
- DECQ CX
- JNZ scale_tail_loop
- scale_done:
- RET
- // Additional constants
- DATA ·negInfConst+0(SB)/4, $0xff800000
- GLOBL ·negInfConst(SB), RODATA, $4
|