package nn import ( "fmt" "math" "makarna/pkg/backend/cpu" ) func KDAGate(gFlat []float32, aLog []float32, headDim int, dtBias []float32) []float32 { h := len(aLog) if h*headDim != len(gFlat) { return nil } out := make([]float32, len(gFlat)) for hi := 0; hi < h; hi++ { mul := -float32(math.Exp(float64(aLog[hi]))) base := hi * headDim for d := 0; d < headDim; d++ { x := gFlat[base+d] if dtBias != nil { x += dtBias[base+d] } out[base+d] = mul * Softplus(x) } } return out } func KDARecurrent(qFlat, kFlat, vFlat, gFlat, beta []float32, state []float32, tokens, numHeads, headDim int) error { stride := numHeads * headDim strideB := numHeads stateStride := headDim * headDim if len(qFlat) < tokens*stride || len(kFlat) < tokens*stride || len(vFlat) < tokens*stride || len(gFlat) < tokens*stride { return fmt.Errorf("KDARecurrent: input size mismatch") } if len(beta) < tokens*strideB { return fmt.Errorf("KDARecurrent: beta size mismatch") } if state == nil || len(state) != numHeads*stateStride { return fmt.Errorf("KDARecurrent: state size mismatch") } scale := float32(1.0 / math.Sqrt(float64(headDim))) tmpKV := make([]float32, headDim) tmpVM := make([]float32, headDim) for t := 0; t < tokens; t++ { for h := 0; h < numHeads; h++ { off := t*stride + h*headDim b := beta[t*strideB+h] SOff := h * stateStride for kk := 0; kk < headDim; kk++ { dec := float32(math.Exp(float64(gFlat[off+kk]))) rowBase := SOff + kk*headDim for vv := 0; vv < headDim; vv++ { state[rowBase+vv] *= dec } } for vv := 0; vv < headDim; vv++ { acc := float32(0) for kk := 0; kk < headDim; kk++ { acc += kFlat[off+kk] * state[SOff+kk*headDim+vv] } tmpKV[vv] = acc } for vv := 0; vv < headDim; vv++ { tmpVM[vv] = vFlat[off+vv] - tmpKV[vv] } for kk := 0; kk < headDim; kk++ { kj := b * kFlat[off+kk] row := state[SOff+kk*headDim : SOff+(kk+1)*headDim] cpu.Axpy(kj, tmpVM, row) } for vv := 0; vv < headDim; vv++ { acc := float32(0) for kk := 0; kk < headDim; kk++ { acc += (qFlat[off+kk] * scale) * state[SOff+kk*headDim+vv] } vFlat[off+vv] = acc } } } return nil } func RMSNormGated(out []float32, g []float32, weight []float32, headDim int, eps float32) { if weight == nil { return } for i := 0; i < len(out); i += headDim { ss := float32(0) for j := 0; j < headDim; j++ { v := out[i+j] ss += v * v } inv := float32(1.0 / math.Sqrt(float64(ss/float32(headDim)+eps))) for j := 0; j < headDim; j++ { y := out[i+j] * inv * weight[j] if g != nil { y *= Sigmoid(g[i+j]) } out[i+j] = y } } } // FlattenALog is a thin convenience wrapper around FlattenVector. func FlattenALog(t *cpu.Tensor, numHeads int) ([]float32, error) { return FlattenVector(t, numHeads, "A_log") }