| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128 |
- package nn
- import (
- "fmt"
- "math"
- "makarna/pkg/backend/cpu"
- )
- type ActivationKind uint8
- const (
- ActivationNone ActivationKind = iota
- ActivationSiLU
- ActivationTanh
- ActivationReLU
- )
- func applyActivation(x float32, act ActivationKind) float32 {
- switch act {
- case ActivationNone:
- return x
- case ActivationSiLU:
- return x * Sigmoid(x)
- case ActivationReLU:
- if x < 0 {
- return 0
- }
- return x
- case ActivationTanh:
- return float32(math.Tanh(float64(x)))
- default:
- return x
- }
- }
- func FlattenConvWeights(w *cpu.Tensor, projSize int, kernel int) ([]float32, error) {
- if w == nil {
- return nil, fmt.Errorf("missing conv weights")
- }
- data := w.DataFloat32()
- expected := projSize * kernel
- shape := w.Shape()
- if shape.NumElements() < expected || len(data) < expected {
- return nil, fmt.Errorf("unexpected conv weight size %d", len(data))
- }
- if len(shape) == 2 {
- if shape[0] == projSize && shape[1] == kernel {
- return data[:expected], nil
- }
- if shape[0] == kernel && shape[1] == projSize {
- out := make([]float32, expected)
- for d := 0; d < projSize; d++ {
- for j := 0; j < kernel; j++ {
- out[d*kernel+j] = data[j*projSize+d]
- }
- }
- return out, nil
- }
- }
- if len(shape) == 3 {
- if shape[0] == projSize && shape[1] == 1 && shape[2] == kernel {
- return data[:expected], nil
- }
- if shape[0] == kernel && shape[1] == 1 && shape[2] == projSize {
- out := make([]float32, expected)
- for d := 0; d < projSize; d++ {
- for j := 0; j < kernel; j++ {
- out[d*kernel+j] = data[j*projSize+d]
- }
- }
- return out, nil
- }
- }
- if len(data) >= expected {
- return data[:expected], nil
- }
- return nil, fmt.Errorf("unexpected conv weight size %d", len(data))
- }
- func CausalShortConv1DInplaceAct(xFlat []float32, state *cpu.Tensor, w *cpu.Tensor, tokens int, projSize int, kernel int, act ActivationKind) error {
- if kernel <= 1 {
- for i := range xFlat {
- xFlat[i] = applyActivation(xFlat[i], act)
- }
- return nil
- }
- convLen := kernel - 1
- if state == nil {
- return fmt.Errorf("nil conv state")
- }
- if state.Shape().NumElements() != projSize*convLen {
- return fmt.Errorf("conv state shape mismatch %v", state.Shape())
- }
- weights, err := FlattenConvWeights(w, projSize, kernel)
- if err != nil {
- return err
- }
- st := state.DataFloat32()
- out := make([]float32, len(xFlat))
- for t := 0; t < tokens; t++ {
- base := t * projSize
- for d := 0; d < projSize; d++ {
- acc := float32(0)
- wBase := d * kernel
- for j := 0; j < convLen; j++ {
- acc += weights[wBase+j] * st[d*convLen+j]
- }
- acc += weights[wBase+convLen] * xFlat[base+d]
- out[base+d] = applyActivation(acc, act)
- }
- if convLen > 0 {
- for d := 0; d < projSize; d++ {
- off := d * convLen
- copy(st[off:off+convLen-1], st[off+1:off+convLen])
- st[off+convLen-1] = xFlat[base+d]
- }
- }
- }
- copy(xFlat, out)
- return nil
- }
- // CausalShortConv1DInplace is the backward-compatible API.
- // It applies a causal short conv1d followed by SiLU.
- func CausalShortConv1DInplace(xFlat []float32, state *cpu.Tensor, w *cpu.Tensor, tokens int, projSize int, kernel int) error {
- return CausalShortConv1DInplaceAct(xFlat, state, w, tokens, projSize, kernel, ActivationSiLU)
- }
|