hybrid_decoder.go 37 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187
  1. package arch
  2. import (
  3. "fmt"
  4. "math"
  5. "unsafe"
  6. "makarna/pkg/backend/cpu"
  7. "makarna/pkg/backend/cpu/nn"
  8. "makarna/pkg/backend/cuda"
  9. "makarna/pkg/compute"
  10. "makarna/pkg/kvcache"
  11. "makarna/pkg/profile"
  12. "makarna/pkg/tensor"
  13. )
  14. type HybridDecoderLayerWeights struct {
  15. Idx int
  16. AttnNorm tensor.Tensor
  17. Wq tensor.Tensor
  18. Wk tensor.Tensor
  19. Wv tensor.Tensor
  20. Wo tensor.Tensor
  21. QNorm tensor.Tensor
  22. KNorm tensor.Tensor
  23. MlpNorm tensor.Tensor
  24. WGate tensor.Tensor
  25. WUp tensor.Tensor
  26. WDown tensor.Tensor
  27. }
  28. type HybridDecoderConfig struct {
  29. HiddenSize int
  30. NumHeads int
  31. NumKVHeads int
  32. Intermediate int
  33. HeadDim int
  34. RopeTheta float32
  35. }
  36. func HybridDecoderBlock(ctx *compute.Context, hidden *compute.Activation, layer *HybridDecoderLayerWeights, positions []int, cache kvcache.KVCacheInterface, cfg HybridDecoderConfig, eps float32) error {
  37. cfg = normalizeHybridDecoderConfig(cfg)
  38. if err := ensureActivationOnContextDevice(ctx, hidden); err != nil {
  39. return err
  40. }
  41. alloc := makeHybridAllocator(ctx)
  42. residual, err := cloneActivation(ctx, alloc, hidden, false)
  43. if err != nil {
  44. return err
  45. }
  46. postAttn, err := runHybridAttentionBlock(ctx, alloc, hidden, residual, layer, positions, cache, cfg, eps)
  47. if err != nil {
  48. return err
  49. }
  50. out, err := runHybridMLPBlock(ctx, alloc, postAttn, layer, cfg, eps)
  51. if err != nil {
  52. return err
  53. }
  54. hidden.ReplaceWith(out.Tensor())
  55. return nil
  56. }
  57. func HybridDecoderBlockBatch(ctx *compute.Context, hidden *compute.Activation, layer *HybridDecoderLayerWeights, positions []int, caches []kvcache.KVCacheInterface, cfg HybridDecoderConfig, eps float32) error {
  58. cfg = normalizeHybridDecoderConfig(cfg)
  59. if err := ensureActivationOnContextDevice(ctx, hidden); err != nil {
  60. return err
  61. }
  62. if len(caches) != hidden.Shape()[0] {
  63. return fmt.Errorf("caches len %d != hidden batch %d", len(caches), hidden.Shape()[0])
  64. }
  65. alloc := makeHybridAllocator(ctx)
  66. residual, err := cloneActivation(ctx, alloc, hidden, false)
  67. if err != nil {
  68. return err
  69. }
  70. postAttn, err := runHybridAttentionBlockBatch(ctx, alloc, hidden, residual, layer, positions, caches, cfg, eps)
  71. if err != nil {
  72. return err
  73. }
  74. out, err := runHybridMLPBlock(ctx, alloc, postAttn, layer, cfg, eps)
  75. if err != nil {
  76. return err
  77. }
  78. hidden.ReplaceWith(out.Tensor())
  79. return nil
  80. }
  81. func normalizeHybridDecoderConfig(cfg HybridDecoderConfig) HybridDecoderConfig {
  82. if cfg.NumKVHeads == 0 {
  83. cfg.NumKVHeads = cfg.NumHeads
  84. }
  85. if cfg.Intermediate == 0 {
  86. cfg.Intermediate = cfg.HiddenSize * 4
  87. }
  88. return cfg
  89. }
  90. func ensureActivationOnContextDevice(ctx *compute.Context, a *compute.Activation) error {
  91. if ctx == nil {
  92. return nil
  93. }
  94. if _, err := a.EnsureOn(ctx.Placement()); err != nil {
  95. return fmt.Errorf("move hidden to target device: %w", err)
  96. }
  97. return nil
  98. }
  99. type hybridAllocFn func(shape tensor.Shape, placement tensor.DevicePlacement, allowScratch bool) (*compute.Activation, error)
  100. func makeHybridAllocator(ctx *compute.Context) hybridAllocFn {
  101. return func(shape tensor.Shape, placement tensor.DevicePlacement, allowScratch bool) (*compute.Activation, error) {
  102. placement = placement.Normalize()
  103. if allowScratch && ctx != nil && ctx.Scratch != nil && placement.Type == tensor.CUDA && ctx.Scratch.GPU() == placement.GPU {
  104. if a, err := ctx.Scratch.GetTensor(shape, tensor.Float32); err == nil {
  105. return a, nil
  106. }
  107. }
  108. return compute.NewActivation(shape, placement)
  109. }
  110. }
  111. func cloneActivation(ctx *compute.Context, alloc hybridAllocFn, src *compute.Activation, allowScratch bool) (*compute.Activation, error) {
  112. dst, err := alloc(src.Shape(), src.Placement(), allowScratch)
  113. if err != nil {
  114. return nil, fmt.Errorf("alloc residual: %w", err)
  115. }
  116. if err := compute.HybridCopy(ctx, dst, src); err != nil {
  117. return nil, fmt.Errorf("copy residual: %w", err)
  118. }
  119. return dst, nil
  120. }
  121. func runHybridAttentionBlock(ctx *compute.Context, alloc hybridAllocFn, hidden *compute.Activation, residual *compute.Activation, layer *HybridDecoderLayerWeights, positions []int, cache kvcache.KVCacheInterface, cfg HybridDecoderConfig, eps float32) (*compute.Activation, error) {
  122. seqLen := hidden.Shape()[0]
  123. profile.Start("Block/AttnNorm")
  124. if err := compute.HybridRMSNorm(ctx, hidden, layer.AttnNorm, eps); err != nil {
  125. profile.End("Block/AttnNorm")
  126. return nil, fmt.Errorf("attn norm: %w", err)
  127. }
  128. profile.End("Block/AttnNorm")
  129. qOut, err := alloc(tensor.Shape{seqLen, cfg.NumHeads * cfg.HeadDim}, ctx.Placement(), true)
  130. if err != nil {
  131. return nil, err
  132. }
  133. kOut, err := alloc(tensor.Shape{seqLen, cfg.NumKVHeads * cfg.HeadDim}, ctx.Placement(), true)
  134. if err != nil {
  135. return nil, err
  136. }
  137. vOut, err := alloc(tensor.Shape{seqLen, cfg.NumKVHeads * cfg.HeadDim}, ctx.Placement(), true)
  138. if err != nil {
  139. return nil, err
  140. }
  141. profile.Start("Block/QProj")
  142. if err := compute.HybridLinear(ctx, hidden, layer.Wq, qOut); err != nil {
  143. profile.End("Block/QProj")
  144. return nil, fmt.Errorf("q_proj: %w", err)
  145. }
  146. profile.End("Block/QProj")
  147. profile.Start("Block/KProj")
  148. if err := compute.HybridLinear(ctx, hidden, layer.Wk, kOut); err != nil {
  149. profile.End("Block/KProj")
  150. return nil, fmt.Errorf("k_proj: %w", err)
  151. }
  152. profile.End("Block/KProj")
  153. profile.Start("Block/VProj")
  154. if err := compute.HybridLinear(ctx, hidden, layer.Wv, vOut); err != nil {
  155. profile.End("Block/VProj")
  156. return nil, fmt.Errorf("v_proj: %w", err)
  157. }
  158. profile.End("Block/VProj")
  159. if layer.QNorm != nil {
  160. profile.Start("Block/QNorm")
  161. if err := compute.HybridRMSNorm(ctx, qOut, layer.QNorm, eps); err != nil {
  162. profile.End("Block/QNorm")
  163. return nil, fmt.Errorf("q_norm: %w", err)
  164. }
  165. profile.End("Block/QNorm")
  166. }
  167. if layer.KNorm != nil {
  168. profile.Start("Block/KNorm")
  169. if err := compute.HybridRMSNorm(ctx, kOut, layer.KNorm, eps); err != nil {
  170. profile.End("Block/KNorm")
  171. return nil, fmt.Errorf("k_norm: %w", err)
  172. }
  173. profile.End("Block/KNorm")
  174. }
  175. // Decide whether we can fuse RoPE inside the paged attention kernel.
  176. // This skips the standalone RoPE kernel launches and avoids an extra read/modify/write of Q/K.
  177. layerCacheDev := tensor.DevicePlacement{Type: tensor.CPU, GPU: -1}
  178. var pc *kvcache.PagedKVCache
  179. useFusedRoPE := false
  180. if cache != nil {
  181. layerCacheDev = cache.LayerDevice(layer.Idx).Normalize()
  182. if p, ok := cache.(*kvcache.PagedKVCache); ok {
  183. pc = p
  184. }
  185. }
  186. if pc != nil && layerCacheDev.Type == tensor.CUDA && ctx != nil && ctx.IsGPU() && cuda.Available() {
  187. gpu := ctx.Placement().GPU
  188. // Fused kernel supports headDim<=256 and even headDim.
  189. if pc.LayerDevice(layer.Idx).GPU == gpu && cfg.HeadDim <= 256 && (cfg.HeadDim&1) == 0 && cfg.RopeTheta != 0 {
  190. useFusedRoPE = true
  191. }
  192. }
  193. if !useFusedRoPE {
  194. profile.Start("Block/RoPE_Q")
  195. if err := compute.HybridRoPE(ctx, qOut, positions, cfg.HeadDim, cfg.RopeTheta); err != nil {
  196. profile.End("Block/RoPE_Q")
  197. return nil, fmt.Errorf("rope q: %w", err)
  198. }
  199. profile.End("Block/RoPE_Q")
  200. profile.Start("Block/RoPE_K")
  201. if err := compute.HybridRoPE(ctx, kOut, positions, cfg.HeadDim, cfg.RopeTheta); err != nil {
  202. profile.End("Block/RoPE_K")
  203. return nil, fmt.Errorf("rope k: %w", err)
  204. }
  205. profile.End("Block/RoPE_K")
  206. }
  207. attnOut, err := alloc(tensor.Shape{seqLen, cfg.NumHeads * cfg.HeadDim}, ctx.Placement(), true)
  208. if err != nil {
  209. return nil, err
  210. }
  211. scale := float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
  212. profile.Start("Block/Attention")
  213. didPaged := false
  214. if cache != nil && layerCacheDev.Type == tensor.CUDA && ctx != nil && ctx.IsGPU() && cuda.Available() {
  215. if pc != nil {
  216. startPos := cache.SeqLen()
  217. gpu := ctx.Placement().GPU
  218. if pc.LayerDevice(layer.Idx).GPU != gpu {
  219. profile.End("Block/Attention")
  220. return nil, fmt.Errorf("paged attention requires cache layer %d on gpu %d (got gpu %d)", layer.Idx, gpu, pc.LayerDevice(layer.Idx).GPU)
  221. }
  222. if _, _, err := cache.Append(layer.Idx, kOut.Tensor(), vOut.Tensor()); err != nil {
  223. profile.End("Block/Attention")
  224. return nil, fmt.Errorf("cache append: %w", err)
  225. }
  226. kvLen := startPos + seqLen
  227. bs := pc.BlockSize()
  228. numBlocks := (kvLen + bs - 1) / bs
  229. var (
  230. kDev unsafe.Pointer
  231. vDev unsafe.Pointer
  232. freeAfter bool
  233. kvType tensor.DType
  234. blockSize int
  235. )
  236. // Prefer persistent per-layer device pointer tables (rebuilt only when numBlocks grows).
  237. kDev, vDev, blockSize, kvType, err = pc.LayerDevicePtrTables(layer.Idx, numBlocks)
  238. if err != nil {
  239. // Fallback to per-call scratch allocation/copy if needed.
  240. kPtrs, vPtrs, blockSize2, kvType2, err2 := pc.LayerBlockPtrTables(layer.Idx, numBlocks)
  241. if err2 != nil {
  242. profile.End("Block/Attention")
  243. return nil, fmt.Errorf("kv ptr tables: %w", err2)
  244. }
  245. blockSize = blockSize2
  246. kvType = kvType2
  247. if ctx != nil && ctx.Scratch != nil {
  248. kDev, err = ctx.Scratch.GetUintptrSlice(len(kPtrs))
  249. if err != nil {
  250. profile.End("Block/Attention")
  251. return nil, fmt.Errorf("scratch K ptr table: %w", err)
  252. }
  253. vDev, err = ctx.Scratch.GetUintptrSlice(len(vPtrs))
  254. if err != nil {
  255. profile.End("Block/Attention")
  256. return nil, fmt.Errorf("scratch V ptr table: %w", err)
  257. }
  258. if err := cuda.MemcpyH2D(kDev, unsafe.Pointer(&kPtrs[0]), uintptr(len(kPtrs))*unsafe.Sizeof(uintptr(0)), gpu); err != nil {
  259. profile.End("Block/Attention")
  260. return nil, fmt.Errorf("memcpy K ptr table: %w", err)
  261. }
  262. if err := cuda.MemcpyH2D(vDev, unsafe.Pointer(&vPtrs[0]), uintptr(len(vPtrs))*unsafe.Sizeof(uintptr(0)), gpu); err != nil {
  263. profile.End("Block/Attention")
  264. return nil, fmt.Errorf("memcpy V ptr table: %w", err)
  265. }
  266. } else {
  267. freeAfter = true
  268. kDev, err = cuda.AllocAndCopyPtrTable(kPtrs, gpu)
  269. if err != nil {
  270. profile.End("Block/Attention")
  271. return nil, fmt.Errorf("alloc K ptr table: %w", err)
  272. }
  273. vDev, err = cuda.AllocAndCopyPtrTable(vPtrs, gpu)
  274. if err != nil {
  275. cuda.FreeDevicePtr(kDev)
  276. profile.End("Block/Attention")
  277. return nil, fmt.Errorf("alloc V ptr table: %w", err)
  278. }
  279. }
  280. }
  281. if freeAfter {
  282. defer cuda.FreeDevicePtr(kDev)
  283. defer cuda.FreeDevicePtr(vDev)
  284. }
  285. gpuQ, err := qOut.AsCUDA(gpu)
  286. if err != nil {
  287. profile.End("Block/Attention")
  288. return nil, fmt.Errorf("q to cuda: %w", err)
  289. }
  290. gpuOut, err := attnOut.AsCUDA(gpu)
  291. if err != nil {
  292. profile.End("Block/Attention")
  293. return nil, fmt.Errorf("attn out to cuda: %w", err)
  294. }
  295. // Single-request fast path: use the non-batch paged attention kernel for any seqLen.
  296. // This avoids per-call device allocations for blockOffsets/kvLens/queryPos.
  297. if kvType == tensor.Float16 {
  298. if useFusedRoPE {
  299. err = cuda.PagedAttentionRoPEF32F16KV(
  300. gpuQ.Data().(unsafe.Pointer),
  301. kDev,
  302. vDev,
  303. gpuOut.Data().(unsafe.Pointer),
  304. seqLen, kvLen,
  305. cfg.NumHeads, cfg.NumKVHeads, cfg.HeadDim,
  306. blockSize,
  307. scale, startPos,
  308. cfg.RopeTheta,
  309. gpu,
  310. )
  311. } else {
  312. err = cuda.PagedAttentionF32F16KV(
  313. gpuQ.Data().(unsafe.Pointer),
  314. kDev,
  315. vDev,
  316. gpuOut.Data().(unsafe.Pointer),
  317. seqLen, kvLen,
  318. cfg.NumHeads, cfg.NumKVHeads, cfg.HeadDim,
  319. blockSize,
  320. scale, startPos,
  321. gpu,
  322. )
  323. }
  324. } else {
  325. err = cuda.PagedAttention(
  326. gpuQ.Data().(unsafe.Pointer),
  327. kDev,
  328. vDev,
  329. gpuOut.Data().(unsafe.Pointer),
  330. seqLen, kvLen,
  331. cfg.NumHeads, cfg.NumKVHeads, cfg.HeadDim,
  332. blockSize,
  333. scale, startPos,
  334. gpu,
  335. )
  336. }
  337. if err != nil {
  338. profile.End("Block/Attention")
  339. return nil, fmt.Errorf("paged attention: %w", err)
  340. }
  341. didPaged = true
  342. }
  343. }
  344. if !didPaged {
  345. // CPU path: do NOT materialize full K/V history. That is O(kvLen*kvDim) per step
  346. // and causes severe context-dependent slowdown. Instead, run attention directly
  347. // over the KV cache block views.
  348. if cache != nil && layerCacheDev.Type != tensor.CUDA {
  349. startPos := cache.SeqLen()
  350. qCPU, err := qOut.AsCPU()
  351. if err != nil {
  352. profile.End("Block/Attention")
  353. return nil, fmt.Errorf("q to cpu: %w", err)
  354. }
  355. kCPU, err := kOut.AsCPU()
  356. if err != nil {
  357. profile.End("Block/Attention")
  358. return nil, fmt.Errorf("k to cpu: %w", err)
  359. }
  360. vCPU, err := vOut.AsCPU()
  361. if err != nil {
  362. profile.End("Block/Attention")
  363. return nil, fmt.Errorf("v to cpu: %w", err)
  364. }
  365. outCPU, err := attnOut.AsCPU()
  366. if err != nil {
  367. profile.End("Block/Attention")
  368. return nil, fmt.Errorf("attn out to cpu: %w", err)
  369. }
  370. views, _, err := cache.Append(layer.Idx, kCPU, vCPU)
  371. if err != nil {
  372. profile.End("Block/Attention")
  373. return nil, fmt.Errorf("cache append: %w", err)
  374. }
  375. if pv, ok := cache.(kvcache.PackedViewsProvider); ok {
  376. pviews := pv.ViewsPacked(layer.Idx)
  377. if len(pviews) != 0 {
  378. if err := nn.CausalAttentionPackedBlocks(qCPU, pviews, outCPU, cfg.NumHeads, cfg.NumKVHeads, cfg.HeadDim, startPos); err != nil {
  379. profile.End("Block/Attention")
  380. return nil, fmt.Errorf("attention packed: %w", err)
  381. }
  382. } else {
  383. if err := nn.CausalAttentionBlocks(qCPU, views, outCPU, cfg.NumHeads, cfg.NumKVHeads, cfg.HeadDim, startPos); err != nil {
  384. profile.End("Block/Attention")
  385. return nil, fmt.Errorf("attention blocks: %w", err)
  386. }
  387. }
  388. } else {
  389. if err := nn.CausalAttentionBlocks(qCPU, views, outCPU, cfg.NumHeads, cfg.NumKVHeads, cfg.HeadDim, startPos); err != nil {
  390. profile.End("Block/Attention")
  391. return nil, fmt.Errorf("attention blocks: %w", err)
  392. }
  393. }
  394. // If the layer itself is on GPU, restore attention output to GPU so the output
  395. // projection runs on GPU even when KV cache is on CPU.
  396. if ctx != nil && ctx.IsGPU() {
  397. if _, err := attnOut.EnsureOn(ctx.Placement()); err != nil {
  398. profile.End("Block/Attention")
  399. return nil, fmt.Errorf("move attn out to gpu: %w", err)
  400. }
  401. }
  402. } else {
  403. fullK, fullV, startPos, err := gatherKV(ctx, cache, layer.Idx, kOut, vOut, cfg.NumKVHeads*cfg.HeadDim)
  404. if err != nil {
  405. profile.End("Block/Attention")
  406. return nil, err
  407. }
  408. if err := compute.HybridAttention(ctx, qOut, fullK, fullV, attnOut, cfg.NumHeads, cfg.NumKVHeads, cfg.HeadDim, scale, startPos); err != nil {
  409. profile.End("Block/Attention")
  410. return nil, fmt.Errorf("attention: %w", err)
  411. }
  412. }
  413. }
  414. profile.End("Block/Attention")
  415. attnProj, err := alloc(tensor.Shape{seqLen, cfg.HiddenSize}, ctx.Placement(), true)
  416. if err != nil {
  417. return nil, err
  418. }
  419. profile.Start("Block/OProj")
  420. if err := compute.HybridLinear(ctx, attnOut, layer.Wo, attnProj); err != nil {
  421. profile.End("Block/OProj")
  422. return nil, fmt.Errorf("o_proj: %w", err)
  423. }
  424. profile.End("Block/OProj")
  425. if err := compute.HybridAdd(ctx, residual, attnProj); err != nil {
  426. return nil, fmt.Errorf("residual 1: %w", err)
  427. }
  428. return residual, nil
  429. }
  430. func runHybridAttentionBlockBatch(ctx *compute.Context, alloc hybridAllocFn, hidden *compute.Activation, residual *compute.Activation, layer *HybridDecoderLayerWeights, positions []int, caches []kvcache.KVCacheInterface, cfg HybridDecoderConfig, eps float32) (*compute.Activation, error) {
  431. seqLen := hidden.Shape()[0]
  432. if len(caches) != seqLen {
  433. return nil, fmt.Errorf("caches len %d != seqLen %d", len(caches), seqLen)
  434. }
  435. if len(positions) != seqLen {
  436. return nil, fmt.Errorf("positions len %d != seqLen %d", len(positions), seqLen)
  437. }
  438. profile.Start("Block/AttnNorm")
  439. if err := compute.HybridRMSNorm(ctx, hidden, layer.AttnNorm, eps); err != nil {
  440. profile.End("Block/AttnNorm")
  441. return nil, fmt.Errorf("attn norm: %w", err)
  442. }
  443. profile.End("Block/AttnNorm")
  444. qOut, err := alloc(tensor.Shape{seqLen, cfg.NumHeads * cfg.HeadDim}, ctx.Placement(), true)
  445. if err != nil {
  446. return nil, err
  447. }
  448. kOut, err := alloc(tensor.Shape{seqLen, cfg.NumKVHeads * cfg.HeadDim}, ctx.Placement(), true)
  449. if err != nil {
  450. return nil, err
  451. }
  452. vOut, err := alloc(tensor.Shape{seqLen, cfg.NumKVHeads * cfg.HeadDim}, ctx.Placement(), true)
  453. if err != nil {
  454. return nil, err
  455. }
  456. profile.Start("Block/QProj")
  457. if err := compute.HybridLinear(ctx, hidden, layer.Wq, qOut); err != nil {
  458. profile.End("Block/QProj")
  459. return nil, fmt.Errorf("q_proj: %w", err)
  460. }
  461. profile.End("Block/QProj")
  462. profile.Start("Block/KProj")
  463. if err := compute.HybridLinear(ctx, hidden, layer.Wk, kOut); err != nil {
  464. profile.End("Block/KProj")
  465. return nil, fmt.Errorf("k_proj: %w", err)
  466. }
  467. profile.End("Block/KProj")
  468. profile.Start("Block/VProj")
  469. if err := compute.HybridLinear(ctx, hidden, layer.Wv, vOut); err != nil {
  470. profile.End("Block/VProj")
  471. return nil, fmt.Errorf("v_proj: %w", err)
  472. }
  473. profile.End("Block/VProj")
  474. if layer.QNorm != nil {
  475. profile.Start("Block/QNorm")
  476. if err := compute.HybridRMSNorm(ctx, qOut, layer.QNorm, eps); err != nil {
  477. profile.End("Block/QNorm")
  478. return nil, fmt.Errorf("q_norm: %w", err)
  479. }
  480. profile.End("Block/QNorm")
  481. }
  482. if layer.KNorm != nil {
  483. profile.Start("Block/KNorm")
  484. if err := compute.HybridRMSNorm(ctx, kOut, layer.KNorm, eps); err != nil {
  485. profile.End("Block/KNorm")
  486. return nil, fmt.Errorf("k_norm: %w", err)
  487. }
  488. profile.End("Block/KNorm")
  489. }
  490. // Only skip standalone RoPE when we will actually use fused RoPE inside CUDA paged attention.
  491. canPagedCUDA := false
  492. if ctx != nil && ctx.IsGPU() && cuda.Available() {
  493. gpu := ctx.Placement().GPU
  494. canPagedCUDA = true
  495. for i := 0; i < seqLen; i++ {
  496. if caches[i] == nil {
  497. canPagedCUDA = false
  498. break
  499. }
  500. pc, ok := caches[i].(*kvcache.PagedKVCache)
  501. if !ok || pc == nil || !pc.IsOnGPU() || pc.LayerDevice(layer.Idx).GPU != gpu {
  502. canPagedCUDA = false
  503. break
  504. }
  505. }
  506. }
  507. useFusedRoPE := canPagedCUDA && cfg.HeadDim <= 256 && (cfg.HeadDim&1) == 0 && cfg.RopeTheta != 0
  508. if !useFusedRoPE {
  509. profile.Start("Block/RoPE_Q")
  510. if err := compute.HybridRoPE(ctx, qOut, positions, cfg.HeadDim, cfg.RopeTheta); err != nil {
  511. profile.End("Block/RoPE_Q")
  512. return nil, fmt.Errorf("rope q: %w", err)
  513. }
  514. profile.End("Block/RoPE_Q")
  515. profile.Start("Block/RoPE_K")
  516. if err := compute.HybridRoPE(ctx, kOut, positions, cfg.HeadDim, cfg.RopeTheta); err != nil {
  517. profile.End("Block/RoPE_K")
  518. return nil, fmt.Errorf("rope k: %w", err)
  519. }
  520. profile.End("Block/RoPE_K")
  521. }
  522. attnOut, err := alloc(tensor.Shape{seqLen, cfg.NumHeads * cfg.HeadDim}, ctx.Placement(), true)
  523. if err != nil {
  524. return nil, err
  525. }
  526. scale := float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
  527. profile.Start("Block/Attention")
  528. didPaged := false
  529. if canPagedCUDA {
  530. gpu := ctx.Placement().GPU
  531. gpuQ, err := qOut.AsCUDA(gpu)
  532. if err != nil {
  533. profile.End("Block/Attention")
  534. return nil, fmt.Errorf("q to cuda: %w", err)
  535. }
  536. gpuK, err := kOut.AsCUDA(gpu)
  537. if err != nil {
  538. profile.End("Block/Attention")
  539. return nil, fmt.Errorf("k to cuda: %w", err)
  540. }
  541. gpuV, err := vOut.AsCUDA(gpu)
  542. if err != nil {
  543. profile.End("Block/Attention")
  544. return nil, fmt.Errorf("v to cuda: %w", err)
  545. }
  546. gpuOut, err := attnOut.AsCUDA(gpu)
  547. if err != nil {
  548. profile.End("Block/Attention")
  549. return nil, fmt.Errorf("attn out to cuda: %w", err)
  550. }
  551. flatKPtrs := make([]uintptr, 0)
  552. flatVPtrs := make([]uintptr, 0)
  553. blockOffsets := make([]int32, seqLen)
  554. kvLens := make([]int32, seqLen)
  555. queryPos := make([]int32, seqLen)
  556. maxKvLen := 0
  557. var kvType tensor.DType
  558. blockSize := 0
  559. kvDim := cfg.NumKVHeads * cfg.HeadDim
  560. for i := 0; i < seqLen; i++ {
  561. pc := caches[i].(*kvcache.PagedKVCache)
  562. startPos := pc.SeqLen()
  563. offBytes := uintptr(i * kvDim * 4)
  564. kView, err := gpuK.ViewAt(tensor.Shape{1, kvDim}, offBytes)
  565. if err != nil {
  566. profile.End("Block/Attention")
  567. return nil, fmt.Errorf("k view: %w", err)
  568. }
  569. vView, err := gpuV.ViewAt(tensor.Shape{1, kvDim}, offBytes)
  570. if err != nil {
  571. profile.End("Block/Attention")
  572. return nil, fmt.Errorf("v view: %w", err)
  573. }
  574. if _, _, err := pc.Append(layer.Idx, kView, vView); err != nil {
  575. profile.End("Block/Attention")
  576. return nil, fmt.Errorf("cache append: %w", err)
  577. }
  578. kvLen := startPos + 1
  579. if kvLen > maxKvLen {
  580. maxKvLen = kvLen
  581. }
  582. bs := pc.BlockSize()
  583. if blockSize == 0 {
  584. blockSize = bs
  585. } else if blockSize != bs {
  586. profile.End("Block/Attention")
  587. return nil, fmt.Errorf("mixed block sizes in batch: %d vs %d", blockSize, bs)
  588. }
  589. nBlocks := (kvLen + bs - 1) / bs
  590. kPtrs, vPtrs, _, curType, err := pc.LayerBlockPtrTables(layer.Idx, nBlocks)
  591. if err != nil {
  592. profile.End("Block/Attention")
  593. return nil, fmt.Errorf("kv ptr tables: %w", err)
  594. }
  595. if kvType == 0 {
  596. kvType = curType
  597. } else if kvType != curType {
  598. profile.End("Block/Attention")
  599. return nil, fmt.Errorf("mixed KV dtypes in batch: %v vs %v", kvType, curType)
  600. }
  601. blockOffsets[i] = int32(len(flatKPtrs))
  602. flatKPtrs = append(flatKPtrs, kPtrs...)
  603. flatVPtrs = append(flatVPtrs, vPtrs...)
  604. kvLens[i] = int32(kvLen)
  605. queryPos[i] = int32(startPos)
  606. }
  607. if blockSize <= 0 {
  608. profile.End("Block/Attention")
  609. return nil, fmt.Errorf("batched attention: invalid blockSize %d", blockSize)
  610. }
  611. if len(flatKPtrs) == 0 || len(flatVPtrs) == 0 {
  612. profile.End("Block/Attention")
  613. return nil, fmt.Errorf("batched attention: empty KV ptr tables")
  614. }
  615. var (
  616. kFlatDev unsafe.Pointer
  617. vFlatDev unsafe.Pointer
  618. offDev unsafe.Pointer
  619. kvDev unsafe.Pointer
  620. qposDev unsafe.Pointer
  621. freeAfter bool
  622. )
  623. if ctx != nil && ctx.Scratch != nil {
  624. kFlatDev, err = ctx.Scratch.GetUintptrSlice(len(flatKPtrs))
  625. if err != nil {
  626. profile.End("Block/Attention")
  627. return nil, fmt.Errorf("scratch K ptr table: %w", err)
  628. }
  629. vFlatDev, err = ctx.Scratch.GetUintptrSlice(len(flatVPtrs))
  630. if err != nil {
  631. profile.End("Block/Attention")
  632. return nil, fmt.Errorf("scratch V ptr table: %w", err)
  633. }
  634. offDev, err = ctx.Scratch.GetInt32Slice(len(blockOffsets))
  635. if err != nil {
  636. profile.End("Block/Attention")
  637. return nil, fmt.Errorf("scratch offsets: %w", err)
  638. }
  639. kvDev, err = ctx.Scratch.GetInt32Slice(len(kvLens))
  640. if err != nil {
  641. profile.End("Block/Attention")
  642. return nil, fmt.Errorf("scratch kv lens: %w", err)
  643. }
  644. qposDev, err = ctx.Scratch.GetInt32Slice(len(queryPos))
  645. if err != nil {
  646. profile.End("Block/Attention")
  647. return nil, fmt.Errorf("scratch query pos: %w", err)
  648. }
  649. if err := cuda.MemcpyH2D(kFlatDev, unsafe.Pointer(&flatKPtrs[0]), uintptr(len(flatKPtrs))*unsafe.Sizeof(uintptr(0)), gpu); err != nil {
  650. profile.End("Block/Attention")
  651. return nil, fmt.Errorf("memcpy K ptr table: %w", err)
  652. }
  653. if err := cuda.MemcpyH2D(vFlatDev, unsafe.Pointer(&flatVPtrs[0]), uintptr(len(flatVPtrs))*unsafe.Sizeof(uintptr(0)), gpu); err != nil {
  654. profile.End("Block/Attention")
  655. return nil, fmt.Errorf("memcpy V ptr table: %w", err)
  656. }
  657. if err := cuda.MemcpyH2D(offDev, unsafe.Pointer(&blockOffsets[0]), uintptr(len(blockOffsets))*4, gpu); err != nil {
  658. profile.End("Block/Attention")
  659. return nil, fmt.Errorf("memcpy offsets: %w", err)
  660. }
  661. if err := cuda.MemcpyH2D(kvDev, unsafe.Pointer(&kvLens[0]), uintptr(len(kvLens))*4, gpu); err != nil {
  662. profile.End("Block/Attention")
  663. return nil, fmt.Errorf("memcpy kv lens: %w", err)
  664. }
  665. if err := cuda.MemcpyH2D(qposDev, unsafe.Pointer(&queryPos[0]), uintptr(len(queryPos))*4, gpu); err != nil {
  666. profile.End("Block/Attention")
  667. return nil, fmt.Errorf("memcpy query pos: %w", err)
  668. }
  669. } else {
  670. freeAfter = true
  671. kFlatDev, err = cuda.AllocAndCopyPtrTable(flatKPtrs, gpu)
  672. if err != nil {
  673. profile.End("Block/Attention")
  674. return nil, fmt.Errorf("alloc K ptr table: %w", err)
  675. }
  676. vFlatDev, err = cuda.AllocAndCopyPtrTable(flatVPtrs, gpu)
  677. if err != nil {
  678. cuda.FreeDevicePtr(kFlatDev)
  679. profile.End("Block/Attention")
  680. return nil, fmt.Errorf("alloc V ptr table: %w", err)
  681. }
  682. offDev, err = cuda.AllocAndCopyInt32(blockOffsets, gpu)
  683. if err != nil {
  684. cuda.FreeDevicePtr(kFlatDev)
  685. cuda.FreeDevicePtr(vFlatDev)
  686. profile.End("Block/Attention")
  687. return nil, fmt.Errorf("alloc offsets: %w", err)
  688. }
  689. kvDev, err = cuda.AllocAndCopyInt32(kvLens, gpu)
  690. if err != nil {
  691. cuda.FreeDevicePtr(kFlatDev)
  692. cuda.FreeDevicePtr(vFlatDev)
  693. cuda.FreeDevicePtr(offDev)
  694. profile.End("Block/Attention")
  695. return nil, fmt.Errorf("alloc kv lens: %w", err)
  696. }
  697. qposDev, err = cuda.AllocAndCopyInt32(queryPos, gpu)
  698. if err != nil {
  699. cuda.FreeDevicePtr(kFlatDev)
  700. cuda.FreeDevicePtr(vFlatDev)
  701. cuda.FreeDevicePtr(offDev)
  702. cuda.FreeDevicePtr(kvDev)
  703. profile.End("Block/Attention")
  704. return nil, fmt.Errorf("alloc query pos: %w", err)
  705. }
  706. }
  707. if freeAfter {
  708. defer cuda.FreeDevicePtr(kFlatDev)
  709. defer cuda.FreeDevicePtr(vFlatDev)
  710. defer cuda.FreeDevicePtr(offDev)
  711. defer cuda.FreeDevicePtr(kvDev)
  712. defer cuda.FreeDevicePtr(qposDev)
  713. }
  714. if kvType == tensor.Float16 {
  715. if useFusedRoPE {
  716. err = cuda.PagedAttentionBatchRoPEF32F16KV(
  717. gpuQ.Data().(unsafe.Pointer),
  718. kFlatDev,
  719. vFlatDev,
  720. offDev,
  721. kvDev,
  722. qposDev,
  723. gpuOut.Data().(unsafe.Pointer),
  724. seqLen,
  725. cfg.NumHeads, cfg.NumKVHeads, cfg.HeadDim,
  726. blockSize,
  727. scale,
  728. maxKvLen,
  729. cfg.RopeTheta,
  730. gpu,
  731. )
  732. } else {
  733. err = cuda.PagedAttentionBatchF32F16KV(
  734. gpuQ.Data().(unsafe.Pointer),
  735. kFlatDev,
  736. vFlatDev,
  737. offDev,
  738. kvDev,
  739. qposDev,
  740. gpuOut.Data().(unsafe.Pointer),
  741. seqLen,
  742. cfg.NumHeads, cfg.NumKVHeads, cfg.HeadDim,
  743. blockSize,
  744. scale,
  745. maxKvLen,
  746. gpu,
  747. )
  748. }
  749. } else {
  750. err = cuda.PagedAttentionBatch(
  751. gpuQ.Data().(unsafe.Pointer),
  752. kFlatDev,
  753. vFlatDev,
  754. offDev,
  755. kvDev,
  756. qposDev,
  757. gpuOut.Data().(unsafe.Pointer),
  758. seqLen,
  759. cfg.NumHeads, cfg.NumKVHeads, cfg.HeadDim,
  760. blockSize,
  761. scale,
  762. maxKvLen,
  763. gpu,
  764. )
  765. }
  766. if err != nil {
  767. profile.End("Block/Attention")
  768. return nil, fmt.Errorf("paged attention batch: %w", err)
  769. }
  770. didPaged = true
  771. }
  772. if !didPaged {
  773. // CPU KV cache path (supports running the layer on GPU while keeping KV on CPU).
  774. qCPU, err := qOut.AsCPU()
  775. if err != nil {
  776. profile.End("Block/Attention")
  777. return nil, fmt.Errorf("q to cpu: %w", err)
  778. }
  779. kCPU, err := kOut.AsCPU()
  780. if err != nil {
  781. profile.End("Block/Attention")
  782. return nil, fmt.Errorf("k to cpu: %w", err)
  783. }
  784. vCPU, err := vOut.AsCPU()
  785. if err != nil {
  786. profile.End("Block/Attention")
  787. return nil, fmt.Errorf("v to cpu: %w", err)
  788. }
  789. outCPU, err := attnOut.AsCPU()
  790. if err != nil {
  791. profile.End("Block/Attention")
  792. return nil, fmt.Errorf("attn out to cpu: %w", err)
  793. }
  794. qStride := cfg.NumHeads * cfg.HeadDim
  795. kvDim := cfg.NumKVHeads * cfg.HeadDim
  796. kAll := kCPU.DataFloat32()
  797. vAll := vCPU.DataFloat32()
  798. outAll := outCPU.DataFloat32()
  799. queryPos := make([]int, seqLen)
  800. packedViews := make([][]kvcache.PackedView, seqLen)
  801. viewsByToken := make([][]kvcache.View, seqLen)
  802. allPacked := true
  803. for i := 0; i < seqLen; i++ {
  804. if caches[i] == nil {
  805. profile.End("Block/Attention")
  806. return nil, fmt.Errorf("batched attention requires non-nil KV cache for token %d", i)
  807. }
  808. layerDev := caches[i].LayerDevice(layer.Idx).Normalize()
  809. if layerDev.Type == tensor.CUDA {
  810. profile.End("Block/Attention")
  811. return nil, fmt.Errorf("CPU batched attention requires CPU KV cache for token %d (got %v)", i, layerDev)
  812. }
  813. queryPos[i] = caches[i].SeqLen()
  814. kRow := cpu.NewTensor(tensor.Shape{1, kvDim}, kAll[i*kvDim:(i+1)*kvDim])
  815. vRow := cpu.NewTensor(tensor.Shape{1, kvDim}, vAll[i*kvDim:(i+1)*kvDim])
  816. views, _, err := caches[i].Append(layer.Idx, kRow, vRow)
  817. if err != nil {
  818. profile.End("Block/Attention")
  819. return nil, fmt.Errorf("cache append: %w", err)
  820. }
  821. viewsByToken[i] = views
  822. if pv, ok := caches[i].(kvcache.PackedViewsProvider); ok {
  823. p := pv.ViewsPacked(layer.Idx)
  824. if len(p) > 0 {
  825. packedViews[i] = p
  826. } else {
  827. allPacked = false
  828. }
  829. } else {
  830. allPacked = false
  831. }
  832. }
  833. if allPacked {
  834. if err := nn.CausalAttentionPackedBlocksBatch(qCPU, packedViews, outCPU, cfg.NumHeads, cfg.NumKVHeads, cfg.HeadDim, queryPos); err != nil {
  835. profile.End("Block/Attention")
  836. return nil, fmt.Errorf("attention packed batch: %w", err)
  837. }
  838. } else {
  839. for i := 0; i < seqLen; i++ {
  840. qRow := cpu.NewTensor(tensor.Shape{1, qStride}, qCPU.DataFloat32()[i*qStride:(i+1)*qStride])
  841. outRow := cpu.NewTensor(tensor.Shape{1, qStride}, outAll[i*qStride:(i+1)*qStride])
  842. if len(packedViews[i]) != 0 {
  843. if err := nn.CausalAttentionPackedBlocks(qRow, packedViews[i], outRow, cfg.NumHeads, cfg.NumKVHeads, cfg.HeadDim, queryPos[i]); err != nil {
  844. profile.End("Block/Attention")
  845. return nil, fmt.Errorf("attention packed: %w", err)
  846. }
  847. continue
  848. }
  849. if err := nn.CausalAttentionBlocks(qRow, viewsByToken[i], outRow, cfg.NumHeads, cfg.NumKVHeads, cfg.HeadDim, queryPos[i]); err != nil {
  850. profile.End("Block/Attention")
  851. return nil, fmt.Errorf("attention blocks: %w", err)
  852. }
  853. }
  854. }
  855. // Restore attention output to GPU so the output projection stays on GPU.
  856. if ctx != nil && ctx.IsGPU() {
  857. if _, err := attnOut.EnsureOn(ctx.Placement()); err != nil {
  858. profile.End("Block/Attention")
  859. return nil, fmt.Errorf("move attn out to gpu: %w", err)
  860. }
  861. }
  862. didPaged = true
  863. }
  864. profile.End("Block/Attention")
  865. attnProj, err := alloc(tensor.Shape{seqLen, cfg.HiddenSize}, ctx.Placement(), true)
  866. if err != nil {
  867. return nil, err
  868. }
  869. profile.Start("Block/OProj")
  870. if err := compute.HybridLinear(ctx, attnOut, layer.Wo, attnProj); err != nil {
  871. profile.End("Block/OProj")
  872. return nil, fmt.Errorf("o_proj: %w", err)
  873. }
  874. profile.End("Block/OProj")
  875. if err := compute.HybridAdd(ctx, residual, attnProj); err != nil {
  876. return nil, fmt.Errorf("residual 1: %w", err)
  877. }
  878. return residual, nil
  879. }
  880. func runHybridMLPBlock(ctx *compute.Context, alloc hybridAllocFn, postAttn *compute.Activation, layer *HybridDecoderLayerWeights, cfg HybridDecoderConfig, eps float32) (*compute.Activation, error) {
  881. seqLen := postAttn.Shape()[0]
  882. residual2, err := cloneActivation(ctx, alloc, postAttn, false)
  883. if err != nil {
  884. return nil, err
  885. }
  886. profile.Start("Block/MlpNorm")
  887. if err := compute.HybridRMSNorm(ctx, postAttn, layer.MlpNorm, eps); err != nil {
  888. profile.End("Block/MlpNorm")
  889. return nil, fmt.Errorf("mlp norm: %w", err)
  890. }
  891. profile.End("Block/MlpNorm")
  892. gate, err := alloc(tensor.Shape{seqLen, cfg.Intermediate}, ctx.Placement(), true)
  893. if err != nil {
  894. return nil, err
  895. }
  896. up, err := alloc(tensor.Shape{seqLen, cfg.Intermediate}, ctx.Placement(), true)
  897. if err != nil {
  898. return nil, err
  899. }
  900. profile.Start("Block/GateProj")
  901. if err := compute.HybridLinear(ctx, postAttn, layer.WGate, gate); err != nil {
  902. profile.End("Block/GateProj")
  903. return nil, fmt.Errorf("gate proj: %w", err)
  904. }
  905. profile.End("Block/GateProj")
  906. profile.Start("Block/UpProj")
  907. if err := compute.HybridLinear(ctx, postAttn, layer.WUp, up); err != nil {
  908. profile.End("Block/UpProj")
  909. return nil, fmt.Errorf("up proj: %w", err)
  910. }
  911. profile.End("Block/UpProj")
  912. profile.Start("Block/SwiGLU")
  913. act, err := alloc(tensor.Shape{seqLen, cfg.Intermediate}, ctx.Placement(), true)
  914. if err != nil {
  915. profile.End("Block/SwiGLU")
  916. return nil, err
  917. }
  918. if err := compute.HybridSwiGLU(ctx, gate, up, act); err != nil {
  919. profile.End("Block/SwiGLU")
  920. return nil, err
  921. }
  922. profile.End("Block/SwiGLU")
  923. mlpOut, err := alloc(tensor.Shape{seqLen, cfg.HiddenSize}, ctx.Placement(), true)
  924. if err != nil {
  925. return nil, err
  926. }
  927. profile.Start("Block/DownProj")
  928. if err := compute.HybridLinear(ctx, act, layer.WDown, mlpOut); err != nil {
  929. profile.End("Block/DownProj")
  930. return nil, fmt.Errorf("down proj: %w", err)
  931. }
  932. profile.End("Block/DownProj")
  933. if err := compute.HybridAdd(ctx, residual2, mlpOut); err != nil {
  934. return nil, fmt.Errorf("residual 2: %w", err)
  935. }
  936. return residual2, nil
  937. }
  938. func gatherKV(ctx *compute.Context, cache kvcache.KVCacheInterface, layerIdx int, kOut, vOut *compute.Activation, kvDim int) (*compute.Activation, *compute.Activation, int, error) {
  939. seqLen := kOut.Shape()[0]
  940. startPos := 0
  941. if cache == nil {
  942. return kOut, vOut, startPos, nil
  943. }
  944. startPos = cache.SeqLen()
  945. layerDev := cache.LayerDevice(layerIdx).Normalize()
  946. if layerDev.Type == tensor.CUDA && ctx != nil && ctx.IsGPU() && cuda.Available() {
  947. profile.Instant("KVCache/GPU_path", profile.EventOp, "")
  948. if _, _, err := cache.Append(layerIdx, kOut.Tensor(), vOut.Tensor()); err != nil {
  949. return nil, nil, 0, fmt.Errorf("cache append: %w", err)
  950. }
  951. kvLen := startPos + seqLen
  952. if kView, vView, ok, err := cache.ContiguousKV(layerIdx, kvLen, kvDim); err != nil {
  953. return nil, nil, 0, fmt.Errorf("contiguous kv: %w", err)
  954. } else if ok {
  955. if kView == nil || vView == nil {
  956. return kOut, vOut, startPos, nil
  957. }
  958. return compute.NewActivationFrom(kView), compute.NewActivationFrom(vView), startPos, nil
  959. }
  960. kAct, vAct, err := concatKVOnDevice(ctx, cache.Views(layerIdx), startPos+seqLen, kvDim)
  961. if err != nil {
  962. return nil, nil, 0, fmt.Errorf("concat kv on device: %w", err)
  963. }
  964. // Append the current step's K/V to the end of the full buffer.
  965. // PagedKVCache views only include committed tokens, so we must manually append the new ones.
  966. kDst := kAct.Tensor().(*cuda.Tensor)
  967. vDst := vAct.Tensor().(*cuda.Tensor)
  968. kSrc := kOut.Tensor().(*cuda.Tensor)
  969. vSrc := vOut.Tensor().(*cuda.Tensor)
  970. dstOffset := startPos * kvDim
  971. copyLen := seqLen * kvDim
  972. if err := kDst.CopyPartialFromDevice(dstOffset, kSrc, 0, copyLen); err != nil {
  973. return nil, nil, 0, fmt.Errorf("copy current K: %w", err)
  974. }
  975. if err := vDst.CopyPartialFromDevice(dstOffset, vSrc, 0, copyLen); err != nil {
  976. return nil, nil, 0, fmt.Errorf("copy current V: %w", err)
  977. }
  978. return kAct, vAct, startPos, nil
  979. }
  980. profile.Instant("KVCache/CPU_path", profile.EventOp, fmt.Sprintf("layerDev=%v ctxGPU=%v", layerDev, ctx != nil && ctx.IsGPU()))
  981. kCPU, err := kOut.AsCPU()
  982. if err != nil {
  983. return nil, nil, 0, err
  984. }
  985. vCPU, err := vOut.AsCPU()
  986. if err != nil {
  987. return nil, nil, 0, err
  988. }
  989. views, _, err := cache.Append(layerIdx, kCPU, vCPU)
  990. if err != nil {
  991. return nil, nil, 0, fmt.Errorf("cache append: %w", err)
  992. }
  993. kvLen := startPos + seqLen
  994. fullKData := make([]float32, kvLen*kvDim)
  995. fullVData := make([]float32, kvLen*kvDim)
  996. for _, view := range views {
  997. kData, err := getViewData(view.K)
  998. if err != nil {
  999. return nil, nil, 0, fmt.Errorf("get K data: %w", err)
  1000. }
  1001. vData, err := getViewData(view.V)
  1002. if err != nil {
  1003. return nil, nil, 0, fmt.Errorf("get V data: %w", err)
  1004. }
  1005. for i := 0; i < view.Length; i++ {
  1006. globalPos := view.Start + i
  1007. copy(fullKData[globalPos*kvDim:(globalPos+1)*kvDim], kData[i*kvDim:(i+1)*kvDim])
  1008. copy(fullVData[globalPos*kvDim:(globalPos+1)*kvDim], vData[i*kvDim:(i+1)*kvDim])
  1009. }
  1010. }
  1011. fullK := compute.NewActivationFrom(cpu.NewTensor(tensor.Shape{kvLen, kvDim}, fullKData))
  1012. fullV := compute.NewActivationFrom(cpu.NewTensor(tensor.Shape{kvLen, kvDim}, fullVData))
  1013. if ctx != nil && ctx.IsGPU() {
  1014. profile.Start("KVCache/FullK_H2D")
  1015. if err := compute.EnsureOnDevice(fullK, ctx.Placement()); err != nil {
  1016. profile.End("KVCache/FullK_H2D")
  1017. return nil, nil, 0, err
  1018. }
  1019. profile.End("KVCache/FullK_H2D")
  1020. profile.Start("KVCache/FullV_H2D")
  1021. if err := compute.EnsureOnDevice(fullV, ctx.Placement()); err != nil {
  1022. profile.End("KVCache/FullV_H2D")
  1023. return nil, nil, 0, err
  1024. }
  1025. profile.End("KVCache/FullV_H2D")
  1026. }
  1027. return fullK, fullV, startPos, nil
  1028. }
  1029. func getViewData(t tensor.Tensor) ([]float32, error) {
  1030. switch tt := t.(type) {
  1031. case *cpu.Tensor:
  1032. return tt.DataFloat32(), nil
  1033. default:
  1034. if copier, ok := t.(interface{ CopyToHost([]float32) error }); ok {
  1035. data := make([]float32, t.Shape().NumElements())
  1036. if err := copier.CopyToHost(data); err != nil {
  1037. return nil, err
  1038. }
  1039. return data, nil
  1040. }
  1041. return nil, fmt.Errorf("unsupported tensor type: %T", t)
  1042. }
  1043. }
  1044. func concatKVOnDevice(ctx *compute.Context, views []kvcache.View, kvLen, kvDim int) (*compute.Activation, *compute.Activation, error) {
  1045. gpu := ctx.Placement().GPU
  1046. fullKAct, err := func() (*compute.Activation, error) {
  1047. if ctx != nil && ctx.Scratch != nil {
  1048. if a, err := ctx.Scratch.GetTensor(tensor.Shape{kvLen, kvDim}, tensor.Float32); err == nil {
  1049. return a, nil
  1050. }
  1051. }
  1052. return compute.NewActivation(tensor.Shape{kvLen, kvDim}, ctx.Placement())
  1053. }()
  1054. if err != nil {
  1055. return nil, nil, err
  1056. }
  1057. fullVAct, err := func() (*compute.Activation, error) {
  1058. if ctx != nil && ctx.Scratch != nil {
  1059. if a, err := ctx.Scratch.GetTensor(tensor.Shape{kvLen, kvDim}, tensor.Float32); err == nil {
  1060. return a, nil
  1061. }
  1062. }
  1063. return compute.NewActivation(tensor.Shape{kvLen, kvDim}, ctx.Placement())
  1064. }()
  1065. if err != nil {
  1066. return nil, nil, err
  1067. }
  1068. fullKGPU, err := fullKAct.AsCUDA(gpu)
  1069. if err != nil {
  1070. return nil, nil, err
  1071. }
  1072. fullVGPU, err := fullVAct.AsCUDA(gpu)
  1073. if err != nil {
  1074. return nil, nil, err
  1075. }
  1076. for _, view := range views {
  1077. kSrc, ok := view.K.(*cuda.Tensor)
  1078. if !ok {
  1079. return nil, nil, fmt.Errorf("expected CUDA tensor for K view, got %T", view.K)
  1080. }
  1081. vSrc, ok := view.V.(*cuda.Tensor)
  1082. if !ok {
  1083. return nil, nil, fmt.Errorf("expected CUDA tensor for V view, got %T", view.V)
  1084. }
  1085. dstStart := view.Start * kvDim
  1086. length := view.Length * kvDim
  1087. if err := fullKGPU.CopyPartialFromDevice(dstStart, kSrc, 0, length); err != nil {
  1088. return nil, nil, fmt.Errorf("copy K view: %w", err)
  1089. }
  1090. if err := fullVGPU.CopyPartialFromDevice(dstStart, vSrc, 0, length); err != nil {
  1091. return nil, nil, fmt.Errorf("copy V view: %w", err)
  1092. }
  1093. }
  1094. return fullKAct, fullVAct, nil
  1095. }