block_pool.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544
  1. // Package kvcache implements a paged KV cache with global prefix caching.
  2. package kvcache
  3. import (
  4. "crypto/sha256"
  5. "encoding/binary"
  6. "fmt"
  7. "sync"
  8. "makarna/pkg/backend/cpu"
  9. "makarna/pkg/backend/cuda"
  10. "makarna/pkg/tensor"
  11. )
  12. // BlockHash is the hash of a block's token content for prefix caching.
  13. type BlockHash [32]byte
  14. // KVBlock represents a single block of KV cache memory.
  15. type KVBlock struct {
  16. ID int
  17. K tensor.Tensor // [blockSize, kvDim]
  18. V tensor.Tensor // [blockSize, kvDim]
  19. pk []float32
  20. pv []float32
  21. Hash *BlockHash // nil if not cached/committed
  22. RefCount int // number of requests using this block
  23. Layer int // which layer this block belongs to
  24. GPU int
  25. }
  26. // BlockPoolConfig configures the block pool.
  27. type BlockPoolConfig struct {
  28. NumLayers int
  29. NumKVHeads int
  30. HeadDim int
  31. BlockSize int
  32. NumBlocks int // total blocks per layer
  33. Device tensor.DeviceType
  34. GPU int
  35. LayerGPUs []int
  36. // LayerPlacements optionally overrides per-layer device placement.
  37. // If provided, it must have length NumLayers and may mix CPU/CUDA layers.
  38. // When set, Device/GPU/LayerGPUs are treated as defaults for layers that
  39. // don't specify a valid placement.
  40. LayerPlacements []tensor.DevicePlacement
  41. Preallocate bool
  42. }
  43. // BlockPool manages a pool of KV cache blocks with prefix caching.
  44. type BlockPool struct {
  45. cfg BlockPoolConfig
  46. mu sync.RWMutex
  47. // blocks[layer][blockID] = *KVBlock
  48. blocks [][]*KVBlock
  49. // Free block IDs per layer (LIFO for locality)
  50. freeBlocks [][]int
  51. // Hash → block mapping for prefix caching (per layer)
  52. // hashToBlock[layer][hash] = blockID
  53. hashToBlock []map[BlockHash]int
  54. // Stats
  55. stats PrefixCacheStats
  56. // CUDA preallocation: one contiguous buffer per layer for K and V.
  57. // When set, ensureAllocatedLocked creates block views into these buffers.
  58. kvDim int
  59. kBase []*cuda.Tensor
  60. vBase []*cuda.Tensor
  61. layerGPUs []int
  62. layerPlacements []tensor.DevicePlacement
  63. }
  64. // PrefixCacheStats tracks cache performance.
  65. type PrefixCacheStats struct {
  66. Hits int64
  67. Misses int64
  68. Evictions int64
  69. Allocations int64
  70. }
  71. // NewBlockPool creates a new block pool.
  72. func NewBlockPool(cfg BlockPoolConfig) (*BlockPool, error) {
  73. if cfg.BlockSize <= 0 {
  74. cfg.BlockSize = 16
  75. }
  76. if cfg.NumBlocks <= 0 {
  77. cfg.NumBlocks = 1024
  78. }
  79. bp := &BlockPool{
  80. cfg: cfg,
  81. blocks: make([][]*KVBlock, cfg.NumLayers),
  82. freeBlocks: make([][]int, cfg.NumLayers),
  83. hashToBlock: make([]map[BlockHash]int, cfg.NumLayers),
  84. kvDim: cfg.NumKVHeads * cfg.HeadDim,
  85. }
  86. if len(cfg.LayerPlacements) != 0 {
  87. if len(cfg.LayerPlacements) != cfg.NumLayers {
  88. return nil, fmt.Errorf("LayerPlacements length %d != NumLayers %d", len(cfg.LayerPlacements), cfg.NumLayers)
  89. }
  90. bp.layerPlacements = make([]tensor.DevicePlacement, cfg.NumLayers)
  91. for i := range cfg.LayerPlacements {
  92. bp.layerPlacements[i] = cfg.LayerPlacements[i].Normalize()
  93. }
  94. }
  95. if cfg.Device == tensor.CUDA || bp.layerPlacements != nil {
  96. if len(cfg.LayerGPUs) != 0 && len(cfg.LayerGPUs) != cfg.NumLayers {
  97. return nil, fmt.Errorf("LayerGPUs length %d != NumLayers %d", len(cfg.LayerGPUs), cfg.NumLayers)
  98. }
  99. if len(cfg.LayerGPUs) == cfg.NumLayers {
  100. bp.layerGPUs = append([]int(nil), cfg.LayerGPUs...)
  101. } else {
  102. bp.layerGPUs = make([]int, cfg.NumLayers)
  103. for i := range bp.layerGPUs {
  104. bp.layerGPUs[i] = cfg.GPU
  105. }
  106. }
  107. }
  108. // Initialize block metadata. K/V tensors are allocated lazily on first use.
  109. for layer := 0; layer < cfg.NumLayers; layer++ {
  110. bp.blocks[layer] = make([]*KVBlock, cfg.NumBlocks)
  111. bp.freeBlocks[layer] = make([]int, 0, cfg.NumBlocks)
  112. bp.hashToBlock[layer] = make(map[BlockHash]int)
  113. for b := 0; b < cfg.NumBlocks; b++ {
  114. p := bp.LayerDevice(layer)
  115. gpu := p.GPU
  116. if p.Type != tensor.CUDA {
  117. gpu = -1
  118. }
  119. bp.blocks[layer][b] = &KVBlock{ID: b, Layer: layer, GPU: gpu}
  120. bp.freeBlocks[layer] = append(bp.freeBlocks[layer], b)
  121. }
  122. }
  123. if cfg.Preallocate && cuda.Available() {
  124. if err := bp.preallocateCUDA(); err != nil {
  125. bp.Free()
  126. return nil, err
  127. }
  128. }
  129. return bp, nil
  130. }
  131. func (bp *BlockPool) preallocateCUDA() error {
  132. if bp.kvDim <= 0 {
  133. return fmt.Errorf("invalid kvDim %d", bp.kvDim)
  134. }
  135. if bp.cfg.BlockSize <= 0 || bp.cfg.NumBlocks <= 0 {
  136. return fmt.Errorf("invalid block pool sizes: blockSize=%d numBlocks=%d", bp.cfg.BlockSize, bp.cfg.NumBlocks)
  137. }
  138. totalTokens := bp.cfg.NumBlocks * bp.cfg.BlockSize
  139. shape := tensor.Shape{totalTokens, bp.kvDim}
  140. bp.kBase = make([]*cuda.Tensor, bp.cfg.NumLayers)
  141. bp.vBase = make([]*cuda.Tensor, bp.cfg.NumLayers)
  142. for layer := 0; layer < bp.cfg.NumLayers; layer++ {
  143. p := bp.LayerDevice(layer)
  144. if p.Type != tensor.CUDA {
  145. continue
  146. }
  147. gpu := p.GPU
  148. k, err := cuda.NewTensor(shape, tensor.Float16, gpu)
  149. if err != nil {
  150. return fmt.Errorf("preallocate K layer %d: %w", layer, err)
  151. }
  152. v, err := cuda.NewTensor(shape, tensor.Float16, gpu)
  153. if err != nil {
  154. k.Free()
  155. return fmt.Errorf("preallocate V layer %d: %w", layer, err)
  156. }
  157. bp.kBase[layer] = k
  158. bp.vBase[layer] = v
  159. }
  160. return nil
  161. }
  162. func (bp *BlockPool) ensureAllocatedLocked(layer, blockID int) (*KVBlock, error) {
  163. if layer < 0 || layer >= len(bp.blocks) {
  164. return nil, fmt.Errorf("invalid layer %d", layer)
  165. }
  166. if blockID < 0 || blockID >= len(bp.blocks[layer]) {
  167. return nil, fmt.Errorf("invalid blockID %d", blockID)
  168. }
  169. block := bp.blocks[layer][blockID]
  170. if block == nil {
  171. p := bp.LayerDevice(layer)
  172. gpu := p.GPU
  173. if p.Type != tensor.CUDA {
  174. gpu = -1
  175. }
  176. block = &KVBlock{ID: blockID, Layer: layer, GPU: gpu}
  177. bp.blocks[layer][blockID] = block
  178. }
  179. if block.K != nil && block.V != nil {
  180. return block, nil
  181. }
  182. kvDim := bp.kvDim
  183. p := bp.LayerDevice(layer)
  184. if p.Type == tensor.CUDA && bp.kBase != nil && bp.vBase != nil {
  185. if layer < len(bp.kBase) && layer < len(bp.vBase) && bp.kBase[layer] != nil && bp.vBase[layer] != nil {
  186. elemSize := uintptr(tensor.Float16.Size())
  187. blockElems := uintptr(bp.cfg.BlockSize * kvDim)
  188. offsetBytes := uintptr(blockID) * blockElems * elemSize
  189. k, err := bp.kBase[layer].ViewAt(tensor.Shape{bp.cfg.BlockSize, kvDim}, offsetBytes)
  190. if err != nil {
  191. return nil, fmt.Errorf("create K view: %w", err)
  192. }
  193. v, err := bp.vBase[layer].ViewAt(tensor.Shape{bp.cfg.BlockSize, kvDim}, offsetBytes)
  194. if err != nil {
  195. return nil, fmt.Errorf("create V view: %w", err)
  196. }
  197. block.K = k
  198. block.V = v
  199. }
  200. }
  201. if block.K == nil || block.V == nil {
  202. k, v, err := allocateBlock(bp.cfg.BlockSize, kvDim, p.Type, p.GPU)
  203. if err != nil {
  204. return nil, err
  205. }
  206. block.K = k
  207. block.V = v
  208. }
  209. if p.Type == tensor.CPU {
  210. bufSize := bp.cfg.NumKVHeads * bp.cfg.BlockSize * bp.cfg.HeadDim
  211. block.pk = make([]float32, bufSize)
  212. block.pv = make([]float32, bufSize)
  213. }
  214. return block, nil
  215. }
  216. func allocateBlock(blockSize, kvDim int, device tensor.DeviceType, gpu int) (tensor.Tensor, tensor.Tensor, error) {
  217. shape := tensor.Shape{blockSize, kvDim}
  218. if device == tensor.CUDA && cuda.Available() {
  219. k, err := cuda.NewTensor(shape, tensor.Float16, gpu)
  220. if err != nil {
  221. return nil, nil, err
  222. }
  223. v, err := cuda.NewTensor(shape, tensor.Float16, gpu)
  224. if err != nil {
  225. return nil, nil, err
  226. }
  227. return k, v, nil
  228. }
  229. k, err := cpu.NewTensorU16(shape, tensor.Float16, nil)
  230. if err != nil {
  231. return nil, nil, err
  232. }
  233. v, err := cpu.NewTensorU16(shape, tensor.Float16, nil)
  234. if err != nil {
  235. return nil, nil, err
  236. }
  237. return k, v, nil
  238. }
  239. // ComputeBlockHash computes hash for a sequence of token IDs.
  240. // This is used for prefix caching - same tokens = same hash.
  241. func ComputeBlockHash(tokens []int, parentHash *BlockHash) BlockHash {
  242. h := sha256.New()
  243. if parentHash != nil {
  244. h.Write(parentHash[:])
  245. }
  246. buf := make([]byte, 4)
  247. for _, t := range tokens {
  248. binary.LittleEndian.PutUint32(buf, uint32(t))
  249. h.Write(buf)
  250. }
  251. var hash BlockHash
  252. copy(hash[:], h.Sum(nil))
  253. return hash
  254. }
  255. // ComputeBlockHashes computes hashes for all blocks in a token sequence.
  256. func ComputeBlockHashes(tokens []int, blockSize int) []BlockHash {
  257. numBlocks := (len(tokens) + blockSize - 1) / blockSize
  258. hashes := make([]BlockHash, numBlocks)
  259. var parentHash *BlockHash
  260. for i := 0; i < numBlocks; i++ {
  261. start := i * blockSize
  262. end := start + blockSize
  263. if end > len(tokens) {
  264. end = len(tokens)
  265. }
  266. hashes[i] = ComputeBlockHash(tokens[start:end], parentHash)
  267. parentHash = &hashes[i]
  268. }
  269. return hashes
  270. }
  271. // FindCachedBlocks finds the longest prefix of blocks that are already cached.
  272. // Returns block IDs and the number of tokens covered.
  273. func (bp *BlockPool) FindCachedBlocks(layer int, hashes []BlockHash) ([]int, int) {
  274. bp.mu.RLock()
  275. defer bp.mu.RUnlock()
  276. if layer < 0 || layer >= len(bp.hashToBlock) {
  277. return nil, 0
  278. }
  279. cached := make([]int, 0, len(hashes))
  280. for _, hash := range hashes {
  281. if blockID, ok := bp.hashToBlock[layer][hash]; ok {
  282. cached = append(cached, blockID)
  283. bp.stats.Hits++
  284. } else {
  285. bp.stats.Misses++
  286. break // prefix caching requires contiguous hits
  287. }
  288. }
  289. return cached, len(cached) * bp.cfg.BlockSize
  290. }
  291. // AllocateBlocks allocates new blocks for a request.
  292. // If needed, evicts least-recently-used cached blocks.
  293. func (bp *BlockPool) AllocateBlocks(layer int, count int) ([]int, error) {
  294. bp.mu.Lock()
  295. defer bp.mu.Unlock()
  296. if layer < 0 || layer >= len(bp.freeBlocks) {
  297. return nil, fmt.Errorf("invalid layer %d", layer)
  298. }
  299. allocated := make([]int, 0, count)
  300. // First try free blocks
  301. for len(allocated) < count && len(bp.freeBlocks[layer]) > 0 {
  302. n := len(bp.freeBlocks[layer])
  303. blockID := bp.freeBlocks[layer][n-1]
  304. bp.freeBlocks[layer] = bp.freeBlocks[layer][:n-1]
  305. if _, err := bp.ensureAllocatedLocked(layer, blockID); err != nil {
  306. // Return blockID to free list and fail.
  307. bp.freeBlocks[layer] = append(bp.freeBlocks[layer], blockID)
  308. return allocated, fmt.Errorf("allocate block %d: %w", blockID, err)
  309. }
  310. allocated = append(allocated, blockID)
  311. bp.stats.Allocations++
  312. }
  313. // If still need more, evict cached blocks with refcount=0
  314. if len(allocated) < count {
  315. for hash, blockID := range bp.hashToBlock[layer] {
  316. block := bp.blocks[layer][blockID]
  317. if block.RefCount == 0 {
  318. // Evict
  319. delete(bp.hashToBlock[layer], hash)
  320. block.Hash = nil
  321. if _, err := bp.ensureAllocatedLocked(layer, blockID); err != nil {
  322. return allocated, fmt.Errorf("allocate evicted block %d: %w", blockID, err)
  323. }
  324. allocated = append(allocated, blockID)
  325. bp.stats.Evictions++
  326. bp.stats.Allocations++
  327. if len(allocated) >= count {
  328. break
  329. }
  330. }
  331. }
  332. }
  333. if len(allocated) < count {
  334. return allocated, fmt.Errorf("not enough blocks: need %d, got %d", count, len(allocated))
  335. }
  336. // Increment ref counts
  337. for _, blockID := range allocated {
  338. bp.blocks[layer][blockID].RefCount++
  339. }
  340. return allocated, nil
  341. }
  342. // TouchBlocks increments ref count for cached blocks being reused.
  343. func (bp *BlockPool) TouchBlocks(layer int, blockIDs []int) {
  344. bp.mu.Lock()
  345. defer bp.mu.Unlock()
  346. for _, blockID := range blockIDs {
  347. if blockID >= 0 && blockID < len(bp.blocks[layer]) {
  348. bp.blocks[layer][blockID].RefCount++
  349. }
  350. }
  351. }
  352. // CacheBlocks registers blocks in the hash cache after they're computed.
  353. func (bp *BlockPool) CacheBlocks(layer int, blockIDs []int, hashes []BlockHash) {
  354. bp.mu.Lock()
  355. defer bp.mu.Unlock()
  356. for i, blockID := range blockIDs {
  357. if i >= len(hashes) {
  358. break
  359. }
  360. block, err := bp.ensureAllocatedLocked(layer, blockID)
  361. if err != nil {
  362. continue
  363. }
  364. hash := hashes[i]
  365. block.Hash = &hash
  366. bp.hashToBlock[layer][hash] = blockID
  367. }
  368. }
  369. // FreeBlocks decrements ref counts; blocks with refcount=0 become eviction candidates.
  370. func (bp *BlockPool) FreeBlocks(layer int, blockIDs []int) {
  371. bp.mu.Lock()
  372. defer bp.mu.Unlock()
  373. for _, blockID := range blockIDs {
  374. if blockID >= 0 && blockID < len(bp.blocks[layer]) {
  375. block := bp.blocks[layer][blockID]
  376. if block.RefCount > 0 {
  377. block.RefCount--
  378. }
  379. // Don't return to free list if cached (stays as eviction candidate)
  380. // Only truly free if not in hash cache
  381. if block.RefCount == 0 && block.Hash == nil {
  382. bp.freeBlocks[layer] = append(bp.freeBlocks[layer], blockID)
  383. }
  384. }
  385. }
  386. }
  387. // GetBlock returns a block by ID.
  388. func (bp *BlockPool) GetBlock(layer, blockID int) *KVBlock {
  389. bp.mu.Lock()
  390. defer bp.mu.Unlock()
  391. block, err := bp.ensureAllocatedLocked(layer, blockID)
  392. if err != nil {
  393. return nil
  394. }
  395. return block
  396. }
  397. func (bp *BlockPool) LayerDevice(layer int) tensor.DevicePlacement {
  398. if bp == nil {
  399. return tensor.DevicePlacement{Type: tensor.CPU, GPU: -1}
  400. }
  401. if bp.layerPlacements != nil && layer >= 0 && layer < len(bp.layerPlacements) {
  402. p := bp.layerPlacements[layer].Normalize()
  403. // Defensive: normalize invalid CUDA GPU ids.
  404. if p.Type == tensor.CUDA && p.GPU < 0 {
  405. p.GPU = 0
  406. }
  407. return p
  408. }
  409. if bp.cfg.Device != tensor.CUDA {
  410. return tensor.DevicePlacement{Type: bp.cfg.Device, GPU: -1}
  411. }
  412. gpu := bp.cfg.GPU
  413. if bp.layerGPUs != nil && layer >= 0 && layer < len(bp.layerGPUs) {
  414. gpu = bp.layerGPUs[layer]
  415. }
  416. return tensor.DevicePlacement{Type: tensor.CUDA, GPU: gpu}
  417. }
  418. // Stats returns current cache statistics.
  419. func (bp *BlockPool) Stats() PrefixCacheStats {
  420. bp.mu.RLock()
  421. defer bp.mu.RUnlock()
  422. return bp.stats
  423. }
  424. // NumFreeBlocks returns free block count for a layer.
  425. func (bp *BlockPool) NumFreeBlocks(layer int) int {
  426. bp.mu.RLock()
  427. defer bp.mu.RUnlock()
  428. if layer < 0 || layer >= len(bp.freeBlocks) {
  429. return 0
  430. }
  431. return len(bp.freeBlocks[layer])
  432. }
  433. // Usage returns cache utilization (0.0 to 1.0).
  434. func (bp *BlockPool) Usage() float64 {
  435. bp.mu.RLock()
  436. defer bp.mu.RUnlock()
  437. if bp.cfg.NumBlocks == 0 {
  438. return 0
  439. }
  440. // Average across layers
  441. totalFree := 0
  442. for _, free := range bp.freeBlocks {
  443. totalFree += len(free)
  444. }
  445. avgFree := float64(totalFree) / float64(len(bp.freeBlocks))
  446. return 1.0 - (avgFree / float64(bp.cfg.NumBlocks))
  447. }
  448. // Free releases all GPU memory.
  449. func (bp *BlockPool) Free() {
  450. bp.mu.Lock()
  451. defer bp.mu.Unlock()
  452. // Free contiguous CUDA buffers first (if present).
  453. for i := range bp.kBase {
  454. if bp.kBase[i] != nil {
  455. bp.kBase[i].Free()
  456. bp.kBase[i] = nil
  457. }
  458. }
  459. for i := range bp.vBase {
  460. if bp.vBase[i] != nil {
  461. bp.vBase[i].Free()
  462. bp.vBase[i] = nil
  463. }
  464. }
  465. bp.kBase = nil
  466. bp.vBase = nil
  467. for _, layerBlocks := range bp.blocks {
  468. for _, block := range layerBlocks {
  469. if block == nil {
  470. continue
  471. }
  472. if ct, ok := block.K.(*cuda.Tensor); ok && ct != nil {
  473. ct.Free()
  474. }
  475. if ct, ok := block.V.(*cuda.Tensor); ok && ct != nil {
  476. ct.Free()
  477. }
  478. }
  479. }
  480. bp.blocks = nil
  481. bp.freeBlocks = nil
  482. bp.hashToBlock = nil
  483. }