kda.go 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. package nn
  2. import (
  3. "fmt"
  4. "math"
  5. "makarna/pkg/backend/cpu"
  6. )
  7. func KDAGate(gFlat []float32, aLog []float32, headDim int, dtBias []float32) []float32 {
  8. h := len(aLog)
  9. if h*headDim != len(gFlat) {
  10. return nil
  11. }
  12. out := make([]float32, len(gFlat))
  13. for hi := 0; hi < h; hi++ {
  14. mul := -float32(math.Exp(float64(aLog[hi])))
  15. base := hi * headDim
  16. for d := 0; d < headDim; d++ {
  17. x := gFlat[base+d]
  18. if dtBias != nil {
  19. x += dtBias[base+d]
  20. }
  21. out[base+d] = mul * Softplus(x)
  22. }
  23. }
  24. return out
  25. }
  26. func KDARecurrent(qFlat, kFlat, vFlat, gFlat, beta []float32, state []float32, tokens, numHeads, headDim int) error {
  27. stride := numHeads * headDim
  28. strideB := numHeads
  29. stateStride := headDim * headDim
  30. if len(qFlat) < tokens*stride || len(kFlat) < tokens*stride || len(vFlat) < tokens*stride || len(gFlat) < tokens*stride {
  31. return fmt.Errorf("KDARecurrent: input size mismatch")
  32. }
  33. if len(beta) < tokens*strideB {
  34. return fmt.Errorf("KDARecurrent: beta size mismatch")
  35. }
  36. if state == nil || len(state) != numHeads*stateStride {
  37. return fmt.Errorf("KDARecurrent: state size mismatch")
  38. }
  39. scale := float32(1.0 / math.Sqrt(float64(headDim)))
  40. tmpKV := make([]float32, headDim)
  41. tmpVM := make([]float32, headDim)
  42. for t := 0; t < tokens; t++ {
  43. for h := 0; h < numHeads; h++ {
  44. off := t*stride + h*headDim
  45. b := beta[t*strideB+h]
  46. SOff := h * stateStride
  47. for kk := 0; kk < headDim; kk++ {
  48. dec := float32(math.Exp(float64(gFlat[off+kk])))
  49. rowBase := SOff + kk*headDim
  50. for vv := 0; vv < headDim; vv++ {
  51. state[rowBase+vv] *= dec
  52. }
  53. }
  54. for vv := 0; vv < headDim; vv++ {
  55. acc := float32(0)
  56. for kk := 0; kk < headDim; kk++ {
  57. acc += kFlat[off+kk] * state[SOff+kk*headDim+vv]
  58. }
  59. tmpKV[vv] = acc
  60. }
  61. for vv := 0; vv < headDim; vv++ {
  62. tmpVM[vv] = vFlat[off+vv] - tmpKV[vv]
  63. }
  64. for kk := 0; kk < headDim; kk++ {
  65. kj := b * kFlat[off+kk]
  66. row := state[SOff+kk*headDim : SOff+(kk+1)*headDim]
  67. cpu.Axpy(kj, tmpVM, row)
  68. }
  69. for vv := 0; vv < headDim; vv++ {
  70. acc := float32(0)
  71. for kk := 0; kk < headDim; kk++ {
  72. acc += (qFlat[off+kk] * scale) * state[SOff+kk*headDim+vv]
  73. }
  74. vFlat[off+vv] = acc
  75. }
  76. }
  77. }
  78. return nil
  79. }
  80. func RMSNormGated(out []float32, g []float32, weight []float32, headDim int, eps float32) {
  81. if weight == nil {
  82. return
  83. }
  84. for i := 0; i < len(out); i += headDim {
  85. ss := float32(0)
  86. for j := 0; j < headDim; j++ {
  87. v := out[i+j]
  88. ss += v * v
  89. }
  90. inv := float32(1.0 / math.Sqrt(float64(ss/float32(headDim)+eps)))
  91. for j := 0; j < headDim; j++ {
  92. y := out[i+j] * inv * weight[j]
  93. if g != nil {
  94. y *= Sigmoid(g[i+j])
  95. }
  96. out[i+j] = y
  97. }
  98. }
  99. }
  100. // FlattenALog is a thin convenience wrapper around FlattenVector.
  101. func FlattenALog(t *cpu.Tensor, numHeads int) ([]float32, error) {
  102. return FlattenVector(t, numHeads, "A_log")
  103. }