package nn import ( "fmt" "math" "makarna/pkg/backend/cpu" ) type ActivationKind uint8 const ( ActivationNone ActivationKind = iota ActivationSiLU ActivationTanh ActivationReLU ) func applyActivation(x float32, act ActivationKind) float32 { switch act { case ActivationNone: return x case ActivationSiLU: return x * Sigmoid(x) case ActivationReLU: if x < 0 { return 0 } return x case ActivationTanh: return float32(math.Tanh(float64(x))) default: return x } } func FlattenConvWeights(w *cpu.Tensor, projSize int, kernel int) ([]float32, error) { if w == nil { return nil, fmt.Errorf("missing conv weights") } data := w.DataFloat32() expected := projSize * kernel shape := w.Shape() if shape.NumElements() < expected || len(data) < expected { return nil, fmt.Errorf("unexpected conv weight size %d", len(data)) } if len(shape) == 2 { if shape[0] == projSize && shape[1] == kernel { return data[:expected], nil } if shape[0] == kernel && shape[1] == projSize { out := make([]float32, expected) for d := 0; d < projSize; d++ { for j := 0; j < kernel; j++ { out[d*kernel+j] = data[j*projSize+d] } } return out, nil } } if len(shape) == 3 { if shape[0] == projSize && shape[1] == 1 && shape[2] == kernel { return data[:expected], nil } if shape[0] == kernel && shape[1] == 1 && shape[2] == projSize { out := make([]float32, expected) for d := 0; d < projSize; d++ { for j := 0; j < kernel; j++ { out[d*kernel+j] = data[j*projSize+d] } } return out, nil } } if len(data) >= expected { return data[:expected], nil } return nil, fmt.Errorf("unexpected conv weight size %d", len(data)) } func CausalShortConv1DInplaceAct(xFlat []float32, state *cpu.Tensor, w *cpu.Tensor, tokens int, projSize int, kernel int, act ActivationKind) error { if kernel <= 1 { for i := range xFlat { xFlat[i] = applyActivation(xFlat[i], act) } return nil } convLen := kernel - 1 if state == nil { return fmt.Errorf("nil conv state") } if state.Shape().NumElements() != projSize*convLen { return fmt.Errorf("conv state shape mismatch %v", state.Shape()) } weights, err := FlattenConvWeights(w, projSize, kernel) if err != nil { return err } st := state.DataFloat32() out := make([]float32, len(xFlat)) for t := 0; t < tokens; t++ { base := t * projSize for d := 0; d < projSize; d++ { acc := float32(0) wBase := d * kernel for j := 0; j < convLen; j++ { acc += weights[wBase+j] * st[d*convLen+j] } acc += weights[wBase+convLen] * xFlat[base+d] out[base+d] = applyActivation(acc, act) } if convLen > 0 { for d := 0; d < projSize; d++ { off := d * convLen copy(st[off:off+convLen-1], st[off+1:off+convLen]) st[off+convLen-1] = xFlat[base+d] } } } copy(xFlat, out) return nil } // CausalShortConv1DInplace is the backward-compatible API. // It applies a causal short conv1d followed by SiLU. func CausalShortConv1DInplace(xFlat []float32, state *cpu.Tensor, w *cpu.Tensor, tokens int, projSize int, kernel int) error { return CausalShortConv1DInplaceAct(xFlat, state, w, tokens, projSize, kernel, ActivationSiLU) }