| 123456789101112131415161718192021222324252627282930313233343536373839 |
- package nn
- import (
- "makarna/pkg/backend/cpu"
- "makarna/pkg/backend/cpu/ops"
- )
- // TransformerBlock applies a standard pre-norm transformer block:
- // 1. RMSNorm → Attention → Add residual
- // 2. RMSNorm → MLP → Add residual
- //
- // attnFn: function that computes attention given normalized input
- // mlpFn: function that computes MLP given normalized input
- func TransformerBlock(
- hidden *cpu.Tensor,
- attnNorm, mlpNorm *cpu.Tensor,
- eps float32,
- attnFn func(*cpu.Tensor) *cpu.Tensor,
- mlpFn func(*cpu.Tensor) *cpu.Tensor,
- ) *cpu.Tensor {
- // Save residual
- residual := ops.Zeros(hidden.Shape())
- ops.Copy(residual, hidden)
- // Attention sub-block
- RMSNorm(hidden, attnNorm, eps)
- attnOut := attnFn(hidden)
- ops.Add(residual, attnOut)
- ops.Copy(hidden, residual)
- // MLP sub-block
- ops.Copy(residual, hidden)
- RMSNorm(hidden, mlpNorm, eps)
- mlpOut := mlpFn(hidden)
- ops.Add(residual, mlpOut)
- ops.Copy(hidden, residual)
- return hidden
- }
|