1
0

silu.go 1.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. package nn
  2. import (
  3. "math"
  4. "makarna/pkg/backend/cpu"
  5. )
  6. // SiLU applies x * sigmoid(x) in-place using the fastest available kernel.
  7. func SiLU(x *cpu.Tensor) error {
  8. siluInplace(x.DataFloat32())
  9. return nil
  10. }
  11. // SwiGLU: out = SiLU(gate) * up. Does not mutate gate.
  12. func SwiGLU(gate, up, out *cpu.Tensor) error {
  13. gData := gate.DataFloat32()
  14. uData := up.DataFloat32()
  15. oData := out.DataFloat32()
  16. if len(oData) == 0 {
  17. return nil
  18. }
  19. if &gData[0] != &oData[0] {
  20. copy(oData, gData)
  21. }
  22. siluInplace(oData)
  23. for i := range oData {
  24. oData[i] *= uData[i]
  25. }
  26. return nil
  27. }
  28. // siluInplace selects the SIMD kernel when available, falling back to scalar.
  29. func siluInplace(data []float32) {
  30. if len(data) == 0 {
  31. return
  32. }
  33. switch {
  34. case hasSiLUAVX512 && cpu.SupportsAVX512():
  35. main := len(data) &^ 15
  36. if main > 0 {
  37. siluAVX512Asm(&data[0], main)
  38. }
  39. if main == len(data) {
  40. return
  41. }
  42. data = data[main:]
  43. case hasSiLUAVX2 && cpu.SupportsAVX2():
  44. main := len(data) &^ 7
  45. if main > 0 {
  46. siluAVX2Asm(&data[0], main)
  47. }
  48. if main == len(data) {
  49. return
  50. }
  51. data = data[main:]
  52. }
  53. siluScalar(data)
  54. }
  55. func siluScalar(data []float32) {
  56. for i := range data {
  57. v := data[i]
  58. data[i] = v / (1.0 + float32(math.Exp(float64(-v))))
  59. }
  60. }