package nn import ( "makarna/pkg/backend/cpu" "makarna/pkg/backend/cpu/ops" ) // TransformerBlock applies a standard pre-norm transformer block: // 1. RMSNorm → Attention → Add residual // 2. RMSNorm → MLP → Add residual // // attnFn: function that computes attention given normalized input // mlpFn: function that computes MLP given normalized input func TransformerBlock( hidden *cpu.Tensor, attnNorm, mlpNorm *cpu.Tensor, eps float32, attnFn func(*cpu.Tensor) *cpu.Tensor, mlpFn func(*cpu.Tensor) *cpu.Tensor, ) *cpu.Tensor { // Save residual residual := ops.Zeros(hidden.Shape()) ops.Copy(residual, hidden) // Attention sub-block RMSNorm(hidden, attnNorm, eps) attnOut := attnFn(hidden) ops.Add(residual, attnOut) ops.Copy(hidden, residual) // MLP sub-block ops.Copy(residual, hidden) RMSNorm(hidden, mlpNorm, eps) mlpOut := mlpFn(hidden) ops.Add(residual, mlpOut) ops.Copy(hidden, residual) return hidden }