1
0

transformer.go 947 B

123456789101112131415161718192021222324252627282930313233343536373839
  1. package nn
  2. import (
  3. "makarna/pkg/backend/cpu"
  4. "makarna/pkg/backend/cpu/ops"
  5. )
  6. // TransformerBlock applies a standard pre-norm transformer block:
  7. // 1. RMSNorm → Attention → Add residual
  8. // 2. RMSNorm → MLP → Add residual
  9. //
  10. // attnFn: function that computes attention given normalized input
  11. // mlpFn: function that computes MLP given normalized input
  12. func TransformerBlock(
  13. hidden *cpu.Tensor,
  14. attnNorm, mlpNorm *cpu.Tensor,
  15. eps float32,
  16. attnFn func(*cpu.Tensor) *cpu.Tensor,
  17. mlpFn func(*cpu.Tensor) *cpu.Tensor,
  18. ) *cpu.Tensor {
  19. // Save residual
  20. residual := ops.Zeros(hidden.Shape())
  21. ops.Copy(residual, hidden)
  22. // Attention sub-block
  23. RMSNorm(hidden, attnNorm, eps)
  24. attnOut := attnFn(hidden)
  25. ops.Add(residual, attnOut)
  26. ops.Copy(hidden, residual)
  27. // MLP sub-block
  28. ops.Copy(residual, hidden)
  29. RMSNorm(hidden, mlpNorm, eps)
  30. mlpOut := mlpFn(hidden)
  31. ops.Add(residual, mlpOut)
  32. ops.Copy(hidden, residual)
  33. return hidden
  34. }