1
0

paged_cache.go 32 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126
  1. // Package kvcache implements a paged KV cache with global prefix caching.
  2. package kvcache
  3. import (
  4. "fmt"
  5. "math"
  6. "sync"
  7. "unsafe"
  8. "makarna/pkg/backend/cpu"
  9. "makarna/pkg/backend/cuda"
  10. "makarna/pkg/tensor"
  11. )
  12. func float32ToFloat16Bits(f float32) uint16 {
  13. bits := math.Float32bits(f)
  14. sign := uint16((bits >> 16) & 0x8000)
  15. exp := int((bits >> 23) & 0xFF)
  16. mant := bits & 0x007FFFFF
  17. // NaN / Inf
  18. if exp == 255 {
  19. if mant == 0 {
  20. return sign | 0x7C00
  21. }
  22. m := uint16(mant >> 13)
  23. if m == 0 {
  24. m = 1
  25. }
  26. return sign | 0x7C00 | m
  27. }
  28. // Bias adjust: float32 bias=127, float16 bias=15
  29. exp16 := exp - 127 + 15
  30. // Overflow -> Inf
  31. if exp16 >= 31 {
  32. return sign | 0x7C00
  33. }
  34. // Subnormal / underflow
  35. if exp16 <= 0 {
  36. if exp16 < -10 {
  37. return sign
  38. }
  39. // Add implicit leading 1.
  40. mant |= 0x00800000
  41. // Shift to 10-bit mantissa (plus 13-bit alignment), with round-to-nearest-even.
  42. shift := uint32(1-exp16) + 13
  43. m16 := mant >> shift
  44. rem := mant & ((uint32(1) << shift) - 1)
  45. half := uint32(1) << (shift - 1)
  46. if rem > half || (rem == half && (m16&1) == 1) {
  47. m16++
  48. }
  49. return sign | uint16(m16)
  50. }
  51. // Normalized: round mantissa to 10 bits (round-to-nearest-even).
  52. m16 := mant >> 13
  53. rem := mant & 0x1FFF
  54. if rem > 0x1000 || (rem == 0x1000 && (m16&1) == 1) {
  55. m16++
  56. if m16 == 0x400 {
  57. m16 = 0
  58. exp16++
  59. if exp16 >= 31 {
  60. return sign | 0x7C00
  61. }
  62. }
  63. }
  64. return sign | uint16(exp16<<10) | uint16(m16)
  65. }
  66. func float32ToBFloat16Bits(f float32) uint16 {
  67. bits := math.Float32bits(f)
  68. upper := uint16(bits >> 16)
  69. lower := uint16(bits & 0xFFFF)
  70. if lower > 0x8000 || (lower == 0x8000 && (upper&1) == 1) {
  71. upper++
  72. }
  73. return upper
  74. }
  75. // PagedKVCache is a paged KV cache that uses a global BlockPool.
  76. // Unlike the old Cache which allocated per-request, this shares blocks across
  77. // requests and enables prefix caching via block hashing.
  78. type PagedKVCache struct {
  79. pool *BlockPool
  80. cfg PagedCacheConfig
  81. mu sync.RWMutex
  82. // Per-request state
  83. requestID string
  84. tokenIDs []int
  85. // Block allocation per layer: blockIDs[layer] = list of block IDs
  86. blockIDs [][]int
  87. // How many tokens have computed KV
  88. numComputed int
  89. // How many tokens have written KV (may be ahead of numComputed within a step).
  90. numWritten int
  91. // Block hashes for prefix matching
  92. blockHashes []BlockHash
  93. // Whether this cache owns its blocks (vs borrowed from prefix cache)
  94. ownedBlocks [][]bool
  95. // Cached device pointer tables for paged attention (per layer).
  96. // These are (re)built only when the number of blocks grows.
  97. ptrTables []devicePtrTable
  98. }
  99. type devicePtrTable struct {
  100. kDev unsafe.Pointer
  101. vDev unsafe.Pointer
  102. len int
  103. gpu int
  104. kvType tensor.DType
  105. }
  106. func (c *PagedKVCache) clearPtrTablesLocked() {
  107. if c.ptrTables == nil {
  108. return
  109. }
  110. for i := range c.ptrTables {
  111. if c.ptrTables[i].kDev != nil {
  112. cuda.FreeDevicePtr(c.ptrTables[i].kDev)
  113. }
  114. if c.ptrTables[i].vDev != nil {
  115. cuda.FreeDevicePtr(c.ptrTables[i].vDev)
  116. }
  117. c.ptrTables[i] = devicePtrTable{}
  118. }
  119. }
  120. func (c *PagedKVCache) ensureCapacityLocked(requiredBlocks int) error {
  121. if requiredBlocks <= 0 {
  122. return nil
  123. }
  124. if len(c.blockIDs) == 0 {
  125. return fmt.Errorf("cache not initialized")
  126. }
  127. // All layers are expected to have the same number of blocks.
  128. cur := 0
  129. if len(c.blockIDs[0]) > 0 {
  130. cur = len(c.blockIDs[0])
  131. }
  132. if requiredBlocks <= cur {
  133. return nil
  134. }
  135. need := requiredBlocks - cur
  136. allocatedByLayer := make([][]int, c.cfg.NumLayers)
  137. for layer := 0; layer < c.cfg.NumLayers; layer++ {
  138. newBlocks, err := c.pool.AllocateBlocks(layer, need)
  139. if err != nil {
  140. // rollback
  141. for l := 0; l < layer; l++ {
  142. c.pool.FreeBlocks(l, allocatedByLayer[l])
  143. }
  144. return err
  145. }
  146. allocatedByLayer[layer] = newBlocks
  147. c.blockIDs[layer] = append(c.blockIDs[layer], newBlocks...)
  148. for range newBlocks {
  149. c.ownedBlocks[layer] = append(c.ownedBlocks[layer], true)
  150. }
  151. }
  152. return nil
  153. }
  154. // AppendToken updates the token history for this request.
  155. // This is used to extend block hashing/caching beyond the initial prompt.
  156. func (c *PagedKVCache) AppendToken(tokenID int) {
  157. c.mu.Lock()
  158. defer c.mu.Unlock()
  159. c.tokenIDs = append(c.tokenIDs, tokenID)
  160. // Only add hashes for fully completed blocks.
  161. bs := c.pool.cfg.BlockSize
  162. if bs <= 0 {
  163. return
  164. }
  165. if len(c.tokenIDs)%bs != 0 {
  166. return
  167. }
  168. blockIdx := len(c.tokenIDs)/bs - 1
  169. if blockIdx < 0 {
  170. return
  171. }
  172. // Ensure blockIDs can hold this block.
  173. _ = c.ensureCapacityLocked(blockIdx + 1)
  174. start := blockIdx * bs
  175. end := start + bs
  176. if end > len(c.tokenIDs) {
  177. end = len(c.tokenIDs)
  178. }
  179. var parent *BlockHash
  180. if blockIdx-1 >= 0 && blockIdx-1 < len(c.blockHashes) {
  181. parent = &c.blockHashes[blockIdx-1]
  182. }
  183. h := ComputeBlockHash(c.tokenIDs[start:end], parent)
  184. // Extend or overwrite
  185. if blockIdx < len(c.blockHashes) {
  186. c.blockHashes[blockIdx] = h
  187. } else {
  188. c.blockHashes = append(c.blockHashes, h)
  189. }
  190. }
  191. // PagedCacheConfig configures a paged KV cache.
  192. type PagedCacheConfig struct {
  193. NumLayers int
  194. NumKVHeads int
  195. HeadDim int
  196. BlockSize int
  197. MaxSeqLen int
  198. Device tensor.DeviceType
  199. GPU int
  200. }
  201. // NewPagedKVCache creates a new paged cache backed by the given block pool.
  202. func NewPagedKVCache(pool *BlockPool, cfg PagedCacheConfig, requestID string) *PagedKVCache {
  203. return &PagedKVCache{
  204. pool: pool,
  205. cfg: cfg,
  206. requestID: requestID,
  207. blockIDs: make([][]int, cfg.NumLayers),
  208. ownedBlocks: make([][]bool, cfg.NumLayers),
  209. }
  210. }
  211. // AllocateForTokens allocates blocks for the given token sequence.
  212. // Returns number of tokens that are already cached (prefix hit).
  213. func (c *PagedKVCache) AllocateForTokens(tokens []int) (int, error) {
  214. c.mu.Lock()
  215. defer c.mu.Unlock()
  216. if c.cfg.MaxSeqLen > 0 && len(tokens) > c.cfg.MaxSeqLen {
  217. return 0, fmt.Errorf("prompt length %d exceeds MaxSeqLen %d", len(tokens), c.cfg.MaxSeqLen)
  218. }
  219. c.tokenIDs = tokens
  220. c.blockHashes = ComputeBlockHashes(tokens, c.pool.cfg.BlockSize)
  221. c.clearPtrTablesLocked()
  222. numBlocks := len(c.blockHashes)
  223. if numBlocks == 0 {
  224. return 0, nil
  225. }
  226. // Find cached prefix in layer 0 (all layers have same structure)
  227. cachedBlockIDs, cachedTokens := c.pool.FindCachedBlocks(0, c.blockHashes)
  228. numCachedBlocks := len(cachedBlockIDs)
  229. // Allocate blocks for all layers
  230. for layer := 0; layer < c.cfg.NumLayers; layer++ {
  231. c.blockIDs[layer] = make([]int, 0, numBlocks)
  232. c.ownedBlocks[layer] = make([]bool, 0, numBlocks)
  233. // First, get cached blocks for this layer
  234. if numCachedBlocks > 0 {
  235. layerCached, _ := c.pool.FindCachedBlocks(layer, c.blockHashes[:numCachedBlocks])
  236. c.pool.TouchBlocks(layer, layerCached)
  237. c.blockIDs[layer] = append(c.blockIDs[layer], layerCached...)
  238. for range layerCached {
  239. c.ownedBlocks[layer] = append(c.ownedBlocks[layer], false) // borrowed
  240. }
  241. }
  242. // Allocate new blocks for uncached portion
  243. numNewBlocks := numBlocks - numCachedBlocks
  244. if numNewBlocks > 0 {
  245. newBlocks, err := c.pool.AllocateBlocks(layer, numNewBlocks)
  246. if err != nil {
  247. // Rollback on failure
  248. for l := 0; l < layer; l++ {
  249. c.pool.FreeBlocks(l, c.blockIDs[l])
  250. }
  251. return 0, fmt.Errorf("layer %d allocation failed: %w", layer, err)
  252. }
  253. c.blockIDs[layer] = append(c.blockIDs[layer], newBlocks...)
  254. for range newBlocks {
  255. c.ownedBlocks[layer] = append(c.ownedBlocks[layer], true) // owned
  256. }
  257. }
  258. }
  259. // Don't cache the very last token so we always compute logits for the final position.
  260. c.numComputed = cachedTokens
  261. c.numWritten = c.numComputed
  262. if c.numComputed >= len(tokens) {
  263. c.numComputed = len(tokens) - 1
  264. if c.numComputed < 0 {
  265. c.numComputed = 0
  266. }
  267. c.numWritten = c.numComputed
  268. }
  269. return c.numComputed, nil
  270. }
  271. // LayerDevicePtrTables returns device pointers to contiguous pointer tables for this layer.
  272. // The returned pointers live until the cache is freed (or reallocated) and are updated only
  273. // when the required number of blocks increases.
  274. func (c *PagedKVCache) LayerDevicePtrTables(layer int, numBlocks int) (unsafe.Pointer, unsafe.Pointer, int, tensor.DType, error) {
  275. c.mu.Lock()
  276. defer c.mu.Unlock()
  277. if c.pool == nil {
  278. return nil, nil, 0, 0, fmt.Errorf("cache not initialized")
  279. }
  280. if layer < 0 || layer >= len(c.blockIDs) {
  281. return nil, nil, 0, 0, fmt.Errorf("invalid layer %d", layer)
  282. }
  283. if numBlocks <= 0 {
  284. return nil, nil, 0, 0, fmt.Errorf("numBlocks must be > 0")
  285. }
  286. layerDev := c.pool.LayerDevice(layer).Normalize()
  287. if layerDev.Type != tensor.CUDA {
  288. return nil, nil, 0, 0, fmt.Errorf("layer %d is not on CUDA (got %v)", layer, layerDev)
  289. }
  290. if numBlocks > len(c.blockIDs[layer]) {
  291. numBlocks = len(c.blockIDs[layer])
  292. }
  293. if numBlocks <= 0 {
  294. return nil, nil, 0, 0, fmt.Errorf("no blocks")
  295. }
  296. if c.ptrTables == nil || len(c.ptrTables) != c.cfg.NumLayers {
  297. c.ptrTables = make([]devicePtrTable, c.cfg.NumLayers)
  298. }
  299. pt := &c.ptrTables[layer]
  300. if pt.kDev != nil && pt.vDev != nil && pt.len == numBlocks && pt.gpu == layerDev.GPU {
  301. return pt.kDev, pt.vDev, c.pool.cfg.BlockSize, pt.kvType, nil
  302. }
  303. // Rebuild pointer tables on growth or device mismatch.
  304. kPtrs := make([]uintptr, numBlocks)
  305. vPtrs := make([]uintptr, numBlocks)
  306. var kvType tensor.DType
  307. for i := 0; i < numBlocks; i++ {
  308. bid := c.blockIDs[layer][i]
  309. b := c.pool.GetBlock(layer, bid)
  310. if b == nil {
  311. return nil, nil, 0, 0, fmt.Errorf("block %d not found", bid)
  312. }
  313. kT, ok := b.K.(*cuda.Tensor)
  314. if !ok {
  315. return nil, nil, 0, 0, fmt.Errorf("block K is not CUDA tensor")
  316. }
  317. vT, ok := b.V.(*cuda.Tensor)
  318. if !ok {
  319. return nil, nil, 0, 0, fmt.Errorf("block V is not CUDA tensor")
  320. }
  321. if kvType == 0 {
  322. kvType = kT.DType()
  323. }
  324. if kT.DType() != kvType || vT.DType() != kvType {
  325. return nil, nil, 0, 0, fmt.Errorf("mixed KV dtypes in blocks")
  326. }
  327. kPtrs[i] = uintptr(kT.Data().(unsafe.Pointer))
  328. vPtrs[i] = uintptr(vT.Data().(unsafe.Pointer))
  329. }
  330. kDev, err := cuda.AllocAndCopyPtrTable(kPtrs, layerDev.GPU)
  331. if err != nil {
  332. return nil, nil, 0, 0, err
  333. }
  334. vDev, err := cuda.AllocAndCopyPtrTable(vPtrs, layerDev.GPU)
  335. if err != nil {
  336. cuda.FreeDevicePtr(kDev)
  337. return nil, nil, 0, 0, err
  338. }
  339. if pt.kDev != nil {
  340. cuda.FreeDevicePtr(pt.kDev)
  341. }
  342. if pt.vDev != nil {
  343. cuda.FreeDevicePtr(pt.vDev)
  344. }
  345. pt.kDev = kDev
  346. pt.vDev = vDev
  347. pt.len = numBlocks
  348. pt.gpu = layerDev.GPU
  349. pt.kvType = kvType
  350. return pt.kDev, pt.vDev, c.pool.cfg.BlockSize, pt.kvType, nil
  351. }
  352. func writePackedBlockU32(dstK, dstV []float32, numKVHeads, headDim, blockSize, blockOffset, writeCount int, kSrc, vSrc []float32) {
  353. // kSrc/vSrc are token-major: [t][kvHead][d]
  354. // dstK/dstV are head-major: [kvHead][t][d]
  355. for t := 0; t < writeCount; t++ {
  356. baseTok := t * (numKVHeads * headDim)
  357. dstTok := blockOffset + t
  358. for h := 0; h < numKVHeads; h++ {
  359. srcBase := baseTok + h*headDim
  360. dstBase := h*(blockSize*headDim) + dstTok*headDim
  361. copy(dstK[dstBase:dstBase+headDim], kSrc[srcBase:srcBase+headDim])
  362. copy(dstV[dstBase:dstBase+headDim], vSrc[srcBase:srcBase+headDim])
  363. }
  364. }
  365. }
  366. // GetBlockForPosition returns the KV block for writing at a token position.
  367. func (c *PagedKVCache) GetBlockForPosition(layer, tokenPos int) *KVBlock {
  368. c.mu.RLock()
  369. defer c.mu.RUnlock()
  370. if layer < 0 || layer >= len(c.blockIDs) {
  371. return nil
  372. }
  373. blockIdx := tokenPos / c.pool.cfg.BlockSize
  374. if blockIdx < 0 || blockIdx >= len(c.blockIDs[layer]) {
  375. return nil
  376. }
  377. blockID := c.blockIDs[layer][blockIdx]
  378. return c.pool.GetBlock(layer, blockID)
  379. }
  380. // GetBlockOffset returns the offset within a block for a token position.
  381. func (c *PagedKVCache) GetBlockOffset(tokenPos int) int {
  382. return tokenPos % c.pool.cfg.BlockSize
  383. }
  384. func (c *PagedKVCache) LayerBlockIDs(layer int, numBlocks int) []int {
  385. c.mu.RLock()
  386. defer c.mu.RUnlock()
  387. if layer < 0 || layer >= len(c.blockIDs) {
  388. return nil
  389. }
  390. if numBlocks <= 0 {
  391. return nil
  392. }
  393. if numBlocks > len(c.blockIDs[layer]) {
  394. numBlocks = len(c.blockIDs[layer])
  395. }
  396. out := make([]int, numBlocks)
  397. copy(out, c.blockIDs[layer][:numBlocks])
  398. return out
  399. }
  400. func (c *PagedKVCache) LayerBlockPtrTables(layer int, numBlocks int) ([]uintptr, []uintptr, int, tensor.DType, error) {
  401. if numBlocks <= 0 {
  402. return nil, nil, 0, 0, fmt.Errorf("numBlocks must be > 0")
  403. }
  404. blockIDs := c.LayerBlockIDs(layer, numBlocks)
  405. if len(blockIDs) == 0 {
  406. return nil, nil, 0, 0, fmt.Errorf("no blocks")
  407. }
  408. kPtrs := make([]uintptr, len(blockIDs))
  409. vPtrs := make([]uintptr, len(blockIDs))
  410. var kvType tensor.DType
  411. for i, bid := range blockIDs {
  412. b := c.pool.GetBlock(layer, bid)
  413. if b == nil {
  414. return nil, nil, 0, 0, fmt.Errorf("block %d not found", bid)
  415. }
  416. kT, ok := b.K.(*cuda.Tensor)
  417. if !ok {
  418. return nil, nil, 0, 0, fmt.Errorf("block K is not CUDA tensor")
  419. }
  420. vT, ok := b.V.(*cuda.Tensor)
  421. if !ok {
  422. return nil, nil, 0, 0, fmt.Errorf("block V is not CUDA tensor")
  423. }
  424. if kvType == 0 {
  425. kvType = kT.DType()
  426. }
  427. if kT.DType() != kvType || vT.DType() != kvType {
  428. return nil, nil, 0, 0, fmt.Errorf("mixed KV dtypes in blocks")
  429. }
  430. kPtrs[i] = uintptr(kT.Data().(unsafe.Pointer))
  431. vPtrs[i] = uintptr(vT.Data().(unsafe.Pointer))
  432. }
  433. return kPtrs, vPtrs, c.pool.cfg.BlockSize, kvType, nil
  434. }
  435. // Append writes new K/V tokens into the cache for a layer.
  436. // Implements KVCacheInterface.
  437. func (c *PagedKVCache) Append(layer int, k, v tensor.Tensor) ([]View, int, error) {
  438. c.mu.Lock()
  439. defer c.mu.Unlock()
  440. if layer < 0 || layer >= len(c.blockIDs) {
  441. return nil, 0, fmt.Errorf("invalid layer %d", layer)
  442. }
  443. layerPlacement := c.LayerDevice(layer).Normalize()
  444. startPos := c.numComputed
  445. newTokens := k.Shape()[0]
  446. kvDim := k.Shape()[1]
  447. endPos := startPos + newTokens
  448. if c.cfg.MaxSeqLen > 0 && endPos > c.cfg.MaxSeqLen {
  449. return nil, 0, fmt.Errorf("KV cache overflow: need %d tokens (start %d + new %d) > MaxSeqLen %d", endPos, startPos, newTokens, c.cfg.MaxSeqLen)
  450. }
  451. if layerPlacement.Type == tensor.CUDA {
  452. kSrc, kIsCUDA := k.(*cuda.Tensor)
  453. vSrc, vIsCUDA := v.(*cuda.Tensor)
  454. if !kIsCUDA || !vIsCUDA {
  455. return nil, 0, fmt.Errorf("PagedKVCache layer %d requires CUDA tensors, got %T/%T", layer, k, v)
  456. }
  457. if layerPlacement.GPU >= 0 {
  458. if kSrc.GPU() != layerPlacement.GPU || vSrc.GPU() != layerPlacement.GPU {
  459. return nil, 0, fmt.Errorf("PagedKVCache layer %d on GPU %d, got k/v on GPU %d/%d", layer, layerPlacement.GPU, kSrc.GPU(), vSrc.GPU())
  460. }
  461. }
  462. // Ensure we have enough blocks to cover the write range.
  463. requiredBlocks := (endPos + c.pool.cfg.BlockSize - 1) / c.pool.cfg.BlockSize
  464. if err := c.ensureCapacityLocked(requiredBlocks); err != nil {
  465. return nil, 0, fmt.Errorf("ensure capacity: %w", err)
  466. }
  467. written := 0
  468. for written < newTokens {
  469. globalPos := startPos + written
  470. blockIdx := globalPos / c.pool.cfg.BlockSize
  471. blockOffset := globalPos % c.pool.cfg.BlockSize
  472. // Capacity is ensured above.
  473. block := c.pool.GetBlock(layer, c.blockIDs[layer][blockIdx])
  474. if block == nil {
  475. return nil, 0, fmt.Errorf("block %d not found", c.blockIDs[layer][blockIdx])
  476. }
  477. // How many tokens can we write to this block?
  478. writeCount := min(newTokens-written, c.pool.cfg.BlockSize-blockOffset)
  479. // Copy K/V to block
  480. srcStart := written * kvDim
  481. dstStart := blockOffset * kvDim
  482. length := writeCount * kvDim
  483. kDst, ok := block.K.(*cuda.Tensor)
  484. if !ok {
  485. return nil, 0, fmt.Errorf("block K is not CUDA tensor")
  486. }
  487. vDst, ok := block.V.(*cuda.Tensor)
  488. if !ok {
  489. return nil, 0, fmt.Errorf("block V is not CUDA tensor")
  490. }
  491. if kDst.DType() == tensor.Float16 && kSrc.DType() == tensor.Float32 {
  492. srcPtr := unsafe.Pointer(uintptr(kSrc.Data().(unsafe.Pointer)) + uintptr(srcStart*4))
  493. dstPtr := unsafe.Pointer(uintptr(kDst.Data().(unsafe.Pointer)) + uintptr(dstStart*2))
  494. if err := cuda.CastF32ToF16(srcPtr, dstPtr, length, kDst.GPU()); err != nil {
  495. return nil, 0, fmt.Errorf("cast K f32->f16 failed: %w", err)
  496. }
  497. } else {
  498. if err := kDst.CopyPartialFromDevice(dstStart, kSrc, srcStart, length); err != nil {
  499. return nil, 0, fmt.Errorf("copy K failed: %w", err)
  500. }
  501. }
  502. if vDst.DType() == tensor.Float16 && vSrc.DType() == tensor.Float32 {
  503. srcPtr := unsafe.Pointer(uintptr(vSrc.Data().(unsafe.Pointer)) + uintptr(srcStart*4))
  504. dstPtr := unsafe.Pointer(uintptr(vDst.Data().(unsafe.Pointer)) + uintptr(dstStart*2))
  505. if err := cuda.CastF32ToF16(srcPtr, dstPtr, length, vDst.GPU()); err != nil {
  506. return nil, 0, fmt.Errorf("cast V f32->f16 failed: %w", err)
  507. }
  508. } else {
  509. if err := vDst.CopyPartialFromDevice(dstStart, vSrc, srcStart, length); err != nil {
  510. return nil, 0, fmt.Errorf("copy V failed: %w", err)
  511. }
  512. }
  513. written += writeCount
  514. }
  515. if endPos > c.numWritten {
  516. c.numWritten = endPos
  517. }
  518. return c.viewsLockedAt(layer, endPos), startPos, nil
  519. }
  520. kSrc, kIsCPU := k.(*cpu.Tensor)
  521. vSrc, vIsCPU := v.(*cpu.Tensor)
  522. if !kIsCPU || !vIsCPU {
  523. return nil, 0, fmt.Errorf("PagedKVCache layer %d requires CPU tensors, got %T/%T", layer, k, v)
  524. }
  525. // Ensure we have enough blocks to cover the write range.
  526. requiredBlocks := (endPos + c.pool.cfg.BlockSize - 1) / c.pool.cfg.BlockSize
  527. if err := c.ensureCapacityLocked(requiredBlocks); err != nil {
  528. return nil, 0, fmt.Errorf("ensure capacity: %w", err)
  529. }
  530. var kF32, vF32 []float32
  531. if kSrc.DType() == tensor.Float32 {
  532. kF32 = kSrc.DataFloat32()
  533. }
  534. if vSrc.DType() == tensor.Float32 {
  535. vF32 = vSrc.DataFloat32()
  536. }
  537. written := 0
  538. for written < newTokens {
  539. globalPos := startPos + written
  540. blockIdx := globalPos / c.pool.cfg.BlockSize
  541. blockOffset := globalPos % c.pool.cfg.BlockSize
  542. // Capacity is ensured above.
  543. block := c.pool.GetBlock(layer, c.blockIDs[layer][blockIdx])
  544. if block == nil {
  545. return nil, 0, fmt.Errorf("block %d not found", c.blockIDs[layer][blockIdx])
  546. }
  547. // How many tokens can we write to this block?
  548. writeCount := min(newTokens-written, c.pool.cfg.BlockSize-blockOffset)
  549. // Copy K/V to block
  550. srcStart := written * kvDim
  551. dstStart := blockOffset * kvDim
  552. length := writeCount * kvDim
  553. kDst, ok := block.K.(*cpu.Tensor)
  554. if !ok {
  555. return nil, 0, fmt.Errorf("block K is not CPU tensor")
  556. }
  557. vDst, ok := block.V.(*cpu.Tensor)
  558. if !ok {
  559. return nil, 0, fmt.Errorf("block V is not CPU tensor")
  560. }
  561. switch kDst.DType() {
  562. case tensor.Float16:
  563. kOut := kDst.DataUint16()[dstStart : dstStart+length]
  564. switch kSrc.DType() {
  565. case tensor.Float32:
  566. kIn := kSrc.DataFloat32()[srcStart : srcStart+length]
  567. for i := 0; i < length; i++ {
  568. kOut[i] = float32ToFloat16Bits(kIn[i])
  569. }
  570. case tensor.Float16:
  571. copy(kOut, kSrc.DataUint16()[srcStart:srcStart+length])
  572. default:
  573. return nil, 0, fmt.Errorf("unsupported CPU K src dtype: %v", kSrc.DType())
  574. }
  575. case tensor.BFloat16:
  576. kOut := kDst.DataUint16()[dstStart : dstStart+length]
  577. switch kSrc.DType() {
  578. case tensor.Float32:
  579. kIn := kSrc.DataFloat32()[srcStart : srcStart+length]
  580. for i := 0; i < length; i++ {
  581. kOut[i] = float32ToBFloat16Bits(kIn[i])
  582. }
  583. case tensor.BFloat16:
  584. copy(kOut, kSrc.DataUint16()[srcStart:srcStart+length])
  585. default:
  586. return nil, 0, fmt.Errorf("unsupported CPU K src dtype: %v", kSrc.DType())
  587. }
  588. default:
  589. return nil, 0, fmt.Errorf("unsupported CPU K dst dtype: %v", kDst.DType())
  590. }
  591. switch vDst.DType() {
  592. case tensor.Float16:
  593. vOut := vDst.DataUint16()[dstStart : dstStart+length]
  594. switch vSrc.DType() {
  595. case tensor.Float32:
  596. vIn := vSrc.DataFloat32()[srcStart : srcStart+length]
  597. for i := 0; i < length; i++ {
  598. vOut[i] = float32ToFloat16Bits(vIn[i])
  599. }
  600. case tensor.Float16:
  601. copy(vOut, vSrc.DataUint16()[srcStart:srcStart+length])
  602. default:
  603. return nil, 0, fmt.Errorf("unsupported CPU V src dtype: %v", vSrc.DType())
  604. }
  605. case tensor.BFloat16:
  606. vOut := vDst.DataUint16()[dstStart : dstStart+length]
  607. switch vSrc.DType() {
  608. case tensor.Float32:
  609. vIn := vSrc.DataFloat32()[srcStart : srcStart+length]
  610. for i := 0; i < length; i++ {
  611. vOut[i] = float32ToBFloat16Bits(vIn[i])
  612. }
  613. case tensor.BFloat16:
  614. copy(vOut, vSrc.DataUint16()[srcStart:srcStart+length])
  615. default:
  616. return nil, 0, fmt.Errorf("unsupported CPU V src dtype: %v", vSrc.DType())
  617. }
  618. default:
  619. return nil, 0, fmt.Errorf("unsupported CPU V dst dtype: %v", vDst.DType())
  620. }
  621. // Populate packed CPU layout when available and source is float32.
  622. if block.pk != nil && block.pv != nil && kF32 != nil && vF32 != nil {
  623. srcStartTok := written * kvDim
  624. srcEndTok := srcStartTok + length
  625. writePackedBlockU32(block.pk, block.pv, c.cfg.NumKVHeads, c.cfg.HeadDim, c.pool.cfg.BlockSize, blockOffset, writeCount, kF32[srcStartTok:srcEndTok], vF32[srcStartTok:srcEndTok])
  626. }
  627. written += writeCount
  628. }
  629. if endPos > c.numWritten {
  630. c.numWritten = endPos
  631. }
  632. return c.viewsLockedAt(layer, endPos), startPos, nil
  633. }
  634. // Views returns the live KV block views for a layer.
  635. func (c *PagedKVCache) Views(layer int) []View {
  636. c.mu.RLock()
  637. defer c.mu.RUnlock()
  638. return c.viewsLockedAt(layer, c.numComputed)
  639. }
  640. func (c *PagedKVCache) ViewsPacked(layer int) []PackedView {
  641. c.mu.RLock()
  642. defer c.mu.RUnlock()
  643. if c.LayerDevice(layer).Type != tensor.CPU {
  644. return nil
  645. }
  646. computed := c.numWritten
  647. if computed <= 0 {
  648. return nil
  649. }
  650. if layer < 0 || layer >= len(c.blockIDs) {
  651. return nil
  652. }
  653. numBlocks := (computed + c.pool.cfg.BlockSize - 1) / c.pool.cfg.BlockSize
  654. views := make([]PackedView, 0, numBlocks)
  655. for i := 0; i < numBlocks && i < len(c.blockIDs[layer]); i++ {
  656. block := c.pool.GetBlock(layer, c.blockIDs[layer][i])
  657. if block == nil || block.pk == nil || block.pv == nil {
  658. continue
  659. }
  660. start := i * c.pool.cfg.BlockSize
  661. length := c.pool.cfg.BlockSize
  662. if start+length > computed {
  663. length = computed - start
  664. }
  665. if length <= 0 {
  666. continue
  667. }
  668. views = append(views, PackedView{
  669. K: block.pk,
  670. V: block.pv,
  671. Start: start,
  672. Length: length,
  673. BlockSize: c.pool.cfg.BlockSize,
  674. HeadDim: c.cfg.HeadDim,
  675. NumKVHeads: c.cfg.NumKVHeads,
  676. })
  677. }
  678. return views
  679. }
  680. // viewsLocked returns live views for the currently committed KV length.
  681. // Kept for internal call sites.
  682. func (c *PagedKVCache) viewsLocked(layer int) []View {
  683. return c.viewsLockedAt(layer, c.numComputed)
  684. }
  685. func (c *PagedKVCache) viewsLockedAt(layer int, computed int) []View {
  686. if layer < 0 || layer >= len(c.blockIDs) {
  687. return nil
  688. }
  689. if computed <= 0 {
  690. return nil
  691. }
  692. numBlocks := (computed + c.pool.cfg.BlockSize - 1) / c.pool.cfg.BlockSize
  693. views := make([]View, 0, numBlocks)
  694. layerDev := c.pool.LayerDevice(layer).Normalize()
  695. for i := 0; i < numBlocks && i < len(c.blockIDs[layer]); i++ {
  696. block := c.pool.GetBlock(layer, c.blockIDs[layer][i])
  697. if block == nil {
  698. continue
  699. }
  700. start := i * c.pool.cfg.BlockSize
  701. length := c.pool.cfg.BlockSize
  702. if start+length > computed {
  703. length = computed - start
  704. }
  705. if length <= 0 {
  706. continue
  707. }
  708. views = append(views, View{
  709. K: block.K,
  710. V: block.V,
  711. Start: start,
  712. Length: length,
  713. Device: layerDev.Type,
  714. GPU: layerDev.GPU,
  715. })
  716. }
  717. return views
  718. }
  719. // AppendKV writes K/V values for new tokens into the cache.
  720. // This is the main write path for prefill and decode.
  721. func (c *PagedKVCache) AppendKV(layer int, k, v tensor.Tensor, startPos int) error {
  722. c.mu.Lock()
  723. defer c.mu.Unlock()
  724. if layer < 0 || layer >= len(c.blockIDs) {
  725. return fmt.Errorf("invalid layer %d", layer)
  726. }
  727. layerPlacement := c.LayerDevice(layer).Normalize()
  728. newTokens := k.Shape()[0]
  729. kvDim := k.Shape()[1]
  730. if c.cfg.MaxSeqLen > 0 && startPos+newTokens > c.cfg.MaxSeqLen {
  731. return fmt.Errorf("KV cache overflow: need %d tokens (start %d + new %d) > MaxSeqLen %d", startPos+newTokens, startPos, newTokens, c.cfg.MaxSeqLen)
  732. }
  733. if layerPlacement.Type == tensor.CUDA {
  734. kSrc, kIsCUDA := k.(*cuda.Tensor)
  735. vSrc, vIsCUDA := v.(*cuda.Tensor)
  736. if !kIsCUDA || !vIsCUDA {
  737. return fmt.Errorf("PagedKVCache layer %d requires CUDA tensors, got %T/%T", layer, k, v)
  738. }
  739. written := 0
  740. for written < newTokens {
  741. globalPos := startPos + written
  742. blockIdx := globalPos / c.pool.cfg.BlockSize
  743. blockOffset := globalPos % c.pool.cfg.BlockSize
  744. if blockIdx >= len(c.blockIDs[layer]) {
  745. return fmt.Errorf("block index %d out of range", blockIdx)
  746. }
  747. block := c.pool.GetBlock(layer, c.blockIDs[layer][blockIdx])
  748. if block == nil {
  749. return fmt.Errorf("block %d not found", c.blockIDs[layer][blockIdx])
  750. }
  751. // How many tokens can we write to this block?
  752. writeCount := min(newTokens-written, c.pool.cfg.BlockSize-blockOffset)
  753. // Copy K/V to block
  754. srcStart := written * kvDim
  755. dstStart := blockOffset * kvDim
  756. length := writeCount * kvDim
  757. kDst, ok := block.K.(*cuda.Tensor)
  758. if !ok {
  759. return fmt.Errorf("block K is not CUDA tensor")
  760. }
  761. vDst, ok := block.V.(*cuda.Tensor)
  762. if !ok {
  763. return fmt.Errorf("block V is not CUDA tensor")
  764. }
  765. if kDst.DType() == tensor.Float16 && kSrc.DType() == tensor.Float32 {
  766. srcPtr := unsafe.Pointer(uintptr(kSrc.Data().(unsafe.Pointer)) + uintptr(srcStart*4))
  767. dstPtr := unsafe.Pointer(uintptr(kDst.Data().(unsafe.Pointer)) + uintptr(dstStart*2))
  768. if err := cuda.CastF32ToF16(srcPtr, dstPtr, length, kDst.GPU()); err != nil {
  769. return fmt.Errorf("cast K f32->f16 failed: %w", err)
  770. }
  771. } else {
  772. if err := kDst.CopyPartialFromDevice(dstStart, kSrc, srcStart, length); err != nil {
  773. return fmt.Errorf("copy K failed: %w", err)
  774. }
  775. }
  776. if vDst.DType() == tensor.Float16 && vSrc.DType() == tensor.Float32 {
  777. srcPtr := unsafe.Pointer(uintptr(vSrc.Data().(unsafe.Pointer)) + uintptr(srcStart*4))
  778. dstPtr := unsafe.Pointer(uintptr(vDst.Data().(unsafe.Pointer)) + uintptr(dstStart*2))
  779. if err := cuda.CastF32ToF16(srcPtr, dstPtr, length, vDst.GPU()); err != nil {
  780. return fmt.Errorf("cast V f32->f16 failed: %w", err)
  781. }
  782. } else {
  783. if err := vDst.CopyPartialFromDevice(dstStart, vSrc, srcStart, length); err != nil {
  784. return fmt.Errorf("copy V failed: %w", err)
  785. }
  786. }
  787. written += writeCount
  788. }
  789. return nil
  790. }
  791. kSrc, kIsCPU := k.(*cpu.Tensor)
  792. vSrc, vIsCPU := v.(*cpu.Tensor)
  793. if !kIsCPU || !vIsCPU {
  794. return fmt.Errorf("PagedKVCache layer %d requires CPU tensors, got %T/%T", layer, k, v)
  795. }
  796. written := 0
  797. for written < newTokens {
  798. globalPos := startPos + written
  799. blockIdx := globalPos / c.pool.cfg.BlockSize
  800. blockOffset := globalPos % c.pool.cfg.BlockSize
  801. if blockIdx >= len(c.blockIDs[layer]) {
  802. return fmt.Errorf("block index %d out of range", blockIdx)
  803. }
  804. block := c.pool.GetBlock(layer, c.blockIDs[layer][blockIdx])
  805. if block == nil {
  806. return fmt.Errorf("block %d not found", c.blockIDs[layer][blockIdx])
  807. }
  808. // How many tokens can we write to this block?
  809. writeCount := min(newTokens-written, c.pool.cfg.BlockSize-blockOffset)
  810. // Copy K/V to block
  811. srcStart := written * kvDim
  812. dstStart := blockOffset * kvDim
  813. length := writeCount * kvDim
  814. kDst, ok := block.K.(*cpu.Tensor)
  815. if !ok {
  816. return fmt.Errorf("block K is not CPU tensor")
  817. }
  818. vDst, ok := block.V.(*cpu.Tensor)
  819. if !ok {
  820. return fmt.Errorf("block V is not CPU tensor")
  821. }
  822. switch kDst.DType() {
  823. case tensor.Float16:
  824. kOut := kDst.DataUint16()[dstStart : dstStart+length]
  825. switch kSrc.DType() {
  826. case tensor.Float32:
  827. kIn := kSrc.DataFloat32()[srcStart : srcStart+length]
  828. for i := 0; i < length; i++ {
  829. kOut[i] = float32ToFloat16Bits(kIn[i])
  830. }
  831. case tensor.Float16:
  832. copy(kOut, kSrc.DataUint16()[srcStart:srcStart+length])
  833. default:
  834. return fmt.Errorf("unsupported CPU K src dtype: %v", kSrc.DType())
  835. }
  836. case tensor.BFloat16:
  837. kOut := kDst.DataUint16()[dstStart : dstStart+length]
  838. switch kSrc.DType() {
  839. case tensor.Float32:
  840. kIn := kSrc.DataFloat32()[srcStart : srcStart+length]
  841. for i := 0; i < length; i++ {
  842. kOut[i] = float32ToBFloat16Bits(kIn[i])
  843. }
  844. case tensor.BFloat16:
  845. copy(kOut, kSrc.DataUint16()[srcStart:srcStart+length])
  846. default:
  847. return fmt.Errorf("unsupported CPU K src dtype: %v", kSrc.DType())
  848. }
  849. default:
  850. return fmt.Errorf("unsupported CPU K dst dtype: %v", kDst.DType())
  851. }
  852. switch vDst.DType() {
  853. case tensor.Float16:
  854. vOut := vDst.DataUint16()[dstStart : dstStart+length]
  855. switch vSrc.DType() {
  856. case tensor.Float32:
  857. vIn := vSrc.DataFloat32()[srcStart : srcStart+length]
  858. for i := 0; i < length; i++ {
  859. vOut[i] = float32ToFloat16Bits(vIn[i])
  860. }
  861. case tensor.Float16:
  862. copy(vOut, vSrc.DataUint16()[srcStart:srcStart+length])
  863. default:
  864. return fmt.Errorf("unsupported CPU V src dtype: %v", vSrc.DType())
  865. }
  866. case tensor.BFloat16:
  867. vOut := vDst.DataUint16()[dstStart : dstStart+length]
  868. switch vSrc.DType() {
  869. case tensor.Float32:
  870. vIn := vSrc.DataFloat32()[srcStart : srcStart+length]
  871. for i := 0; i < length; i++ {
  872. vOut[i] = float32ToBFloat16Bits(vIn[i])
  873. }
  874. case tensor.BFloat16:
  875. copy(vOut, vSrc.DataUint16()[srcStart:srcStart+length])
  876. default:
  877. return fmt.Errorf("unsupported CPU V src dtype: %v", vSrc.DType())
  878. }
  879. default:
  880. return fmt.Errorf("unsupported CPU V dst dtype: %v", vDst.DType())
  881. }
  882. written += writeCount
  883. }
  884. return nil
  885. }
  886. // Commit marks tokens as computed and caches completed blocks.
  887. func (c *PagedKVCache) Commit(newTokens int) {
  888. c.mu.Lock()
  889. defer c.mu.Unlock()
  890. oldComputed := c.numComputed
  891. c.numComputed += newTokens
  892. if c.numWritten < c.numComputed {
  893. c.numWritten = c.numComputed
  894. }
  895. // Cache newly completed blocks (only those we have hashes for).
  896. oldBlocks := oldComputed / c.pool.cfg.BlockSize
  897. newBlocks := c.numComputed / c.pool.cfg.BlockSize
  898. maxHashBlocks := len(c.blockHashes)
  899. if newBlocks > maxHashBlocks {
  900. newBlocks = maxHashBlocks
  901. }
  902. if newBlocks > oldBlocks {
  903. for layer := 0; layer < c.cfg.NumLayers; layer++ {
  904. blockIDs := c.blockIDs[layer][oldBlocks:newBlocks]
  905. hashes := c.blockHashes[oldBlocks:newBlocks]
  906. c.pool.CacheBlocks(layer, blockIDs, hashes)
  907. }
  908. }
  909. }
  910. // SeqLen returns the number of computed tokens.
  911. func (c *PagedKVCache) SeqLen() int {
  912. c.mu.RLock()
  913. defer c.mu.RUnlock()
  914. return c.numComputed
  915. }
  916. // ContiguousKV returns a contiguous view of K/V for attention.
  917. // For paged cache, this gathers from blocks into a contiguous buffer.
  918. func (c *PagedKVCache) ContiguousKV(layer, kvLen, kvDim int) (tensor.Tensor, tensor.Tensor, bool, error) {
  919. c.mu.RLock()
  920. defer c.mu.RUnlock()
  921. // NOTE: PagedKVCache does not currently provide a contiguous K/V view.
  922. // Building one by allocating [kvLen, kvDim] tensors per layer causes large
  923. // transient CUDA allocations and can trigger GPU OOM under load.
  924. // Callers should fall back to Views()+concatKVOnDevice.
  925. return nil, nil, false, nil
  926. }
  927. // Free releases all blocks back to the pool.
  928. func (c *PagedKVCache) Free() {
  929. c.mu.Lock()
  930. defer c.mu.Unlock()
  931. c.clearPtrTablesLocked()
  932. for layer := 0; layer < len(c.blockIDs); layer++ {
  933. c.pool.FreeBlocks(layer, c.blockIDs[layer])
  934. }
  935. c.blockIDs = nil
  936. c.ownedBlocks = nil
  937. c.tokenIDs = nil
  938. c.blockHashes = nil
  939. }
  940. // RequestID returns the request ID.
  941. func (c *PagedKVCache) RequestID() string {
  942. return c.requestID
  943. }
  944. // NumTokens returns total tokens in sequence.
  945. func (c *PagedKVCache) NumTokens() int {
  946. c.mu.RLock()
  947. defer c.mu.RUnlock()
  948. return len(c.tokenIDs)
  949. }
  950. // BlockSize returns the block size.
  951. func (c *PagedKVCache) BlockSize() int {
  952. return c.pool.cfg.BlockSize
  953. }
  954. // Truncate rewinds the cache to a specific sequence length.
  955. func (c *PagedKVCache) Truncate(seqLen int) {
  956. c.mu.Lock()
  957. defer c.mu.Unlock()
  958. if seqLen < 0 {
  959. seqLen = 0
  960. }
  961. if seqLen >= c.numComputed {
  962. return
  963. }
  964. c.numComputed = seqLen
  965. if c.numWritten > c.numComputed {
  966. c.numWritten = c.numComputed
  967. }
  968. }
  969. // LayerDevice returns the device placement for a layer.
  970. func (c *PagedKVCache) LayerDevice(layer int) tensor.DevicePlacement {
  971. if c.pool != nil {
  972. return c.pool.LayerDevice(layer)
  973. }
  974. return tensor.DevicePlacement{Type: c.cfg.Device, GPU: c.cfg.GPU}
  975. }
  976. // MaxSeqLen returns the maximum sequence length.
  977. func (c *PagedKVCache) MaxSeqLen() int {
  978. return c.cfg.MaxSeqLen
  979. }
  980. // IsOnGPU returns true if the cache is on GPU.
  981. func (c *PagedKVCache) IsOnGPU() bool {
  982. if c == nil {
  983. return false
  984. }
  985. if c.pool != nil {
  986. for i := 0; i < c.cfg.NumLayers; i++ {
  987. if c.pool.LayerDevice(i).Type == tensor.CUDA {
  988. return true
  989. }
  990. }
  991. return false
  992. }
  993. return c.cfg.Device == tensor.CUDA
  994. }