package nn import ( "errors" "math" "makarna/pkg/backend/cpu" ) // Softmax applies softmax normalization in-place. // Dispatch order: AVX-512 -> AVX2 -> scalar fallback. func Softmax(x *cpu.Tensor) error { data := x.DataFloat32() if len(data) == 0 { return nil } softmaxInplace(data) var sum float32 for _, v := range data { if v < 0 || math.IsNaN(float64(v)) || math.IsInf(float64(v), 0) { return errors.New("softmax produced invalid value") } sum += v } if sum <= 0 || math.IsNaN(float64(sum)) || math.IsInf(float64(sum), 0) { return errors.New("softmax produced invalid sum") } d := float32(math.Abs(float64(sum - 1))) if d > 1e-3 { inv := 1 / sum for i := range data { data[i] *= inv } } return nil } func softmaxInplace(data []float32) { if len(data) == 0 { return } switch { case hasSoftmaxAVX512 && cpu.SupportsAVX512() && len(data) >= 16: softmaxAVX512(data) case hasSoftmaxAVX2 && cpu.SupportsAVX2() && len(data) >= 8: softmaxAVX2(data) default: softmaxScalar(data) } } func softmaxScalar(data []float32) { maxVal := float32(-math.MaxFloat32) for _, v := range data { if v > maxVal { maxVal = v } } var sum float32 for i, v := range data { ev := float32(math.Exp(float64(v - maxVal))) data[i] = ev sum += ev } inv := 1.0 / sum for i := range data { data[i] *= inv } }