kda.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480
  1. package compute
  2. import (
  3. "fmt"
  4. "unsafe"
  5. "makarna/pkg/backend/cpu"
  6. "makarna/pkg/backend/cpu/nn"
  7. "makarna/pkg/backend/cuda"
  8. "makarna/pkg/backend/device"
  9. "makarna/pkg/tensor"
  10. )
  11. // HybridKDA runs Kimi Delta Attention (KDA) block and writes output to attnOut.
  12. //
  13. // Current behavior:
  14. // - CPU path is implemented.
  15. // - If the current layer is placed on GPU (ctx.IsGPU() && hidden.IsGPU()), this returns an error
  16. // until CUDA kernels are implemented.
  17. func HybridKDA(
  18. ctx *Context,
  19. hidden *Activation,
  20. qProj, kProj, vProj tensor.Tensor,
  21. qConv, kConv, vConv tensor.Tensor,
  22. fAProj, fBProj, bProj tensor.Tensor,
  23. aLog tensor.Tensor,
  24. dtBias tensor.Tensor,
  25. gAProj, gBProj tensor.Tensor,
  26. oNorm tensor.Tensor,
  27. oProj tensor.Tensor,
  28. convQState, convKState, convVState tensor.Tensor,
  29. recurrentState tensor.Tensor,
  30. seqLen, numHeads, headDim, shortConvKernel int,
  31. eps float32,
  32. attnOut *Activation,
  33. ) error {
  34. if ctx != nil && ctx.IsGPU() && hidden.IsGPU() {
  35. return hybridKDAGPU(
  36. ctx,
  37. hidden,
  38. qProj, kProj, vProj,
  39. qConv, kConv, vConv,
  40. fAProj, fBProj, bProj,
  41. aLog,
  42. dtBias,
  43. gAProj, gBProj,
  44. oNorm,
  45. oProj,
  46. convQState, convKState, convVState,
  47. recurrentState,
  48. seqLen, numHeads, headDim, shortConvKernel,
  49. eps,
  50. attnOut,
  51. )
  52. }
  53. return hybridKDACPU(
  54. hidden,
  55. qProj, kProj, vProj,
  56. qConv, kConv, vConv,
  57. fAProj, fBProj, bProj,
  58. aLog,
  59. dtBias,
  60. gAProj, gBProj,
  61. oNorm,
  62. oProj,
  63. convQState, convKState, convVState,
  64. recurrentState,
  65. seqLen, numHeads, headDim, shortConvKernel,
  66. eps,
  67. attnOut,
  68. )
  69. }
  70. func hybridKDAGPU(
  71. ctx *Context,
  72. hidden *Activation,
  73. qProj, kProj, vProj tensor.Tensor,
  74. qConv, kConv, vConv tensor.Tensor,
  75. fAProj, fBProj, bProj tensor.Tensor,
  76. aLog tensor.Tensor,
  77. dtBias tensor.Tensor,
  78. gAProj, gBProj tensor.Tensor,
  79. oNorm tensor.Tensor,
  80. oProj tensor.Tensor,
  81. convQState, convKState, convVState tensor.Tensor,
  82. recurrentState tensor.Tensor,
  83. seqLen, numHeads, headDim, shortConvKernel int,
  84. eps float32,
  85. attnOut *Activation,
  86. ) error {
  87. if ctx == nil || !ctx.IsGPU() {
  88. return fmt.Errorf("HybridKDA/GPU: missing GPU context")
  89. }
  90. if !device.CUDAAvailable() || !cuda.Available() {
  91. return fmt.Errorf("HybridKDA/GPU: CUDA not available")
  92. }
  93. gpu := ctx.Placement().GPU
  94. projSize := numHeads * headDim
  95. alloc := func(shape tensor.Shape) (*Activation, error) {
  96. if ctx.Scratch != nil {
  97. if act, err := ctx.Scratch.GetTensor(shape, tensor.Float32); err == nil {
  98. return act, nil
  99. }
  100. }
  101. return NewActivation(shape, tensor.DevicePlacement{Type: tensor.CUDA, GPU: gpu})
  102. }
  103. // Project to Q/K/V on GPU.
  104. qAct, err := alloc(tensor.Shape{seqLen, projSize})
  105. if err != nil {
  106. return err
  107. }
  108. kAct, err := alloc(tensor.Shape{seqLen, projSize})
  109. if err != nil {
  110. return err
  111. }
  112. vAct, err := alloc(tensor.Shape{seqLen, projSize})
  113. if err != nil {
  114. return err
  115. }
  116. if err := HybridLinear(ctx, hidden, qProj, qAct); err != nil {
  117. return err
  118. }
  119. if err := HybridLinear(ctx, hidden, kProj, kAct); err != nil {
  120. return err
  121. }
  122. if err := HybridLinear(ctx, hidden, vProj, vAct); err != nil {
  123. return err
  124. }
  125. qCUDA, _ := qAct.AsCUDA(gpu)
  126. kCUDA, _ := kAct.AsCUDA(gpu)
  127. vCUDA, _ := vAct.AsCUDA(gpu)
  128. qStateCUDA, ok := convQState.(*cuda.Tensor)
  129. if !ok {
  130. return fmt.Errorf("HybridKDA/GPU: convQState not cuda tensor")
  131. }
  132. kStateCUDA, ok := convKState.(*cuda.Tensor)
  133. if !ok {
  134. return fmt.Errorf("HybridKDA/GPU: convKState not cuda tensor")
  135. }
  136. vStateCUDA, ok := convVState.(*cuda.Tensor)
  137. if !ok {
  138. return fmt.Errorf("HybridKDA/GPU: convVState not cuda tensor")
  139. }
  140. cache := GetWeightCache(gpu)
  141. uploadF32 := func(label string, w tensor.Tensor) (unsafe.Pointer, error) {
  142. if w == nil {
  143. return nil, fmt.Errorf("HybridKDA/GPU: missing %s", label)
  144. }
  145. if wt, ok := w.(*cuda.Tensor); ok {
  146. if wt.GPU() != gpu {
  147. return nil, fmt.Errorf("HybridKDA/GPU: %s on gpu=%d (want %d)", label, wt.GPU(), gpu)
  148. }
  149. if wt.DType() != tensor.Float32 {
  150. return nil, fmt.Errorf("HybridKDA/GPU: %s dtype=%v (want Float32)", label, wt.DType())
  151. }
  152. return wt.Data().(unsafe.Pointer), nil
  153. }
  154. wCPU, ok := w.(*cpu.Tensor)
  155. if !ok {
  156. return nil, fmt.Errorf("HybridKDA/GPU: %s not cpu/cuda tensor (%T)", label, w)
  157. }
  158. key := fmt.Sprintf("kda_l%d_%s", ctx.LayerIdx, label)
  159. if ptr, ok := cache.Get(key); ok {
  160. return ptr, nil
  161. }
  162. ptr, err := cache.Upload(key, wCPU)
  163. if err != nil {
  164. return nil, fmt.Errorf("HybridKDA/GPU: upload %s: %w", label, err)
  165. }
  166. return ptr, nil
  167. }
  168. qWPtr, err := uploadF32("qconv", qConv)
  169. if err != nil {
  170. return err
  171. }
  172. kWPtr, err := uploadF32("kconv", kConv)
  173. if err != nil {
  174. return err
  175. }
  176. vWPtr, err := uploadF32("vconv", vConv)
  177. if err != nil {
  178. return err
  179. }
  180. // 1. Conv1d + SiLU
  181. if err := cuda.KDACausalShortConv1D(qCUDA.Data().(unsafe.Pointer), qStateCUDA.Data().(unsafe.Pointer), qWPtr, seqLen, projSize, shortConvKernel, gpu); err != nil {
  182. return err
  183. }
  184. if err := cuda.KDACausalShortConv1D(kCUDA.Data().(unsafe.Pointer), kStateCUDA.Data().(unsafe.Pointer), kWPtr, seqLen, projSize, shortConvKernel, gpu); err != nil {
  185. return err
  186. }
  187. if err := cuda.KDACausalShortConv1D(vCUDA.Data().(unsafe.Pointer), vStateCUDA.Data().(unsafe.Pointer), vWPtr, seqLen, projSize, shortConvKernel, gpu); err != nil {
  188. return err
  189. }
  190. // 2. L2 Norm Q/K
  191. if err := cuda.L2NormHeads(qCUDA.Data().(unsafe.Pointer), kCUDA.Data().(unsafe.Pointer), seqLen, numHeads, headDim, 1e-6, gpu); err != nil {
  192. return err
  193. }
  194. // 3. Beta projection + sigmoid
  195. betaAct, _ := alloc(tensor.Shape{seqLen, numHeads})
  196. if err := HybridLinear(ctx, hidden, bProj, betaAct); err != nil {
  197. return err
  198. }
  199. betaCUDA, err := betaAct.AsCUDA(gpu)
  200. if err != nil {
  201. return err
  202. }
  203. if err := cuda.Sigmoid(betaCUDA.Data().(unsafe.Pointer), seqLen*numHeads, gpu); err != nil {
  204. return err
  205. }
  206. // 4. Gate computation (f_a -> f_b -> gate)
  207. gAAct, _ := alloc(tensor.Shape{seqLen, headDim})
  208. if err := HybridLinear(ctx, hidden, fAProj, gAAct); err != nil {
  209. return err
  210. }
  211. gBAct, _ := alloc(tensor.Shape{seqLen, projSize})
  212. if err := HybridLinear(ctx, gAAct, fBProj, gBAct); err != nil {
  213. return err
  214. }
  215. // Upload aLog and dtBias (cached on GPU).
  216. aLogCPU, ok := aLog.(*cpu.Tensor)
  217. if !ok {
  218. return fmt.Errorf("HybridKDA/GPU: aLog not cpu tensor")
  219. }
  220. aLogFlat, err := nn.FlattenALog(aLogCPU, numHeads)
  221. if err != nil {
  222. return err
  223. }
  224. aLogView := cpu.NewTensor(tensor.Shape{numHeads}, aLogFlat[:numHeads])
  225. aLogPtr, err := uploadF32("alog", aLogView)
  226. if err != nil {
  227. return err
  228. }
  229. dtBiasCPU, ok := dtBias.(*cpu.Tensor)
  230. if !ok {
  231. return fmt.Errorf("HybridKDA/GPU: dtBias not cpu tensor")
  232. }
  233. dtBiasPtr, err := uploadF32("dtbias", dtBiasCPU)
  234. if err != nil {
  235. return err
  236. }
  237. gBCUDA, _ := gBAct.AsCUDA(gpu)
  238. gOutAct, err := alloc(tensor.Shape{seqLen, projSize})
  239. if err != nil {
  240. return err
  241. }
  242. gOutCUDA, _ := gOutAct.AsCUDA(gpu)
  243. if err := cuda.KDAGate(gBCUDA.Data().(unsafe.Pointer), aLogPtr, dtBiasPtr, gOutCUDA.Data().(unsafe.Pointer), seqLen, numHeads, headDim, gpu); err != nil {
  244. return err
  245. }
  246. // 5. Recurrent state update
  247. stateCUDA, ok := recurrentState.(*cuda.Tensor)
  248. if !ok {
  249. return fmt.Errorf("HybridKDA/GPU: recurrentState not cuda tensor")
  250. }
  251. if err := cuda.KDARecurrent(
  252. qCUDA.Data().(unsafe.Pointer),
  253. kCUDA.Data().(unsafe.Pointer),
  254. vCUDA.Data().(unsafe.Pointer),
  255. gOutCUDA.Data().(unsafe.Pointer),
  256. betaCUDA.Data().(unsafe.Pointer),
  257. stateCUDA.Data().(unsafe.Pointer),
  258. seqLen, numHeads, headDim, gpu,
  259. ); err != nil {
  260. return err
  261. }
  262. // 6. Output gate (g_a -> g_b)
  263. gGateAAct, _ := alloc(tensor.Shape{seqLen, headDim})
  264. if err := HybridLinear(ctx, hidden, gAProj, gGateAAct); err != nil {
  265. return err
  266. }
  267. gGateBAct, _ := alloc(tensor.Shape{seqLen, projSize})
  268. if err := HybridLinear(ctx, gGateAAct, gBProj, gGateBAct); err != nil {
  269. return err
  270. }
  271. gGateBCUDA, _ := gGateBAct.AsCUDA(gpu)
  272. // 7. RMSNorm gated (v now contains output from recurrent)
  273. if oNorm != nil {
  274. oNormCPU, ok := oNorm.(*cpu.Tensor)
  275. if !ok {
  276. return fmt.Errorf("HybridKDA/GPU: oNorm not cpu tensor")
  277. }
  278. oNormPtr, err := uploadF32("onorm", oNormCPU)
  279. if err != nil {
  280. return err
  281. }
  282. if err := cuda.RMSNormGated(vCUDA.Data().(unsafe.Pointer), gGateBCUDA.Data().(unsafe.Pointer), oNormPtr, seqLen*projSize, headDim, eps, gpu); err != nil {
  283. return err
  284. }
  285. }
  286. // 8. Output projection
  287. coreAct := NewActivationFrom(vCUDA)
  288. if err := HybridLinear(ctx, coreAct, oProj, attnOut); err != nil {
  289. return err
  290. }
  291. return nil
  292. }
  293. func uploadWeightGPU(w tensor.Tensor, gpu int) (*cuda.Tensor, error) {
  294. wCPU := w.(*cpu.Tensor)
  295. wDev, err := cuda.NewTensor(wCPU.Shape(), tensor.Float32, gpu)
  296. if err != nil {
  297. return nil, err
  298. }
  299. if err := wDev.CopyFrom(wCPU.DataFloat32()); err != nil {
  300. wDev.Free()
  301. return nil, err
  302. }
  303. return wDev, nil
  304. }
  305. func hybridKDACPU(
  306. hidden *Activation,
  307. qProj, kProj, vProj tensor.Tensor,
  308. qConv, kConv, vConv tensor.Tensor,
  309. fAProj, fBProj, bProj tensor.Tensor,
  310. aLog tensor.Tensor,
  311. dtBias tensor.Tensor,
  312. gAProj, gBProj tensor.Tensor,
  313. oNorm tensor.Tensor,
  314. oProj tensor.Tensor,
  315. convQState, convKState, convVState tensor.Tensor,
  316. recurrentState tensor.Tensor,
  317. seqLen, numHeads, headDim, shortConvKernel int,
  318. eps float32,
  319. attnOut *Activation,
  320. ) error {
  321. projSize := numHeads * headDim
  322. hiddenCPU, err := hidden.AsCPU()
  323. if err != nil {
  324. return err
  325. }
  326. qAct := NewActivationFrom(cpu.NewTensor(tensor.Shape{seqLen, projSize}, nil))
  327. kAct := NewActivationFrom(cpu.NewTensor(tensor.Shape{seqLen, projSize}, nil))
  328. vAct := NewActivationFrom(cpu.NewTensor(tensor.Shape{seqLen, projSize}, nil))
  329. // Note: we intentionally call HybridLinear(nil, ...) to force CPU path.
  330. if err := HybridLinear(nil, NewActivationFrom(hiddenCPU), qProj, qAct); err != nil {
  331. return err
  332. }
  333. if err := HybridLinear(nil, NewActivationFrom(hiddenCPU), kProj, kAct); err != nil {
  334. return err
  335. }
  336. if err := HybridLinear(nil, NewActivationFrom(hiddenCPU), vProj, vAct); err != nil {
  337. return err
  338. }
  339. qCPU, _ := qAct.AsCPU()
  340. kCPU, _ := kAct.AsCPU()
  341. vCPU, _ := vAct.AsCPU()
  342. qConvStateCPU, ok := convQState.(*cpu.Tensor)
  343. if !ok {
  344. return fmt.Errorf("KDA: convQState not cpu tensor")
  345. }
  346. kConvStateCPU, ok := convKState.(*cpu.Tensor)
  347. if !ok {
  348. return fmt.Errorf("KDA: convKState not cpu tensor")
  349. }
  350. vConvStateCPU, ok := convVState.(*cpu.Tensor)
  351. if !ok {
  352. return fmt.Errorf("KDA: convVState not cpu tensor")
  353. }
  354. qConvW, ok := qConv.(*cpu.Tensor)
  355. if !ok {
  356. return fmt.Errorf("KDA: qConv not cpu tensor")
  357. }
  358. kConvW, ok := kConv.(*cpu.Tensor)
  359. if !ok {
  360. return fmt.Errorf("KDA: kConv not cpu tensor")
  361. }
  362. vConvW, ok := vConv.(*cpu.Tensor)
  363. if !ok {
  364. return fmt.Errorf("KDA: vConv not cpu tensor")
  365. }
  366. if err := nn.CausalShortConv1DInplace(qCPU.DataFloat32(), qConvStateCPU, qConvW, seqLen, projSize, shortConvKernel); err != nil {
  367. return err
  368. }
  369. if err := nn.CausalShortConv1DInplace(kCPU.DataFloat32(), kConvStateCPU, kConvW, seqLen, projSize, shortConvKernel); err != nil {
  370. return err
  371. }
  372. if err := nn.CausalShortConv1DInplace(vCPU.DataFloat32(), vConvStateCPU, vConvW, seqLen, projSize, shortConvKernel); err != nil {
  373. return err
  374. }
  375. nn.L2NormHeads(qCPU.DataFloat32(), kCPU.DataFloat32(), seqLen, numHeads, headDim, 1e-6)
  376. betaAct := NewActivationFrom(cpu.NewTensor(tensor.Shape{seqLen, numHeads}, nil))
  377. if err := HybridLinear(nil, NewActivationFrom(hiddenCPU), bProj, betaAct); err != nil {
  378. return err
  379. }
  380. betaCPU, _ := betaAct.AsCPU()
  381. nn.SigmoidInplace(betaCPU.DataFloat32())
  382. gAAct := NewActivationFrom(cpu.NewTensor(tensor.Shape{seqLen, headDim}, nil))
  383. if err := HybridLinear(nil, NewActivationFrom(hiddenCPU), fAProj, gAAct); err != nil {
  384. return err
  385. }
  386. gBAct := NewActivationFrom(cpu.NewTensor(tensor.Shape{seqLen, projSize}, nil))
  387. if err := HybridLinear(nil, gAAct, fBProj, gBAct); err != nil {
  388. return err
  389. }
  390. gBCPU, _ := gBAct.AsCPU()
  391. aLogCPU, ok := aLog.(*cpu.Tensor)
  392. if !ok {
  393. return fmt.Errorf("KDA: aLog not cpu tensor")
  394. }
  395. dtBiasCPU, ok := dtBias.(*cpu.Tensor)
  396. if !ok {
  397. return fmt.Errorf("KDA: dtBias not cpu tensor")
  398. }
  399. aLogFlat, err := nn.FlattenALog(aLogCPU, numHeads)
  400. if err != nil {
  401. return err
  402. }
  403. gOut := make([]float32, seqLen*projSize)
  404. for t := 0; t < seqLen; t++ {
  405. gTok := gBCPU.DataFloat32()[t*projSize : (t+1)*projSize]
  406. gTok2 := nn.KDAGate(gTok, aLogFlat, headDim, dtBiasCPU.DataFloat32())
  407. copy(gOut[t*projSize:(t+1)*projSize], gTok2)
  408. }
  409. stCPU, ok := recurrentState.(*cpu.Tensor)
  410. if !ok {
  411. return fmt.Errorf("KDA: recurrentState not cpu tensor")
  412. }
  413. coreFlat := make([]float32, seqLen*projSize)
  414. copy(coreFlat, vCPU.DataFloat32())
  415. if err := nn.KDARecurrent(qCPU.DataFloat32(), kCPU.DataFloat32(), coreFlat, gOut, betaCPU.DataFloat32(), stCPU.DataFloat32(), seqLen, numHeads, headDim); err != nil {
  416. return err
  417. }
  418. gGateAAct := NewActivationFrom(cpu.NewTensor(tensor.Shape{seqLen, headDim}, nil))
  419. if err := HybridLinear(nil, NewActivationFrom(hiddenCPU), gAProj, gGateAAct); err != nil {
  420. return err
  421. }
  422. gGateBAct := NewActivationFrom(cpu.NewTensor(tensor.Shape{seqLen, projSize}, nil))
  423. if err := HybridLinear(nil, gGateAAct, gBProj, gGateBAct); err != nil {
  424. return err
  425. }
  426. gGateBCPU, _ := gGateBAct.AsCPU()
  427. if oNorm != nil {
  428. if w, ok := oNorm.(*cpu.Tensor); ok {
  429. nn.RMSNormGated(coreFlat, gGateBCPU.DataFloat32(), w.DataFloat32(), headDim, eps)
  430. }
  431. }
  432. coreAct := NewActivationFrom(cpu.NewTensor(tensor.Shape{seqLen, projSize}, coreFlat))
  433. if err := HybridLinear(nil, coreAct, oProj, attnOut); err != nil {
  434. return err
  435. }
  436. return nil
  437. }