rmsnorm.go 720 B

1234567891011121314151617181920212223242526272829303132
  1. // Package nn provides neural network layer operations
  2. package nn
  3. import (
  4. "math"
  5. "makarna/pkg/backend/cpu"
  6. )
  7. // RMSNorm normalizes input in-place using RMS normalization
  8. // x: [batch, dim], w: [dim]
  9. func RMSNorm(x, w *cpu.Tensor, eps float32) error {
  10. xData := x.DataFloat32()
  11. wData := w.DataFloat32()
  12. dim := w.Shape().NumElements()
  13. numRows := x.Shape().NumElements() / dim
  14. for i := 0; i < numRows; i++ {
  15. row := xData[i*dim : (i+1)*dim]
  16. // Sum of squares (uses SIMD dot when available)
  17. ss := cpu.DotFloat32(row, row) / float32(dim)
  18. // Normalize and scale
  19. invRMS := 1.0 / float32(math.Sqrt(float64(ss+eps)))
  20. for j := 0; j < dim; j++ {
  21. row[j] = row[j] * invRMS * wData[j]
  22. }
  23. }
  24. return nil
  25. }