//go:build amd64 package nn import "math" const ( hasSoftmaxAVX2 = true hasSoftmaxAVX512 = true ) //go:noescape func softmaxMaxAVX2Asm(x *float32, n int) float32 //go:noescape func softmaxExpSumAVX2Asm(x *float32, n int, max float32) float32 //go:noescape func softmaxScaleAVX2Asm(x *float32, n int, inv float32) //go:noescape func softmaxMaxAVX512Asm(x *float32, n int) float32 //go:noescape func softmaxExpSumAVX512Asm(x *float32, n int, max float32) float32 //go:noescape func softmaxScaleAVX512Asm(x *float32, n int, inv float32) func softmaxAVX2(data []float32) { n := len(data) main := n &^ 7 // process multiples of 8 maxVal := float32(-math.MaxFloat32) if main > 0 { maxVal = softmaxMaxAVX2Asm(&data[0], main) } for i := main; i < n; i++ { if data[i] > maxVal { maxVal = data[i] } } var sum float32 if main > 0 { sum += softmaxExpSumAVX2Asm(&data[0], main, maxVal) } for i := main; i < n; i++ { ev := float32(math.Exp(float64(data[i] - maxVal))) data[i] = ev sum += ev } inv := 1.0 / sum if main > 0 { softmaxScaleAVX2Asm(&data[0], main, inv) } for i := main; i < n; i++ { data[i] *= inv } } func softmaxAVX512(data []float32) { n := len(data) main := n &^ 15 // process multiples of 16 maxVal := float32(-math.MaxFloat32) if main > 0 { maxVal = softmaxMaxAVX512Asm(&data[0], main) } for i := main; i < n; i++ { if data[i] > maxVal { maxVal = data[i] } } var sum float32 if main > 0 { sum += softmaxExpSumAVX512Asm(&data[0], main, maxVal) } for i := main; i < n; i++ { ev := float32(math.Exp(float64(data[i] - maxVal))) data[i] = ev sum += ev } inv := 1.0 / sum if main > 0 { softmaxScaleAVX512Asm(&data[0], main, inv) } for i := main; i < n; i++ { data[i] *= inv } }