| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172 |
- package nn
- import (
- "errors"
- "math"
- "makarna/pkg/backend/cpu"
- )
- // Softmax applies softmax normalization in-place.
- // Dispatch order: AVX-512 -> AVX2 -> scalar fallback.
- func Softmax(x *cpu.Tensor) error {
- data := x.DataFloat32()
- if len(data) == 0 {
- return nil
- }
- softmaxInplace(data)
- var sum float32
- for _, v := range data {
- if v < 0 || math.IsNaN(float64(v)) || math.IsInf(float64(v), 0) {
- return errors.New("softmax produced invalid value")
- }
- sum += v
- }
- if sum <= 0 || math.IsNaN(float64(sum)) || math.IsInf(float64(sum), 0) {
- return errors.New("softmax produced invalid sum")
- }
- d := float32(math.Abs(float64(sum - 1)))
- if d > 1e-3 {
- inv := 1 / sum
- for i := range data {
- data[i] *= inv
- }
- }
- return nil
- }
- func softmaxInplace(data []float32) {
- if len(data) == 0 {
- return
- }
- switch {
- case hasSoftmaxAVX512 && cpu.SupportsAVX512() && len(data) >= 16:
- softmaxAVX512(data)
- case hasSoftmaxAVX2 && cpu.SupportsAVX2() && len(data) >= 8:
- softmaxAVX2(data)
- default:
- softmaxScalar(data)
- }
- }
- func softmaxScalar(data []float32) {
- maxVal := float32(-math.MaxFloat32)
- for _, v := range data {
- if v > maxVal {
- maxVal = v
- }
- }
- var sum float32
- for i, v := range data {
- ev := float32(math.Exp(float64(v - maxVal)))
- data[i] = ev
- sum += ev
- }
- inv := 1.0 / sum
- for i := range data {
- data[i] *= inv
- }
- }
|