nn_cuda.go 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. //go:build cuda
  2. // Package compute provides CUDA-accelerated neural network operations.
  3. package compute
  4. import (
  5. "fmt"
  6. "math"
  7. "unsafe"
  8. "makarna/pkg/backend/cpu"
  9. "makarna/pkg/backend/cuda"
  10. "makarna/pkg/backend/device"
  11. "makarna/pkg/profile"
  12. "makarna/pkg/tensor"
  13. )
  14. // RMSNorm applies RMS normalization in-place.
  15. // Uses CUDA when on GPU layer.
  16. func RMSNorm(ctx *Context, x, w tensor.Tensor, eps float32) error {
  17. useGPU := ctx != nil && ctx.IsGPU() && device.CUDAAvailable()
  18. if !useGPU {
  19. profile.Start("RMSNorm/CPU")
  20. err := rmsNormCPU(x, w, eps)
  21. profile.End("RMSNorm/CPU")
  22. return err
  23. }
  24. profile.Start("RMSNorm/GPU")
  25. defer profile.End("RMSNorm/GPU")
  26. // Get GPU pointers
  27. xCPU, ok := x.(*cpu.Tensor)
  28. if !ok {
  29. return fmt.Errorf("rmsnorm: x must be CPU tensor")
  30. }
  31. wCPU, ok := w.(*cpu.Tensor)
  32. if !ok {
  33. return fmt.Errorf("rmsnorm: w must be CPU tensor")
  34. }
  35. gpu := ctx.Placement().GPU
  36. shape := x.Shape()
  37. seqLen := shape[0]
  38. dim := shape[1]
  39. // Upload x to GPU
  40. profile.Start("RMSNorm/alloc_x")
  41. gpuX, err := cuda.NewTensor(shape, tensor.Float32, gpu)
  42. profile.End("RMSNorm/alloc_x")
  43. if err != nil {
  44. return err
  45. }
  46. if err := gpuX.CopyFrom(xCPU.DataFloat32()); err != nil {
  47. return err
  48. }
  49. // Get or upload weight
  50. cache := GetWeightCache(gpu)
  51. wKey := fmt.Sprintf("rmsnorm_%p", wCPU)
  52. gpuW, ok := cache.Get(wKey)
  53. if !ok {
  54. gpuW, err = cache.Upload(wKey, wCPU)
  55. if err != nil {
  56. return err
  57. }
  58. }
  59. // Run CUDA RMSNorm
  60. profile.Start("RMSNorm/kernel")
  61. err = cuda.RMSNorm(gpuX.Data().(unsafe.Pointer), gpuW, seqLen, dim, eps, gpu)
  62. profile.End("RMSNorm/kernel")
  63. if err != nil {
  64. return err
  65. }
  66. // Copy back
  67. if err := gpuX.CopyToHost(xCPU.DataFloat32()); err != nil {
  68. return err
  69. }
  70. return nil
  71. }
  72. func rmsNormCPU(x, w tensor.Tensor, eps float32) error {
  73. xCPU := x.(*cpu.Tensor)
  74. wCPU := w.(*cpu.Tensor)
  75. xData := xCPU.DataFloat32()
  76. wData := wCPU.DataFloat32()
  77. dim := wCPU.Shape().NumElements()
  78. numRows := xCPU.Shape().NumElements() / dim
  79. for i := 0; i < numRows; i++ {
  80. row := xData[i*dim : (i+1)*dim]
  81. ss := cpu.DotFloat32(row, row) / float32(dim)
  82. invRMS := 1.0 / float32(math.Sqrt(float64(ss+eps)))
  83. for j := 0; j < dim; j++ {
  84. row[j] = row[j] * invRMS * wData[j]
  85. }
  86. }
  87. return nil
  88. }
  89. // RoPE applies Rotary Positional Embeddings in-place.
  90. func RoPE(ctx *Context, x tensor.Tensor, positions []int, headDim int, theta float32) error {
  91. useGPU := ctx != nil && ctx.IsGPU() && device.CUDAAvailable()
  92. if !useGPU {
  93. profile.Start("RoPE/CPU")
  94. err := ropeCPU(x, positions, headDim, theta)
  95. profile.End("RoPE/CPU")
  96. return err
  97. }
  98. profile.Start("RoPE/GPU")
  99. defer profile.End("RoPE/GPU")
  100. xCPU := x.(*cpu.Tensor)
  101. gpu := ctx.Placement().GPU
  102. shape := x.Shape()
  103. seqLen := shape[0]
  104. totalDim := shape[1]
  105. numHeads := totalDim / headDim
  106. // Upload x to GPU
  107. gpuX, err := cuda.NewTensor(shape, tensor.Float32, gpu)
  108. if err != nil {
  109. return err
  110. }
  111. if err := gpuX.CopyFrom(xCPU.DataFloat32()); err != nil {
  112. return err
  113. }
  114. // Upload positions
  115. posData := make([]int32, len(positions))
  116. for i, p := range positions {
  117. posData[i] = int32(p)
  118. }
  119. gpuPos, err := cuda.NewTensor(tensor.Shape{len(positions)}, tensor.Float32, gpu) // Reuse for int alloc
  120. if err != nil {
  121. return err
  122. }
  123. // Manual copy as int32
  124. posPtr := unsafe.Pointer(&posData[0])
  125. if err := gpuPos.CopyFrom(unsafe.Slice((*float32)(posPtr), len(posData))); err != nil {
  126. return err
  127. }
  128. // Run CUDA RoPE
  129. profile.Start("RoPE/kernel")
  130. err = cuda.RoPE(gpuX.Data().(unsafe.Pointer), gpuPos.Data().(unsafe.Pointer), seqLen, numHeads, headDim, theta, gpu)
  131. profile.End("RoPE/kernel")
  132. if err != nil {
  133. return err
  134. }
  135. // Copy back
  136. if err := gpuX.CopyToHost(xCPU.DataFloat32()); err != nil {
  137. return err
  138. }
  139. return nil
  140. }
  141. func ropeCPU(x tensor.Tensor, positions []int, headDim int, theta float32) error {
  142. xCPU := x.(*cpu.Tensor)
  143. data := xCPU.DataFloat32()
  144. shape := x.Shape()
  145. seqLen := shape[0]
  146. totalDim := shape[1]
  147. halfDim := headDim / 2
  148. invFreqs := make([]float64, halfDim)
  149. for j := 0; j < halfDim; j++ {
  150. invFreqs[j] = 1.0 / math.Pow(float64(theta), float64(2*j)/float64(headDim))
  151. }
  152. for seq := 0; seq < seqLen; seq++ {
  153. pos := positions[seq]
  154. rowStart := seq * totalDim
  155. for headStart := 0; headStart < totalDim; headStart += headDim {
  156. for j := 0; j < halfDim; j++ {
  157. freq := float64(pos) * invFreqs[j]
  158. sin, cos := math.Sincos(freq)
  159. idx0 := rowStart + headStart + j
  160. idx1 := rowStart + headStart + j + halfDim
  161. v0 := data[idx0]
  162. v1 := data[idx1]
  163. data[idx0] = v0*float32(cos) - v1*float32(sin)
  164. data[idx1] = v1*float32(cos) + v0*float32(sin)
  165. }
  166. }
  167. }
  168. return nil
  169. }
  170. // SwiGLU computes SiLU(gate) * up and stores in out.
  171. func SwiGLU(ctx *Context, gate, up, out tensor.Tensor) error {
  172. useGPU := ctx != nil && ctx.IsGPU() && device.CUDAAvailable()
  173. if !useGPU {
  174. profile.Start("SwiGLU/CPU")
  175. err := swigluCPU(gate, up, out)
  176. profile.End("SwiGLU/CPU")
  177. return err
  178. }
  179. profile.Start("SwiGLU/GPU")
  180. defer profile.End("SwiGLU/GPU")
  181. gCPU := gate.(*cpu.Tensor)
  182. uCPU := up.(*cpu.Tensor)
  183. oCPU := out.(*cpu.Tensor)
  184. gpu := ctx.Placement().GPU
  185. n := gate.Shape().NumElements()
  186. // Upload gate and up to GPU
  187. gpuGate, err := cuda.NewTensor(gate.Shape(), tensor.Float32, gpu)
  188. if err != nil {
  189. return err
  190. }
  191. gpuUp, err := cuda.NewTensor(up.Shape(), tensor.Float32, gpu)
  192. if err != nil {
  193. return err
  194. }
  195. if err := gpuGate.CopyFrom(gCPU.DataFloat32()); err != nil {
  196. return err
  197. }
  198. if err := gpuUp.CopyFrom(uCPU.DataFloat32()); err != nil {
  199. return err
  200. }
  201. // SiLU(gate) in-place
  202. profile.Start("SwiGLU/silu_kernel")
  203. if err := cuda.SiLU(gpuGate.Data().(unsafe.Pointer), n, gpu); err != nil {
  204. profile.End("SwiGLU/silu_kernel")
  205. return err
  206. }
  207. profile.End("SwiGLU/silu_kernel")
  208. // gate = gate * up
  209. profile.Start("SwiGLU/mul_kernel")
  210. if err := cuda.MulInplace(gpuGate.Data().(unsafe.Pointer), gpuUp.Data().(unsafe.Pointer), n, gpu); err != nil {
  211. profile.End("SwiGLU/mul_kernel")
  212. return err
  213. }
  214. profile.End("SwiGLU/mul_kernel")
  215. // Copy to output
  216. if err := gpuGate.CopyToHost(oCPU.DataFloat32()); err != nil {
  217. return err
  218. }
  219. return nil
  220. }
  221. func swigluCPU(gate, up, out tensor.Tensor) error {
  222. gCPU := gate.(*cpu.Tensor)
  223. uCPU := up.(*cpu.Tensor)
  224. oCPU := out.(*cpu.Tensor)
  225. gData := gCPU.DataFloat32()
  226. uData := uCPU.DataFloat32()
  227. oData := oCPU.DataFloat32()
  228. for i := range gData {
  229. gv := gData[i]
  230. silu := gv / (1.0 + float32(math.Exp(float64(-gv))))
  231. oData[i] = silu * uData[i]
  232. }
  233. return nil
  234. }
  235. // Softmax applies softmax along the last dimension.
  236. func Softmax(ctx *Context, x tensor.Tensor) error {
  237. xCPU := x.(*cpu.Tensor)
  238. return softmaxCPU(xCPU)
  239. }
  240. func softmaxCPU(x *cpu.Tensor) error {
  241. data := x.DataFloat32()
  242. shape := x.Shape()
  243. if len(shape) != 2 {
  244. return fmt.Errorf("softmax: expected 2D tensor")
  245. }
  246. rows := shape[0]
  247. cols := shape[1]
  248. for i := 0; i < rows; i++ {
  249. row := data[i*cols : (i+1)*cols]
  250. maxVal := row[0]
  251. for _, v := range row[1:] {
  252. if v > maxVal {
  253. maxVal = v
  254. }
  255. }
  256. sum := float32(0)
  257. for j := range row {
  258. row[j] = float32(math.Exp(float64(row[j] - maxVal)))
  259. sum += row[j]
  260. }
  261. for j := range row {
  262. row[j] /= sum
  263. }
  264. }
  265. return nil
  266. }
  267. // Add performs element-wise addition: dst += src
  268. func Add(dst, src tensor.Tensor) error {
  269. dCPU := dst.(*cpu.Tensor)
  270. sCPU := src.(*cpu.Tensor)
  271. dData := dCPU.DataFloat32()
  272. sData := sCPU.DataFloat32()
  273. for i := range dData {
  274. dData[i] += sData[i]
  275. }
  276. return nil
  277. }
  278. // CopyData copies tensor data: dst = src
  279. func CopyData(dst, src tensor.Tensor) error {
  280. dCPU := dst.(*cpu.Tensor)
  281. sCPU := src.(*cpu.Tensor)
  282. copy(dCPU.DataFloat32(), sCPU.DataFloat32())
  283. return nil
  284. }