| 1234567891011121314151617181920212223242526272829303132 |
- // Package nn provides neural network layer operations
- package nn
- import (
- "math"
- "makarna/pkg/backend/cpu"
- )
- // RMSNorm normalizes input in-place using RMS normalization
- // x: [batch, dim], w: [dim]
- func RMSNorm(x, w *cpu.Tensor, eps float32) error {
- xData := x.DataFloat32()
- wData := w.DataFloat32()
- dim := w.Shape().NumElements()
- numRows := x.Shape().NumElements() / dim
- for i := 0; i < numRows; i++ {
- row := xData[i*dim : (i+1)*dim]
- // Sum of squares (uses SIMD dot when available)
- ss := cpu.DotFloat32(row, row) / float32(dim)
- // Normalize and scale
- invRMS := 1.0 / float32(math.Sqrt(float64(ss+eps)))
- for j := 0; j < dim; j++ {
- row[j] = row[j] * invRMS * wData[j]
- }
- }
- return nil
- }
|