| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114 |
- package nn
- import (
- "fmt"
- "math"
- "makarna/pkg/backend/cpu"
- )
- func KDAGate(gFlat []float32, aLog []float32, headDim int, dtBias []float32) []float32 {
- h := len(aLog)
- if h*headDim != len(gFlat) {
- return nil
- }
- out := make([]float32, len(gFlat))
- for hi := 0; hi < h; hi++ {
- mul := -float32(math.Exp(float64(aLog[hi])))
- base := hi * headDim
- for d := 0; d < headDim; d++ {
- x := gFlat[base+d]
- if dtBias != nil {
- x += dtBias[base+d]
- }
- out[base+d] = mul * Softplus(x)
- }
- }
- return out
- }
- func KDARecurrent(qFlat, kFlat, vFlat, gFlat, beta []float32, state []float32, tokens, numHeads, headDim int) error {
- stride := numHeads * headDim
- strideB := numHeads
- stateStride := headDim * headDim
- if len(qFlat) < tokens*stride || len(kFlat) < tokens*stride || len(vFlat) < tokens*stride || len(gFlat) < tokens*stride {
- return fmt.Errorf("KDARecurrent: input size mismatch")
- }
- if len(beta) < tokens*strideB {
- return fmt.Errorf("KDARecurrent: beta size mismatch")
- }
- if state == nil || len(state) != numHeads*stateStride {
- return fmt.Errorf("KDARecurrent: state size mismatch")
- }
- scale := float32(1.0 / math.Sqrt(float64(headDim)))
- tmpKV := make([]float32, headDim)
- tmpVM := make([]float32, headDim)
- for t := 0; t < tokens; t++ {
- for h := 0; h < numHeads; h++ {
- off := t*stride + h*headDim
- b := beta[t*strideB+h]
- SOff := h * stateStride
- for kk := 0; kk < headDim; kk++ {
- dec := float32(math.Exp(float64(gFlat[off+kk])))
- rowBase := SOff + kk*headDim
- for vv := 0; vv < headDim; vv++ {
- state[rowBase+vv] *= dec
- }
- }
- for vv := 0; vv < headDim; vv++ {
- acc := float32(0)
- for kk := 0; kk < headDim; kk++ {
- acc += kFlat[off+kk] * state[SOff+kk*headDim+vv]
- }
- tmpKV[vv] = acc
- }
- for vv := 0; vv < headDim; vv++ {
- tmpVM[vv] = vFlat[off+vv] - tmpKV[vv]
- }
- for kk := 0; kk < headDim; kk++ {
- kj := b * kFlat[off+kk]
- row := state[SOff+kk*headDim : SOff+(kk+1)*headDim]
- cpu.Axpy(kj, tmpVM, row)
- }
- for vv := 0; vv < headDim; vv++ {
- acc := float32(0)
- for kk := 0; kk < headDim; kk++ {
- acc += (qFlat[off+kk] * scale) * state[SOff+kk*headDim+vv]
- }
- vFlat[off+vv] = acc
- }
- }
- }
- return nil
- }
- func RMSNormGated(out []float32, g []float32, weight []float32, headDim int, eps float32) {
- if weight == nil {
- return
- }
- for i := 0; i < len(out); i += headDim {
- ss := float32(0)
- for j := 0; j < headDim; j++ {
- v := out[i+j]
- ss += v * v
- }
- inv := float32(1.0 / math.Sqrt(float64(ss/float32(headDim)+eps)))
- for j := 0; j < headDim; j++ {
- y := out[i+j] * inv * weight[j]
- if g != nil {
- y *= Sigmoid(g[i+j])
- }
- out[i+j] = y
- }
- }
- }
- // FlattenALog is a thin convenience wrapper around FlattenVector.
- func FlattenALog(t *cpu.Tensor, numHeads int) ([]float32, error) {
- return FlattenVector(t, numHeads, "A_log")
- }
|