| 1234567891011121314151617181920212223242526272829303132 |
- package nn
- import (
- "makarna/pkg/backend/cpu"
- "makarna/pkg/backend/cpu/matmul"
- "makarna/pkg/backend/cpu/ops"
- "makarna/pkg/tensor"
- )
- // SwiGLUMLP computes the SwiGLU MLP block:
- // gate = x @ wGate
- // up = x @ wUp
- // hidden = SwiGLU(gate, up)
- // out = hidden @ wDown
- func SwiGLUMLP(x, wGate, wUp, wDown *cpu.Tensor) *cpu.Tensor {
- seqLen := x.Shape()[0]
- intermediate := wGate.Shape()[0]
- hiddenSize := wDown.Shape()[0]
- gate := ops.Zeros(tensor.Shape{seqLen, intermediate})
- up := ops.Zeros(tensor.Shape{seqLen, intermediate})
- matmul.Linear(x, wGate, gate)
- matmul.Linear(x, wUp, up)
- SwiGLU(gate, up, gate)
- out := ops.Zeros(tensor.Shape{seqLen, hiddenSize})
- matmul.Linear(gate, wDown, out)
- return out
- }
|