hybrid_ops_nocuda.go 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. //go:build !cuda
  2. // Package compute provides CPU-only implementations of hybrid operations.
  3. // When CUDA is not available, all operations fall back to CPU.
  4. package compute
  5. import (
  6. "math"
  7. "makarna/pkg/backend/cpu"
  8. "makarna/pkg/backend/cpu/matmul"
  9. "makarna/pkg/backend/cpu/nn"
  10. "makarna/pkg/tensor"
  11. )
  12. // HybridLinear performs matrix multiplication (CPU-only path).
  13. func HybridLinear(ctx *Context, input *Activation, weight tensor.Tensor, output *Activation) error {
  14. inCPU, err := input.AsCPU()
  15. if err != nil {
  16. return err
  17. }
  18. outCPU, err := output.AsCPU()
  19. if err != nil {
  20. return err
  21. }
  22. wCPU := weight.(*cpu.Tensor)
  23. return matmul.Linear(inCPU, wCPU, outCPU)
  24. }
  25. // HybridRMSNorm applies RMS normalization in-place (CPU-only).
  26. func HybridRMSNorm(ctx *Context, x *Activation, w tensor.Tensor, eps float32) error {
  27. xCPU, err := x.AsCPU()
  28. if err != nil {
  29. return err
  30. }
  31. wCPU := w.(*cpu.Tensor)
  32. xData := xCPU.DataFloat32()
  33. wData := wCPU.DataFloat32()
  34. dim := wCPU.Shape().NumElements()
  35. numRows := xCPU.Shape().NumElements() / dim
  36. for i := 0; i < numRows; i++ {
  37. row := xData[i*dim : (i+1)*dim]
  38. ss := cpu.DotFloat32(row, row) / float32(dim)
  39. invRMS := 1.0 / float32(math.Sqrt(float64(ss+eps)))
  40. for j := 0; j < dim; j++ {
  41. row[j] = row[j] * invRMS * wData[j]
  42. }
  43. }
  44. return nil
  45. }
  46. // HybridRoPE applies rotary positional embeddings in-place (CPU-only).
  47. func HybridRoPE(ctx *Context, x *Activation, positions []int, headDim int, theta float32) error {
  48. xCPU, err := x.AsCPU()
  49. if err != nil {
  50. return err
  51. }
  52. data := xCPU.DataFloat32()
  53. shape := x.Shape()
  54. seqLen := shape[0]
  55. totalDim := shape[1]
  56. halfDim := headDim / 2
  57. invFreqs := make([]float64, halfDim)
  58. for j := 0; j < halfDim; j++ {
  59. invFreqs[j] = 1.0 / math.Pow(float64(theta), float64(2*j)/float64(headDim))
  60. }
  61. for seq := 0; seq < seqLen; seq++ {
  62. pos := positions[seq]
  63. rowStart := seq * totalDim
  64. for headStart := 0; headStart < totalDim; headStart += headDim {
  65. for j := 0; j < halfDim; j++ {
  66. freq := float64(pos) * invFreqs[j]
  67. sin, cos := math.Sincos(freq)
  68. idx0 := rowStart + headStart + j
  69. idx1 := rowStart + headStart + j + halfDim
  70. v0 := data[idx0]
  71. v1 := data[idx1]
  72. data[idx0] = v0*float32(cos) - v1*float32(sin)
  73. data[idx1] = v1*float32(cos) + v0*float32(sin)
  74. }
  75. }
  76. }
  77. return nil
  78. }
  79. // HybridSoftmax applies softmax along the last dimension in-place (CPU-only).
  80. func HybridSoftmax(ctx *Context, x *Activation) error {
  81. xCPU, err := x.AsCPU()
  82. if err != nil {
  83. return err
  84. }
  85. data := xCPU.DataFloat32()
  86. shape := x.Shape()
  87. rows, cols := shape[0], shape[1]
  88. for i := 0; i < rows; i++ {
  89. row := data[i*cols : (i+1)*cols]
  90. maxVal := row[0]
  91. for _, v := range row[1:] {
  92. if v > maxVal {
  93. maxVal = v
  94. }
  95. }
  96. sum := float32(0)
  97. for j := range row {
  98. row[j] = float32(math.Exp(float64(row[j] - maxVal)))
  99. sum += row[j]
  100. }
  101. for j := range row {
  102. row[j] /= sum
  103. }
  104. }
  105. return nil
  106. }
  107. // HybridSiLU applies SiLU activation in-place (CPU-only).
  108. func HybridSiLU(ctx *Context, x *Activation) error {
  109. xCPU, err := x.AsCPU()
  110. if err != nil {
  111. return err
  112. }
  113. return nn.SiLU(xCPU)
  114. }
  115. func HybridSwiGLU(ctx *Context, gate, up, out *Activation) error {
  116. if err := HybridCopy(ctx, out, gate); err != nil {
  117. return err
  118. }
  119. if err := HybridSiLU(ctx, out); err != nil {
  120. return err
  121. }
  122. return HybridMul(ctx, out, up)
  123. }
  124. // HybridMul performs element-wise multiplication: a = a * b (CPU-only).
  125. func HybridMul(ctx *Context, a, b *Activation) error {
  126. aCPU, err := a.AsCPU()
  127. if err != nil {
  128. return err
  129. }
  130. bCPU, err := b.AsCPU()
  131. if err != nil {
  132. return err
  133. }
  134. aData := aCPU.DataFloat32()
  135. bData := bCPU.DataFloat32()
  136. for i := range aData {
  137. aData[i] *= bData[i]
  138. }
  139. return nil
  140. }
  141. // HybridAdd performs element-wise addition: a = a + b (CPU-only).
  142. func HybridAdd(ctx *Context, a, b *Activation) error {
  143. aCPU, err := a.AsCPU()
  144. if err != nil {
  145. return err
  146. }
  147. bCPU, err := b.AsCPU()
  148. if err != nil {
  149. return err
  150. }
  151. aData := aCPU.DataFloat32()
  152. bData := bCPU.DataFloat32()
  153. for i := range aData {
  154. aData[i] += bData[i]
  155. }
  156. return nil
  157. }
  158. // HybridAttention computes full causal attention (CPU-only).
  159. func HybridAttention(ctx *Context, Q, K, V, out *Activation, numHeads, numKVHeads, headDim int, scale float32, startPos int) error {
  160. qCPU, err := Q.AsCPU()
  161. if err != nil {
  162. return err
  163. }
  164. kCPU, err := K.AsCPU()
  165. if err != nil {
  166. return err
  167. }
  168. vCPU, err := V.AsCPU()
  169. if err != nil {
  170. return err
  171. }
  172. outCPU, err := out.AsCPU()
  173. if err != nil {
  174. return err
  175. }
  176. qData := qCPU.DataFloat32()
  177. kData := kCPU.DataFloat32()
  178. vData := vCPU.DataFloat32()
  179. outData := outCPU.DataFloat32()
  180. seqLen := Q.Shape()[0]
  181. kvLen := K.Shape()[0]
  182. headsPerKV := numHeads / numKVHeads
  183. for h := 0; h < numHeads; h++ {
  184. kvHead := h / headsPerKV
  185. for q := 0; q < seqLen; q++ {
  186. qOffset := q*numHeads*headDim + h*headDim
  187. qVec := qData[qOffset : qOffset+headDim]
  188. scores := make([]float32, kvLen)
  189. maxScore := float32(-1e9)
  190. for k := 0; k < kvLen; k++ {
  191. if k > startPos+q {
  192. scores[k] = float32(-1e9)
  193. continue
  194. }
  195. kOffset := k*numKVHeads*headDim + kvHead*headDim
  196. kVec := kData[kOffset : kOffset+headDim]
  197. dot := float32(0)
  198. for d := 0; d < headDim; d++ {
  199. dot += qVec[d] * kVec[d]
  200. }
  201. scores[k] = dot * scale
  202. if scores[k] > maxScore {
  203. maxScore = scores[k]
  204. }
  205. }
  206. sum := float32(0)
  207. for k := range scores {
  208. scores[k] = float32(math.Exp(float64(scores[k] - maxScore)))
  209. sum += scores[k]
  210. }
  211. for k := range scores {
  212. scores[k] /= sum
  213. }
  214. outOffset := q*numHeads*headDim + h*headDim
  215. for d := 0; d < headDim; d++ {
  216. acc := float32(0)
  217. for k := 0; k < kvLen; k++ {
  218. vOffset := k*numKVHeads*headDim + kvHead*headDim
  219. acc += scores[k] * vData[vOffset+d]
  220. }
  221. outData[outOffset+d] = acc
  222. }
  223. }
  224. }
  225. return nil
  226. }
  227. // HybridCopy copies src to dst (CPU-only).
  228. func HybridCopy(ctx *Context, dst, src *Activation) error {
  229. dstCPU, err := dst.AsCPU()
  230. if err != nil {
  231. return err
  232. }
  233. srcCPU, err := src.AsCPU()
  234. if err != nil {
  235. return err
  236. }
  237. copy(dstCPU.DataFloat32(), srcCPU.DataFloat32())
  238. return nil
  239. }
  240. // EnsureOnDevice moves activation to target device if needed (CPU-only stub).
  241. func EnsureOnDevice(a *Activation, target tensor.DevicePlacement) error {
  242. _, err := a.EnsureOn(target)
  243. return err
  244. }