softmax_amd64.go 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. //go:build amd64
  2. package nn
  3. import "math"
  4. const (
  5. hasSoftmaxAVX2 = true
  6. hasSoftmaxAVX512 = true
  7. )
  8. //go:noescape
  9. func softmaxMaxAVX2Asm(x *float32, n int) float32
  10. //go:noescape
  11. func softmaxExpSumAVX2Asm(x *float32, n int, max float32) float32
  12. //go:noescape
  13. func softmaxScaleAVX2Asm(x *float32, n int, inv float32)
  14. //go:noescape
  15. func softmaxMaxAVX512Asm(x *float32, n int) float32
  16. //go:noescape
  17. func softmaxExpSumAVX512Asm(x *float32, n int, max float32) float32
  18. //go:noescape
  19. func softmaxScaleAVX512Asm(x *float32, n int, inv float32)
  20. func softmaxAVX2(data []float32) {
  21. n := len(data)
  22. main := n &^ 7 // process multiples of 8
  23. maxVal := float32(-math.MaxFloat32)
  24. if main > 0 {
  25. maxVal = softmaxMaxAVX2Asm(&data[0], main)
  26. }
  27. for i := main; i < n; i++ {
  28. if data[i] > maxVal {
  29. maxVal = data[i]
  30. }
  31. }
  32. var sum float32
  33. if main > 0 {
  34. sum += softmaxExpSumAVX2Asm(&data[0], main, maxVal)
  35. }
  36. for i := main; i < n; i++ {
  37. ev := float32(math.Exp(float64(data[i] - maxVal)))
  38. data[i] = ev
  39. sum += ev
  40. }
  41. inv := 1.0 / sum
  42. if main > 0 {
  43. softmaxScaleAVX2Asm(&data[0], main, inv)
  44. }
  45. for i := main; i < n; i++ {
  46. data[i] *= inv
  47. }
  48. }
  49. func softmaxAVX512(data []float32) {
  50. n := len(data)
  51. main := n &^ 15 // process multiples of 16
  52. maxVal := float32(-math.MaxFloat32)
  53. if main > 0 {
  54. maxVal = softmaxMaxAVX512Asm(&data[0], main)
  55. }
  56. for i := main; i < n; i++ {
  57. if data[i] > maxVal {
  58. maxVal = data[i]
  59. }
  60. }
  61. var sum float32
  62. if main > 0 {
  63. sum += softmaxExpSumAVX512Asm(&data[0], main, maxVal)
  64. }
  65. for i := main; i < n; i++ {
  66. ev := float32(math.Exp(float64(data[i] - maxVal)))
  67. data[i] = ev
  68. sum += ev
  69. }
  70. inv := 1.0 / sum
  71. if main > 0 {
  72. softmaxScaleAVX512Asm(&data[0], main, inv)
  73. }
  74. for i := main; i < n; i++ {
  75. data[i] *= inv
  76. }
  77. }