hybrid_ops.go 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782
  1. //go:build cuda
  2. // Package compute provides hybrid CPU/GPU neural network operations.
  3. // Operations automatically dispatch to the correct backend based on activation placement.
  4. package compute
  5. import (
  6. "fmt"
  7. "math"
  8. "unsafe"
  9. "makarna/pkg/backend/cpu"
  10. "makarna/pkg/backend/cpu/matmul"
  11. "makarna/pkg/backend/cpu/nn"
  12. "makarna/pkg/backend/cuda"
  13. "makarna/pkg/profile"
  14. "makarna/pkg/tensor"
  15. )
  16. // HybridLinear performs matrix multiplication on either CPU or GPU.
  17. // Automatically uses weight cache for GPU weights.
  18. func HybridLinear(ctx *Context, input *Activation, weight tensor.Tensor, output *Activation) error {
  19. if input.IsGPU() && ctx != nil && ctx.IsGPU() {
  20. profile.Start("HybridLinear/GPU")
  21. err := hybridLinearGPU(ctx, input, weight, output)
  22. profile.End("HybridLinear/GPU")
  23. return err
  24. }
  25. profile.Start("HybridLinear/CPU")
  26. err := hybridLinearCPU(input, weight, output)
  27. profile.End("HybridLinear/CPU")
  28. return err
  29. }
  30. func hybridLinearCPU(input *Activation, weight tensor.Tensor, output *Activation) error {
  31. inCPU, err := input.AsCPU()
  32. if err != nil {
  33. return err
  34. }
  35. outCPU, err := output.AsCPU()
  36. if err != nil {
  37. return err
  38. }
  39. wCPU := weight.(*cpu.Tensor)
  40. return matmul.Linear(inCPU, wCPU, outCPU)
  41. }
  42. func hybridLinearGPU(ctx *Context, input *Activation, weight tensor.Tensor, output *Activation) error {
  43. gpu := ctx.Placement().GPU
  44. inShape := input.Shape()
  45. wShape := weight.Shape()
  46. M, K, N := inShape[0], inShape[1], wShape[0]
  47. // Get GPU input
  48. gpuIn, err := input.AsCUDA(gpu)
  49. if err != nil {
  50. return err
  51. }
  52. // Get cached weight
  53. cache := GetWeightCache(gpu)
  54. var weightKey string
  55. if wCPU, ok := weight.(*cpu.Tensor); ok {
  56. weightKey = fmt.Sprintf("layer%d_w_%p", ctx.LayerIdx, wCPU)
  57. } else {
  58. weightKey = fmt.Sprintf("layer%d_w_%T_%p", ctx.LayerIdx, weight, weight)
  59. }
  60. upload := cache.Upload
  61. if weight.DType() == tensor.Float32 {
  62. // Use a separate key so float32 weights needed by other ops can coexist with FP16 GEMM weights.
  63. if wCPU, ok := weight.(*cpu.Tensor); ok {
  64. weightKey = fmt.Sprintf("layer%d_w_f16_%p", ctx.LayerIdx, wCPU)
  65. } else {
  66. weightKey = fmt.Sprintf("layer%d_w_f16_%T_%p", ctx.LayerIdx, weight, weight)
  67. }
  68. upload = cache.UploadF16
  69. }
  70. gpuWeight, ok := cache.Get(weightKey)
  71. if !ok {
  72. cpuW := weight.(*cpu.Tensor)
  73. gpuWeight, err = upload(weightKey, cpuW)
  74. if err != nil {
  75. return fmt.Errorf("hybrid linear: cache weight: %w", err)
  76. }
  77. }
  78. // Reuse preallocated output buffer when possible (e.g., scratch views).
  79. var gpuOut *cuda.Tensor
  80. if output != nil && output.IsGPU() {
  81. if outT, err := output.AsCUDA(gpu); err == nil {
  82. if outT.DType() == tensor.Float32 {
  83. shape := outT.Shape()
  84. if len(shape) == 2 && shape[0] == M && shape[1] == N {
  85. gpuOut = outT
  86. }
  87. }
  88. }
  89. }
  90. if gpuOut == nil {
  91. t, err := cuda.NewTensor(tensor.Shape{M, N}, tensor.Float32, gpu)
  92. if err != nil {
  93. return err
  94. }
  95. output.ReplaceWith(t)
  96. gpuOut = t
  97. }
  98. // Execute based on weight dtype and input dtype
  99. aPtr := gpuIn.Data().(unsafe.Pointer)
  100. cPtr := gpuOut.Data().(unsafe.Pointer)
  101. inputIsF16 := gpuIn.DType() == tensor.Float16
  102. // Prefer FP16 input for quant matmuls (memory bandwidth win).
  103. // If activations are still FP32, cast to FP16 on GPU and use the FP16 kernels.
  104. if !inputIsF16 {
  105. switch weight.DType() {
  106. case tensor.Q8_K, tensor.Q5_K, tensor.Q4_K, tensor.Q2_K, tensor.Q3_K, tensor.Q6_K:
  107. f16In, err := cuda.NewTensor(tensor.Shape{M, K}, tensor.Float16, gpu)
  108. if err != nil {
  109. return err
  110. }
  111. defer f16In.Free()
  112. if err := cuda.CastF32ToF16(aPtr, f16In.Data().(unsafe.Pointer), M*K, gpu); err != nil {
  113. return err
  114. }
  115. aPtr = f16In.Data().(unsafe.Pointer)
  116. inputIsF16 = true
  117. }
  118. }
  119. switch weight.DType() {
  120. case tensor.Q8_K:
  121. if inputIsF16 {
  122. if err := cuda.MatMulF16Q8K(aPtr, gpuWeight, cPtr, M, K, N, gpu); err != nil {
  123. return err
  124. }
  125. } else {
  126. if err := cuda.MatMulQ8K(aPtr, gpuWeight, cPtr, M, K, N, gpu); err != nil {
  127. return err
  128. }
  129. }
  130. case tensor.Q5_K:
  131. if inputIsF16 {
  132. if err := cuda.MatMulF16Q5K(aPtr, gpuWeight, cPtr, M, K, N, gpu); err != nil {
  133. return err
  134. }
  135. } else {
  136. if err := cuda.MatMulQ5K(aPtr, gpuWeight, cPtr, M, K, N, gpu); err != nil {
  137. return err
  138. }
  139. }
  140. case tensor.Q4_K:
  141. if inputIsF16 {
  142. if err := cuda.MatMulF16Q4K(aPtr, gpuWeight, cPtr, M, K, N, gpu); err != nil {
  143. return err
  144. }
  145. } else {
  146. if err := cuda.MatMulQ4K(aPtr, gpuWeight, cPtr, M, K, N, gpu); err != nil {
  147. return err
  148. }
  149. }
  150. case tensor.Q2_K:
  151. if inputIsF16 {
  152. if err := cuda.MatMulF16Q2K(aPtr, gpuWeight, cPtr, M, K, N, gpu); err != nil {
  153. return err
  154. }
  155. } else {
  156. if err := cuda.MatMulQ2K(aPtr, gpuWeight, cPtr, M, K, N, gpu); err != nil {
  157. return err
  158. }
  159. }
  160. case tensor.Q3_K:
  161. if inputIsF16 {
  162. if err := cuda.MatMulF16Q3K(aPtr, gpuWeight, cPtr, M, K, N, gpu); err != nil {
  163. return err
  164. }
  165. } else {
  166. if err := cuda.MatMulQ3K(aPtr, gpuWeight, cPtr, M, K, N, gpu); err != nil {
  167. return err
  168. }
  169. }
  170. case tensor.Q6_K:
  171. if inputIsF16 {
  172. if err := cuda.MatMulF16Q6K(aPtr, gpuWeight, cPtr, M, K, N, gpu); err != nil {
  173. return err
  174. }
  175. } else {
  176. if err := cuda.MatMulQ6K(aPtr, gpuWeight, cPtr, M, K, N, gpu); err != nil {
  177. return err
  178. }
  179. }
  180. default:
  181. // Dense GEMM path (weights cached as FP16 on GPU).
  182. if !inputIsF16 {
  183. f16In, err := cuda.NewTensor(tensor.Shape{M, K}, tensor.Float16, gpu)
  184. if err != nil {
  185. return err
  186. }
  187. defer f16In.Free()
  188. if err := cuda.CastF32ToF16(aPtr, f16In.Data().(unsafe.Pointer), M*K, gpu); err != nil {
  189. return err
  190. }
  191. aPtr = f16In.Data().(unsafe.Pointer)
  192. inputIsF16 = true
  193. }
  194. if err := cuda.MatMulF16(aPtr, gpuWeight, cPtr, M, K, N, gpu); err != nil {
  195. return err
  196. }
  197. }
  198. return nil
  199. }
  200. // HybridRMSNorm applies RMS normalization in-place.
  201. func HybridRMSNorm(ctx *Context, x *Activation, w tensor.Tensor, eps float32) error {
  202. if x.IsGPU() && ctx != nil && ctx.IsGPU() {
  203. profile.Start("HybridRMSNorm/GPU")
  204. err := hybridRMSNormGPU(ctx, x, w, eps)
  205. profile.End("HybridRMSNorm/GPU")
  206. return err
  207. }
  208. profile.Start("HybridRMSNorm/CPU")
  209. err := hybridRMSNormCPU(x, w, eps)
  210. profile.End("HybridRMSNorm/CPU")
  211. return err
  212. }
  213. func hybridRMSNormCPU(x *Activation, w tensor.Tensor, eps float32) error {
  214. xCPU, err := x.AsCPU()
  215. if err != nil {
  216. return err
  217. }
  218. wCPU := w.(*cpu.Tensor)
  219. xData := xCPU.DataFloat32()
  220. wData := wCPU.DataFloat32()
  221. dim := wCPU.Shape().NumElements()
  222. numRows := xCPU.Shape().NumElements() / dim
  223. for i := 0; i < numRows; i++ {
  224. row := xData[i*dim : (i+1)*dim]
  225. ss := cpu.DotFloat32(row, row) / float32(dim)
  226. invRMS := 1.0 / float32(math.Sqrt(float64(ss+eps)))
  227. for j := 0; j < dim; j++ {
  228. row[j] = row[j] * invRMS * wData[j]
  229. }
  230. }
  231. return nil
  232. }
  233. func hybridRMSNormGPU(ctx *Context, x *Activation, w tensor.Tensor, eps float32) error {
  234. gpu := ctx.Placement().GPU
  235. shape := x.Shape()
  236. seqLen, dim := shape[0], shape[1]
  237. wShape := w.Shape()
  238. wDim := wShape.NumElements()
  239. // For per-head normalization (qNorm/kNorm), if the dimension matches the weight dimension
  240. // when viewed as a flattened sequence of heads, we can run it on GPU by reshaping.
  241. if wDim != dim {
  242. if dim%wDim == 0 {
  243. // e.g. dim=3584, wDim=128 (28 heads).
  244. // We can treat this as [seqLen * numHeads, headDim]
  245. numHeads := dim / wDim
  246. effectiveSeqLen := seqLen * numHeads
  247. // We use the same kernel, just with modified dimensions
  248. gpuX, err := x.AsCUDA(gpu)
  249. if err != nil {
  250. return err
  251. }
  252. // Get cached weight
  253. cache := GetWeightCache(gpu)
  254. var wKey string
  255. if wCPU, ok := w.(*cpu.Tensor); ok {
  256. wKey = fmt.Sprintf("norm_%p", wCPU)
  257. } else {
  258. wKey = fmt.Sprintf("norm_%T_%p", w, w)
  259. }
  260. gpuW, ok := cache.Get(wKey)
  261. if !ok {
  262. gpuW, err = cache.Upload(wKey, w.(*cpu.Tensor))
  263. if err != nil {
  264. return fmt.Errorf("rmsnorm: cache upload failed: %w", err)
  265. }
  266. }
  267. if gpuW == nil {
  268. return fmt.Errorf("rmsnorm: got nil weight pointer from cache")
  269. }
  270. return cuda.RMSNorm(gpuX.Data().(unsafe.Pointer), gpuW, effectiveSeqLen, wDim, eps, gpu)
  271. }
  272. // Fallback to CPU if we can't reshape cleanly
  273. // Per-head normalization - fall back to CPU but restore to GPU after
  274. wasGPU := x.IsGPU()
  275. if err := hybridRMSNormCPU(x, w, eps); err != nil {
  276. return err
  277. }
  278. // Restore to GPU if it was on GPU before
  279. if wasGPU {
  280. if _, err := x.EnsureOn(ctx.Placement()); err != nil {
  281. return fmt.Errorf("restore to GPU after per-head norm: %w", err)
  282. }
  283. }
  284. return nil
  285. }
  286. gpuX, err := x.AsCUDA(gpu)
  287. if err != nil {
  288. return err
  289. }
  290. // Get cached weight
  291. cache := GetWeightCache(gpu)
  292. var wKey string
  293. if wCPU, ok := w.(*cpu.Tensor); ok {
  294. wKey = fmt.Sprintf("norm_%p", wCPU)
  295. } else {
  296. wKey = fmt.Sprintf("norm_%T_%p", w, w)
  297. }
  298. gpuW, ok := cache.Get(wKey)
  299. if !ok {
  300. gpuW, err = cache.Upload(wKey, w.(*cpu.Tensor))
  301. if err != nil {
  302. return fmt.Errorf("rmsnorm: cache upload failed: %w", err)
  303. }
  304. }
  305. if gpuW == nil {
  306. return fmt.Errorf("rmsnorm: got nil weight pointer from cache")
  307. }
  308. // Standard case: weight dimension matches activation dimension
  309. return cuda.RMSNorm(gpuX.Data().(unsafe.Pointer), gpuW, seqLen, dim, eps, gpu)
  310. }
  311. // HybridRoPE applies rotary positional embeddings in-place.
  312. func HybridRoPE(ctx *Context, x *Activation, positions []int, headDim int, theta float32) error {
  313. if x.IsGPU() && ctx != nil && ctx.IsGPU() {
  314. profile.Start("HybridRoPE/GPU")
  315. err := hybridRoPEGPU(ctx, x, positions, headDim, theta)
  316. profile.End("HybridRoPE/GPU")
  317. return err
  318. }
  319. profile.Start("HybridRoPE/CPU")
  320. err := hybridRoPECPU(x, positions, headDim, theta)
  321. profile.End("HybridRoPE/CPU")
  322. return err
  323. }
  324. func hybridRoPECPU(x *Activation, positions []int, headDim int, theta float32) error {
  325. xCPU, err := x.AsCPU()
  326. if err != nil {
  327. return err
  328. }
  329. data := xCPU.DataFloat32()
  330. shape := x.Shape()
  331. seqLen := shape[0]
  332. totalDim := shape[1]
  333. halfDim := headDim / 2
  334. invFreqs := make([]float64, halfDim)
  335. for j := 0; j < halfDim; j++ {
  336. invFreqs[j] = 1.0 / math.Pow(float64(theta), float64(2*j)/float64(headDim))
  337. }
  338. for seq := 0; seq < seqLen; seq++ {
  339. pos := positions[seq]
  340. rowStart := seq * totalDim
  341. for headStart := 0; headStart < totalDim; headStart += headDim {
  342. for j := 0; j < halfDim; j++ {
  343. freq := float64(pos) * invFreqs[j]
  344. sin, cos := math.Sincos(freq)
  345. idx0 := rowStart + headStart + j
  346. idx1 := rowStart + headStart + j + halfDim
  347. v0 := data[idx0]
  348. v1 := data[idx1]
  349. data[idx0] = v0*float32(cos) - v1*float32(sin)
  350. data[idx1] = v1*float32(cos) + v0*float32(sin)
  351. }
  352. }
  353. }
  354. return nil
  355. }
  356. func hybridRoPEGPU(ctx *Context, x *Activation, positions []int, headDim int, theta float32) error {
  357. gpu := ctx.Placement().GPU
  358. shape := x.Shape()
  359. seqLen := shape[0]
  360. totalDim := shape[1]
  361. numHeads := totalDim / headDim
  362. gpuX, err := x.AsCUDA(gpu)
  363. if err != nil {
  364. return err
  365. }
  366. // Optimization: For single-token update (decode phase), we can pass the position
  367. // directly to the kernel as a scalar, avoiding ALL memory allocation/copy overhead.
  368. if len(positions) == 1 {
  369. pos := positions[0]
  370. return cuda.RoPESingle(gpuX.Data().(unsafe.Pointer), pos, numHeads, headDim, theta, gpu)
  371. }
  372. // Upload positions as int32 (CUDA kernel expects int*)
  373. posData := make([]int32, len(positions))
  374. for i, p := range positions {
  375. posData[i] = int32(p)
  376. }
  377. var gpuPosPtr unsafe.Pointer
  378. // Try using scratch space if available to avoid malloc/free overhead
  379. if ctx != nil && ctx.Scratch != nil {
  380. gpuPosPtr, err = ctx.Scratch.GetInt32Slice(len(positions))
  381. }
  382. // Fallback or if scratch failed (or nil), allocate new
  383. shouldFree := false
  384. if gpuPosPtr == nil || err != nil {
  385. gpuPosPtr, err = cuda.AllocAndCopyInt32(posData, gpu)
  386. if err != nil {
  387. return fmt.Errorf("RoPE: upload positions: %w", err)
  388. }
  389. shouldFree = true
  390. } else {
  391. // If using scratch, we need to copy data manually
  392. // (AllocAndCopyInt32 did alloc+copy, GetInt32Slice only allocs)
  393. err = cuda.MemcpyH2D(gpuPosPtr, unsafe.Pointer(&posData[0]), uintptr(len(posData)*4), gpu)
  394. if err != nil {
  395. return fmt.Errorf("RoPE: memcpy positions: %w", err)
  396. }
  397. }
  398. if shouldFree {
  399. defer cuda.Free(gpuPosPtr)
  400. }
  401. if err := cuda.RoPE(gpuX.Data().(unsafe.Pointer), gpuPosPtr, seqLen, numHeads, headDim, theta, gpu); err != nil {
  402. return err
  403. }
  404. return nil
  405. // Synchronize REMOVED for performance.
  406. // 1. If using scratch: memory persists until end of step (reset). Safe.
  407. // 2. If using alloc+free: cudaFree is stream-ordered, so kernel will finish reading before free happens. Safe.
  408. return nil
  409. }
  410. // HybridSoftmax applies softmax along the last dimension in-place.
  411. func HybridSoftmax(ctx *Context, x *Activation) error {
  412. if x.IsGPU() && ctx != nil && ctx.IsGPU() {
  413. profile.Start("HybridSoftmax/GPU")
  414. err := hybridSoftmaxGPU(ctx, x)
  415. profile.End("HybridSoftmax/GPU")
  416. return err
  417. }
  418. profile.Start("HybridSoftmax/CPU")
  419. err := hybridSoftmaxCPU(x)
  420. profile.End("HybridSoftmax/CPU")
  421. return err
  422. }
  423. func hybridSoftmaxCPU(x *Activation) error {
  424. xCPU, err := x.AsCPU()
  425. if err != nil {
  426. return err
  427. }
  428. data := xCPU.DataFloat32()
  429. shape := x.Shape()
  430. rows, cols := shape[0], shape[1]
  431. for i := 0; i < rows; i++ {
  432. row := data[i*cols : (i+1)*cols]
  433. maxVal := row[0]
  434. for _, v := range row[1:] {
  435. if v > maxVal {
  436. maxVal = v
  437. }
  438. }
  439. sum := float32(0)
  440. for j := range row {
  441. row[j] = float32(math.Exp(float64(row[j] - maxVal)))
  442. sum += row[j]
  443. }
  444. for j := range row {
  445. row[j] /= sum
  446. }
  447. }
  448. return nil
  449. }
  450. func hybridSoftmaxGPU(ctx *Context, x *Activation) error {
  451. gpu := ctx.Placement().GPU
  452. shape := x.Shape()
  453. rows, cols := shape[0], shape[1]
  454. gpuX, err := x.AsCUDA(gpu)
  455. if err != nil {
  456. return err
  457. }
  458. return cuda.Softmax(gpuX.Data().(unsafe.Pointer), rows, cols, gpu)
  459. }
  460. // HybridSiLU applies SiLU activation in-place: x = x * sigmoid(x)
  461. func HybridSiLU(ctx *Context, x *Activation) error {
  462. if x.IsGPU() && ctx != nil && ctx.IsGPU() {
  463. profile.Start("HybridSiLU/GPU")
  464. err := hybridSiLUGPU(ctx, x)
  465. profile.End("HybridSiLU/GPU")
  466. return err
  467. }
  468. profile.Start("HybridSiLU/CPU")
  469. err := hybridSiLUCPU(x)
  470. profile.End("HybridSiLU/CPU")
  471. return err
  472. }
  473. func HybridSwiGLU(ctx *Context, gate, up, out *Activation) error {
  474. if err := HybridCopy(ctx, out, gate); err != nil {
  475. return err
  476. }
  477. if err := HybridSiLU(ctx, out); err != nil {
  478. return err
  479. }
  480. return HybridMul(ctx, out, up)
  481. }
  482. func hybridSiLUCPU(x *Activation) error {
  483. xCPU, err := x.AsCPU()
  484. if err != nil {
  485. return err
  486. }
  487. return nn.SiLU(xCPU)
  488. }
  489. func hybridSiLUGPU(ctx *Context, x *Activation) error {
  490. gpu := ctx.Placement().GPU
  491. gpuX, err := x.AsCUDA(gpu)
  492. if err != nil {
  493. return err
  494. }
  495. return cuda.SiLU(gpuX.Data().(unsafe.Pointer), x.Shape().NumElements(), gpu)
  496. }
  497. // HybridMul performs element-wise multiplication: a = a * b
  498. func HybridMul(ctx *Context, a, b *Activation) error {
  499. if a.IsGPU() && ctx != nil && ctx.IsGPU() {
  500. profile.Start("HybridMul/GPU")
  501. err := hybridMulGPU(ctx, a, b)
  502. profile.End("HybridMul/GPU")
  503. return err
  504. }
  505. profile.Start("HybridMul/CPU")
  506. err := hybridMulCPU(a, b)
  507. profile.End("HybridMul/CPU")
  508. return err
  509. }
  510. func hybridMulCPU(a, b *Activation) error {
  511. aCPU, err := a.AsCPU()
  512. if err != nil {
  513. return err
  514. }
  515. bCPU, err := b.AsCPU()
  516. if err != nil {
  517. return err
  518. }
  519. aData := aCPU.DataFloat32()
  520. bData := bCPU.DataFloat32()
  521. for i := range aData {
  522. aData[i] *= bData[i]
  523. }
  524. return nil
  525. }
  526. func hybridMulGPU(ctx *Context, a, b *Activation) error {
  527. gpu := ctx.Placement().GPU
  528. gpuA, err := a.AsCUDA(gpu)
  529. if err != nil {
  530. return err
  531. }
  532. gpuB, err := b.AsCUDA(gpu)
  533. if err != nil {
  534. return err
  535. }
  536. return cuda.MulInplace(gpuA.Data().(unsafe.Pointer), gpuB.Data().(unsafe.Pointer), a.Shape().NumElements(), gpu)
  537. }
  538. // HybridAdd performs element-wise addition: a = a + b
  539. func HybridAdd(ctx *Context, a, b *Activation) error {
  540. if a.IsGPU() && ctx != nil && ctx.IsGPU() {
  541. profile.Start("HybridAdd/GPU")
  542. err := hybridAddGPU(ctx, a, b)
  543. profile.End("HybridAdd/GPU")
  544. return err
  545. }
  546. profile.Start("HybridAdd/CPU")
  547. err := hybridAddCPU(a, b)
  548. profile.End("HybridAdd/CPU")
  549. return err
  550. }
  551. func hybridAddCPU(a, b *Activation) error {
  552. aCPU, err := a.AsCPU()
  553. if err != nil {
  554. return err
  555. }
  556. bCPU, err := b.AsCPU()
  557. if err != nil {
  558. return err
  559. }
  560. aData := aCPU.DataFloat32()
  561. bData := bCPU.DataFloat32()
  562. for i := range aData {
  563. aData[i] += bData[i]
  564. }
  565. return nil
  566. }
  567. func hybridAddGPU(ctx *Context, a, b *Activation) error {
  568. gpu := ctx.Placement().GPU
  569. gpuA, err := a.AsCUDA(gpu)
  570. if err != nil {
  571. return err
  572. }
  573. gpuB, err := b.AsCUDA(gpu)
  574. if err != nil {
  575. return err
  576. }
  577. return cuda.AddInplace(gpuA.Data().(unsafe.Pointer), gpuB.Data().(unsafe.Pointer), a.Shape().NumElements(), gpu)
  578. }
  579. // HybridAttention computes full causal attention.
  580. func HybridAttention(ctx *Context, Q, K, V, out *Activation, numHeads, numKVHeads, headDim int, scale float32, startPos int) error {
  581. if Q.IsGPU() && ctx != nil && ctx.IsGPU() {
  582. profile.Start("HybridAttention/GPU")
  583. err := hybridAttentionGPU(ctx, Q, K, V, out, numHeads, numKVHeads, headDim, scale, startPos)
  584. profile.End("HybridAttention/GPU")
  585. return err
  586. }
  587. profile.Start("HybridAttention/CPU")
  588. err := hybridAttentionCPU(Q, K, V, out, numHeads, numKVHeads, headDim, scale, startPos)
  589. profile.End("HybridAttention/CPU")
  590. return err
  591. }
  592. func hybridAttentionCPU(Q, K, V, out *Activation, numHeads, numKVHeads, headDim int, scale float32, startPos int) error {
  593. qCPU, err := Q.AsCPU()
  594. if err != nil {
  595. return err
  596. }
  597. kCPU, err := K.AsCPU()
  598. if err != nil {
  599. return err
  600. }
  601. vCPU, err := V.AsCPU()
  602. if err != nil {
  603. return err
  604. }
  605. outCPU, err := out.AsCPU()
  606. if err != nil {
  607. return err
  608. }
  609. // Prefer optimized CPU kernels from backend/cpu/nn.
  610. // This avoids per-token allocations and uses SIMD softmax/Axpy.
  611. qTensor := qCPU
  612. kTensor := kCPU
  613. vTensor := vCPU
  614. outTensor := outCPU
  615. // nn implementations include scaling internally; apply scale by scaling Q in-place into a temporary
  616. // would be costly. Instead we keep the existing API and pass scale via headDim scaling in scores.
  617. // Here we rely on nn attention using cpu.DotFloat32 and multiply by scale internally.
  618. _ = scale
  619. // Use cached causal attention when startPos is provided (decode/prefill with cache).
  620. // When seqLen == kvLen and startPos==0, this also works as standard causal attention.
  621. return nn.CausalAttentionCached(qTensor, kTensor, vTensor, outTensor, numHeads, numKVHeads, headDim, startPos)
  622. }
  623. func hybridAttentionGPU(ctx *Context, Q, K, V, out *Activation, numHeads, numKVHeads, headDim int, scale float32, startPos int) error {
  624. gpu := ctx.Placement().GPU
  625. gpuQ, err := Q.AsCUDA(gpu)
  626. if err != nil {
  627. return err
  628. }
  629. gpuK, err := K.AsCUDA(gpu)
  630. if err != nil {
  631. return err
  632. }
  633. gpuV, err := V.AsCUDA(gpu)
  634. if err != nil {
  635. return err
  636. }
  637. // Allocate output on GPU
  638. gpuOut, err := cuda.NewTensor(out.Shape(), tensor.Float32, gpu)
  639. if err != nil {
  640. return err
  641. }
  642. seqLen := Q.Shape()[0]
  643. kvLen := K.Shape()[0]
  644. err = cuda.Attention(
  645. gpuQ.Data().(unsafe.Pointer),
  646. gpuK.Data().(unsafe.Pointer),
  647. gpuV.Data().(unsafe.Pointer),
  648. gpuOut.Data().(unsafe.Pointer),
  649. seqLen, kvLen, numHeads, numKVHeads, headDim,
  650. scale, startPos, gpu,
  651. )
  652. if err != nil {
  653. return err
  654. }
  655. out.ReplaceWith(gpuOut)
  656. return nil
  657. }
  658. // HybridCopy copies src to dst.
  659. func HybridCopy(ctx *Context, dst, src *Activation) error {
  660. if dst.IsGPU() && src.IsGPU() && ctx != nil && ctx.IsGPU() {
  661. return hybridCopyGPU(ctx, dst, src)
  662. }
  663. return hybridCopyCPU(dst, src)
  664. }
  665. func hybridCopyCPU(dst, src *Activation) error {
  666. dstCPU, err := dst.AsCPU()
  667. if err != nil {
  668. return err
  669. }
  670. srcCPU, err := src.AsCPU()
  671. if err != nil {
  672. return err
  673. }
  674. copy(dstCPU.DataFloat32(), srcCPU.DataFloat32())
  675. return nil
  676. }
  677. func hybridCopyGPU(ctx *Context, dst, src *Activation) error {
  678. gpu := ctx.Placement().GPU
  679. gpuDst, err := dst.AsCUDA(gpu)
  680. if err != nil {
  681. return err
  682. }
  683. gpuSrc, err := src.AsCUDA(gpu)
  684. if err != nil {
  685. return err
  686. }
  687. return cuda.Copy(gpuDst.Data().(unsafe.Pointer), gpuSrc.Data().(unsafe.Pointer), dst.Shape().NumElements(), gpu)
  688. }
  689. // EnsureOnDevice moves activation to target device if needed.
  690. // This is the key function for hybrid execution - only transfers when crossing device boundaries.
  691. func EnsureOnDevice(a *Activation, target tensor.DevicePlacement) error {
  692. transferred, err := a.EnsureOn(target)
  693. if err != nil {
  694. return err
  695. }
  696. if transferred {
  697. // Log for debugging (can be removed later)
  698. _ = transferred
  699. }
  700. return nil
  701. }