nn.go 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. //go:build !cuda
  2. // Package compute provides device-agnostic neural network operations.
  3. package compute
  4. import (
  5. "fmt"
  6. "math"
  7. "makarna/pkg/backend/cpu"
  8. "makarna/pkg/backend/cpu/nn"
  9. "makarna/pkg/backend/device"
  10. "makarna/pkg/tensor"
  11. )
  12. // RMSNorm applies RMS normalization in-place.
  13. // For GPU tensors, temporarily copies to CPU (CUDA kernel not implemented yet).
  14. func RMSNorm(ctx *Context, x, w tensor.Tensor, eps float32) error {
  15. // Currently always use CPU path
  16. // TODO: Implement CUDA RMSNorm kernel
  17. xCPU, err := ensureCPUForOp(x)
  18. if err != nil {
  19. return fmt.Errorf("rmsnorm: %w", err)
  20. }
  21. wCPU, ok := w.(*cpu.Tensor)
  22. if !ok {
  23. return fmt.Errorf("rmsnorm: weight must be CPU tensor")
  24. }
  25. return rmsNormCPU(xCPU, wCPU, eps)
  26. }
  27. func rmsNormCPU(x, w *cpu.Tensor, eps float32) error {
  28. xData := x.DataFloat32()
  29. wData := w.DataFloat32()
  30. dim := w.Shape().NumElements()
  31. numRows := x.Shape().NumElements() / dim
  32. for i := 0; i < numRows; i++ {
  33. row := xData[i*dim : (i+1)*dim]
  34. // Sum of squares
  35. ss := cpu.DotFloat32(row, row) / float32(dim)
  36. // Normalize and scale
  37. invRMS := 1.0 / float32(math.Sqrt(float64(ss+eps)))
  38. for j := 0; j < dim; j++ {
  39. row[j] = row[j] * invRMS * wData[j]
  40. }
  41. }
  42. return nil
  43. }
  44. // RoPE applies Rotary Positional Embeddings in-place.
  45. func RoPE(ctx *Context, x tensor.Tensor, positions []int, headDim int, theta float32) error {
  46. xCPU, err := ensureCPUForOp(x)
  47. if err != nil {
  48. return fmt.Errorf("rope: %w", err)
  49. }
  50. return ropeCPU(xCPU, positions, headDim, theta)
  51. }
  52. func ropeCPU(x *cpu.Tensor, positions []int, headDim int, theta float32) error {
  53. data := x.DataFloat32()
  54. shape := x.Shape()
  55. seqLen := shape[0]
  56. totalDim := shape[1]
  57. halfDim := headDim / 2
  58. invFreqs := make([]float64, halfDim)
  59. for j := 0; j < halfDim; j++ {
  60. invFreqs[j] = 1.0 / math.Pow(float64(theta), float64(2*j)/float64(headDim))
  61. }
  62. for seq := 0; seq < seqLen; seq++ {
  63. pos := positions[seq]
  64. rowStart := seq * totalDim
  65. for headStart := 0; headStart < totalDim; headStart += headDim {
  66. for j := 0; j < halfDim; j++ {
  67. freq := float64(pos) * invFreqs[j]
  68. sin, cos := math.Sincos(freq)
  69. idx0 := rowStart + headStart + j
  70. idx1 := rowStart + headStart + j + halfDim
  71. v0 := data[idx0]
  72. v1 := data[idx1]
  73. data[idx0] = v0*float32(cos) - v1*float32(sin)
  74. data[idx1] = v1*float32(cos) + v0*float32(sin)
  75. }
  76. }
  77. }
  78. return nil
  79. }
  80. // SwiGLU computes SiLU(gate) * up and stores in out.
  81. func SwiGLU(ctx *Context, gate, up, out tensor.Tensor) error {
  82. gCPU, err := ensureCPUForOp(gate)
  83. if err != nil {
  84. return err
  85. }
  86. uCPU, err := ensureCPUForOp(up)
  87. if err != nil {
  88. return err
  89. }
  90. oCPU, err := ensureCPUForOp(out)
  91. if err != nil {
  92. return err
  93. }
  94. return swigluCPU(gCPU, uCPU, oCPU)
  95. }
  96. func swigluCPU(gate, up, out *cpu.Tensor) error {
  97. gData := gate.DataFloat32()
  98. uData := up.DataFloat32()
  99. oData := out.DataFloat32()
  100. if len(oData) == 0 {
  101. return nil
  102. }
  103. copy(oData, gData)
  104. if err := nn.SiLU(out); err != nil {
  105. return err
  106. }
  107. for i := range oData {
  108. oData[i] *= uData[i]
  109. }
  110. return nil
  111. }
  112. // Softmax applies softmax along the last dimension.
  113. func Softmax(ctx *Context, x tensor.Tensor) error {
  114. xCPU, err := ensureCPUForOp(x)
  115. if err != nil {
  116. return err
  117. }
  118. return softmaxCPU(xCPU)
  119. }
  120. func softmaxCPU(x *cpu.Tensor) error {
  121. data := x.DataFloat32()
  122. shape := x.Shape()
  123. if len(shape) != 2 {
  124. return fmt.Errorf("softmax: expected 2D tensor")
  125. }
  126. rows := shape[0]
  127. cols := shape[1]
  128. for i := 0; i < rows; i++ {
  129. row := data[i*cols : (i+1)*cols]
  130. // Find max for numerical stability
  131. maxVal := row[0]
  132. for _, v := range row[1:] {
  133. if v > maxVal {
  134. maxVal = v
  135. }
  136. }
  137. // Exp and sum
  138. sum := float32(0)
  139. for j := range row {
  140. row[j] = float32(math.Exp(float64(row[j] - maxVal)))
  141. sum += row[j]
  142. }
  143. // Normalize
  144. for j := range row {
  145. row[j] /= sum
  146. }
  147. }
  148. return nil
  149. }
  150. // Add performs element-wise addition: dst += src
  151. func Add(dst, src tensor.Tensor) error {
  152. dCPU, err := ensureCPUForOp(dst)
  153. if err != nil {
  154. return err
  155. }
  156. sCPU, err := ensureCPUForOp(src)
  157. if err != nil {
  158. return err
  159. }
  160. dData := dCPU.DataFloat32()
  161. sData := sCPU.DataFloat32()
  162. for i := range dData {
  163. dData[i] += sData[i]
  164. }
  165. return nil
  166. }
  167. // CopyData copies tensor data: dst = src
  168. func CopyData(dst, src tensor.Tensor) error {
  169. dCPU, err := ensureCPUForOp(dst)
  170. if err != nil {
  171. return err
  172. }
  173. sCPU, err := ensureCPUForOp(src)
  174. if err != nil {
  175. return err
  176. }
  177. copy(dCPU.DataFloat32(), sCPU.DataFloat32())
  178. return nil
  179. }
  180. // ensureCPUForOp converts tensor to CPU if needed (temporary until CUDA kernels are done)
  181. func ensureCPUForOp(t tensor.Tensor) (*cpu.Tensor, error) {
  182. if cpuT, ok := t.(*cpu.Tensor); ok {
  183. return cpuT, nil
  184. }
  185. // Convert to CPU
  186. result, err := device.EnsureOn(t, tensor.DevicePlacement{Type: tensor.CPU, GPU: -1})
  187. if err != nil {
  188. return nil, err
  189. }
  190. cpuT, ok := result.(*cpu.Tensor)
  191. if !ok {
  192. return nil, fmt.Errorf("expected CPU tensor after conversion")
  193. }
  194. return cpuT, nil
  195. }