1
0

mlp.go 718 B

1234567891011121314151617181920212223242526272829303132
  1. package nn
  2. import (
  3. "makarna/pkg/backend/cpu"
  4. "makarna/pkg/backend/cpu/matmul"
  5. "makarna/pkg/backend/cpu/ops"
  6. "makarna/pkg/tensor"
  7. )
  8. // SwiGLUMLP computes the SwiGLU MLP block:
  9. // gate = x @ wGate
  10. // up = x @ wUp
  11. // hidden = SwiGLU(gate, up)
  12. // out = hidden @ wDown
  13. func SwiGLUMLP(x, wGate, wUp, wDown *cpu.Tensor) *cpu.Tensor {
  14. seqLen := x.Shape()[0]
  15. intermediate := wGate.Shape()[0]
  16. hiddenSize := wDown.Shape()[0]
  17. gate := ops.Zeros(tensor.Shape{seqLen, intermediate})
  18. up := ops.Zeros(tensor.Shape{seqLen, intermediate})
  19. matmul.Linear(x, wGate, gate)
  20. matmul.Linear(x, wUp, up)
  21. SwiGLU(gate, up, gate)
  22. out := ops.Zeros(tensor.Shape{seqLen, hiddenSize})
  23. matmul.Linear(gate, wDown, out)
  24. return out
  25. }