| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394 |
- //go:build amd64
- package nn
- import "math"
- const (
- hasSoftmaxAVX2 = true
- hasSoftmaxAVX512 = true
- )
- //go:noescape
- func softmaxMaxAVX2Asm(x *float32, n int) float32
- //go:noescape
- func softmaxExpSumAVX2Asm(x *float32, n int, max float32) float32
- //go:noescape
- func softmaxScaleAVX2Asm(x *float32, n int, inv float32)
- //go:noescape
- func softmaxMaxAVX512Asm(x *float32, n int) float32
- //go:noescape
- func softmaxExpSumAVX512Asm(x *float32, n int, max float32) float32
- //go:noescape
- func softmaxScaleAVX512Asm(x *float32, n int, inv float32)
- func softmaxAVX2(data []float32) {
- n := len(data)
- main := n &^ 7 // process multiples of 8
- maxVal := float32(-math.MaxFloat32)
- if main > 0 {
- maxVal = softmaxMaxAVX2Asm(&data[0], main)
- }
- for i := main; i < n; i++ {
- if data[i] > maxVal {
- maxVal = data[i]
- }
- }
- var sum float32
- if main > 0 {
- sum += softmaxExpSumAVX2Asm(&data[0], main, maxVal)
- }
- for i := main; i < n; i++ {
- ev := float32(math.Exp(float64(data[i] - maxVal)))
- data[i] = ev
- sum += ev
- }
- inv := 1.0 / sum
- if main > 0 {
- softmaxScaleAVX2Asm(&data[0], main, inv)
- }
- for i := main; i < n; i++ {
- data[i] *= inv
- }
- }
- func softmaxAVX512(data []float32) {
- n := len(data)
- main := n &^ 15 // process multiples of 16
- maxVal := float32(-math.MaxFloat32)
- if main > 0 {
- maxVal = softmaxMaxAVX512Asm(&data[0], main)
- }
- for i := main; i < n; i++ {
- if data[i] > maxVal {
- maxVal = data[i]
- }
- }
- var sum float32
- if main > 0 {
- sum += softmaxExpSumAVX512Asm(&data[0], main, maxVal)
- }
- for i := main; i < n; i++ {
- ev := float32(math.Exp(float64(data[i] - maxVal)))
- data[i] = ev
- sum += ev
- }
- inv := 1.0 / sum
- if main > 0 {
- softmaxScaleAVX512Asm(&data[0], main, inv)
- }
- for i := main; i < n; i++ {
- data[i] *= inv
- }
- }
|