softmax.go 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. package nn
  2. import (
  3. "errors"
  4. "math"
  5. "makarna/pkg/backend/cpu"
  6. )
  7. // Softmax applies softmax normalization in-place.
  8. // Dispatch order: AVX-512 -> AVX2 -> scalar fallback.
  9. func Softmax(x *cpu.Tensor) error {
  10. data := x.DataFloat32()
  11. if len(data) == 0 {
  12. return nil
  13. }
  14. softmaxInplace(data)
  15. var sum float32
  16. for _, v := range data {
  17. if v < 0 || math.IsNaN(float64(v)) || math.IsInf(float64(v), 0) {
  18. return errors.New("softmax produced invalid value")
  19. }
  20. sum += v
  21. }
  22. if sum <= 0 || math.IsNaN(float64(sum)) || math.IsInf(float64(sum), 0) {
  23. return errors.New("softmax produced invalid sum")
  24. }
  25. d := float32(math.Abs(float64(sum - 1)))
  26. if d > 1e-3 {
  27. inv := 1 / sum
  28. for i := range data {
  29. data[i] *= inv
  30. }
  31. }
  32. return nil
  33. }
  34. func softmaxInplace(data []float32) {
  35. if len(data) == 0 {
  36. return
  37. }
  38. switch {
  39. case hasSoftmaxAVX512 && cpu.SupportsAVX512() && len(data) >= 16:
  40. softmaxAVX512(data)
  41. case hasSoftmaxAVX2 && cpu.SupportsAVX2() && len(data) >= 8:
  42. softmaxAVX2(data)
  43. default:
  44. softmaxScalar(data)
  45. }
  46. }
  47. func softmaxScalar(data []float32) {
  48. maxVal := float32(-math.MaxFloat32)
  49. for _, v := range data {
  50. if v > maxVal {
  51. maxVal = v
  52. }
  53. }
  54. var sum float32
  55. for i, v := range data {
  56. ev := float32(math.Exp(float64(v - maxVal)))
  57. data[i] = ev
  58. sum += ev
  59. }
  60. inv := 1.0 / sum
  61. for i := range data {
  62. data[i] *= inv
  63. }
  64. }