manager.go 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. // Package kvcache provides KV cache management with prefix caching.
  2. package kvcache
  3. import (
  4. "fmt"
  5. "log"
  6. "sync"
  7. "makarna/pkg/tensor"
  8. )
  9. // RequestKVCache holds the KV cache state for a single request.
  10. // This replaces the old session-based approach.
  11. type RequestKVCache struct {
  12. RequestID string
  13. TokenIDs []int // all tokens (prompt + generated)
  14. // BlockIDs per layer - which blocks this request owns
  15. BlockIDs [][]int
  16. // NumComputedTokens - how many tokens have KV computed
  17. NumComputedTokens int
  18. // BlockHashes for prefix matching
  19. BlockHashes []BlockHash
  20. }
  21. // KVCacheManager manages KV cache for all requests using a global block pool.
  22. // This is the main interface for the inference engine.
  23. type KVCacheManager struct {
  24. pool *BlockPool
  25. cfg ManagerConfig
  26. mu sync.Mutex
  27. requests map[string]*RequestKVCache
  28. // For contiguous GPU KV (optimization when all on one GPU)
  29. contiguousK []tensor.Tensor // [layer] -> [maxSeqLen, kvDim]
  30. contiguousV []tensor.Tensor
  31. useContig bool
  32. }
  33. // ManagerConfig configures the KV cache manager.
  34. type ManagerConfig struct {
  35. NumLayers int
  36. NumKVHeads int
  37. HeadDim int
  38. BlockSize int
  39. MaxSeqLen int
  40. Device tensor.DeviceType
  41. GPU int
  42. EnableCaching bool // enable prefix caching
  43. MaxNumRequests int // max concurrent requests
  44. }
  45. // NewKVCacheManager creates a new KV cache manager.
  46. func NewKVCacheManager(cfg ManagerConfig) (*KVCacheManager, error) {
  47. if cfg.BlockSize <= 0 {
  48. cfg.BlockSize = 16
  49. }
  50. if cfg.MaxSeqLen <= 0 {
  51. cfg.MaxSeqLen = 8192
  52. }
  53. if cfg.MaxNumRequests <= 0 {
  54. cfg.MaxNumRequests = 16
  55. }
  56. // Calculate number of blocks needed
  57. blocksPerRequest := (cfg.MaxSeqLen + cfg.BlockSize - 1) / cfg.BlockSize
  58. totalBlocks := blocksPerRequest * cfg.MaxNumRequests
  59. pool, err := NewBlockPool(BlockPoolConfig{
  60. NumLayers: cfg.NumLayers,
  61. NumKVHeads: cfg.NumKVHeads,
  62. HeadDim: cfg.HeadDim,
  63. BlockSize: cfg.BlockSize,
  64. NumBlocks: totalBlocks,
  65. Device: cfg.Device,
  66. GPU: cfg.GPU,
  67. })
  68. if err != nil {
  69. return nil, fmt.Errorf("create block pool: %w", err)
  70. }
  71. mgr := &KVCacheManager{
  72. pool: pool,
  73. cfg: cfg,
  74. requests: make(map[string]*RequestKVCache),
  75. }
  76. return mgr, nil
  77. }
  78. // AllocateRequest allocates KV cache for a new request.
  79. // Returns the number of prefix-cached tokens (already computed).
  80. func (m *KVCacheManager) AllocateRequest(requestID string, tokenIDs []int) (int, error) {
  81. m.mu.Lock()
  82. defer m.mu.Unlock()
  83. if _, exists := m.requests[requestID]; exists {
  84. return 0, fmt.Errorf("request %s already exists", requestID)
  85. }
  86. // Compute block hashes for prefix matching
  87. hashes := ComputeBlockHashes(tokenIDs, m.cfg.BlockSize)
  88. // Find cached prefix
  89. numCachedTokens := 0
  90. cachedBlockIDs := make([][]int, m.cfg.NumLayers)
  91. if m.cfg.EnableCaching && len(hashes) > 0 {
  92. // Check layer 0 for prefix hits (all layers should have same structure)
  93. cached, tokens := m.pool.FindCachedBlocks(0, hashes)
  94. if len(cached) > 0 {
  95. numCachedTokens = tokens
  96. // Touch cached blocks in all layers
  97. for layer := 0; layer < m.cfg.NumLayers; layer++ {
  98. layerCached, _ := m.pool.FindCachedBlocks(layer, hashes[:len(cached)])
  99. m.pool.TouchBlocks(layer, layerCached)
  100. cachedBlockIDs[layer] = layerCached
  101. }
  102. }
  103. }
  104. // Allocate new blocks for uncached portion
  105. numBlocks := (len(tokenIDs) + m.cfg.BlockSize - 1) / m.cfg.BlockSize
  106. numCachedBlocks := len(cachedBlockIDs[0])
  107. numNewBlocks := numBlocks - numCachedBlocks
  108. allBlockIDs := make([][]int, m.cfg.NumLayers)
  109. for layer := 0; layer < m.cfg.NumLayers; layer++ {
  110. allBlockIDs[layer] = make([]int, 0, numBlocks)
  111. allBlockIDs[layer] = append(allBlockIDs[layer], cachedBlockIDs[layer]...)
  112. if numNewBlocks > 0 {
  113. newBlocks, err := m.pool.AllocateBlocks(layer, numNewBlocks)
  114. if err != nil {
  115. // Rollback: free already allocated
  116. for l := 0; l < layer; l++ {
  117. m.pool.FreeBlocks(l, allBlockIDs[l][numCachedBlocks:])
  118. }
  119. return 0, fmt.Errorf("layer %d: %w", layer, err)
  120. }
  121. allBlockIDs[layer] = append(allBlockIDs[layer], newBlocks...)
  122. }
  123. }
  124. // Don't treat the full prompt as cached; keep the last token to compute logits.
  125. if numCachedTokens >= len(tokenIDs) {
  126. numCachedTokens = len(tokenIDs) - 1
  127. if numCachedTokens < 0 {
  128. numCachedTokens = 0
  129. }
  130. }
  131. req := &RequestKVCache{
  132. RequestID: requestID,
  133. TokenIDs: tokenIDs,
  134. BlockIDs: allBlockIDs,
  135. NumComputedTokens: numCachedTokens,
  136. BlockHashes: hashes,
  137. }
  138. m.requests[requestID] = req
  139. if m.cfg.EnableCaching {
  140. log.Printf("kv_cache request=%s prompt_tokens=%d cached_tokens=%d new_tokens=%d blocks=%d",
  141. requestID, len(tokenIDs), numCachedTokens, len(tokenIDs)-numCachedTokens, numBlocks)
  142. }
  143. return numCachedTokens, nil
  144. }
  145. // GetComputedBlocks returns KV block references for already-computed tokens.
  146. func (m *KVCacheManager) GetComputedBlocks(requestID string, layer int) []*KVBlock {
  147. m.mu.Lock()
  148. defer m.mu.Unlock()
  149. req, ok := m.requests[requestID]
  150. if !ok || layer >= len(req.BlockIDs) {
  151. return nil
  152. }
  153. numComputedBlocks := req.NumComputedTokens / m.cfg.BlockSize
  154. blocks := make([]*KVBlock, 0, numComputedBlocks)
  155. for i := 0; i < numComputedBlocks && i < len(req.BlockIDs[layer]); i++ {
  156. block := m.pool.GetBlock(layer, req.BlockIDs[layer][i])
  157. if block != nil {
  158. blocks = append(blocks, block)
  159. }
  160. }
  161. return blocks
  162. }
  163. // GetBlockForWrite returns the block to write new KV data into.
  164. func (m *KVCacheManager) GetBlockForWrite(requestID string, layer int, tokenPos int) *KVBlock {
  165. m.mu.Lock()
  166. defer m.mu.Unlock()
  167. req, ok := m.requests[requestID]
  168. if !ok || layer >= len(req.BlockIDs) {
  169. return nil
  170. }
  171. blockIdx := tokenPos / m.cfg.BlockSize
  172. if blockIdx >= len(req.BlockIDs[layer]) {
  173. return nil
  174. }
  175. return m.pool.GetBlock(layer, req.BlockIDs[layer][blockIdx])
  176. }
  177. // CommitTokens marks tokens as computed and optionally caches the blocks.
  178. func (m *KVCacheManager) CommitTokens(requestID string, numTokens int) {
  179. m.mu.Lock()
  180. defer m.mu.Unlock()
  181. req, ok := m.requests[requestID]
  182. if !ok {
  183. return
  184. }
  185. oldComputed := req.NumComputedTokens
  186. req.NumComputedTokens += numTokens
  187. // Cache newly completed blocks
  188. if m.cfg.EnableCaching {
  189. oldBlocks := oldComputed / m.cfg.BlockSize
  190. newBlocks := req.NumComputedTokens / m.cfg.BlockSize
  191. if newBlocks > oldBlocks && newBlocks <= len(req.BlockHashes) {
  192. for layer := 0; layer < m.cfg.NumLayers; layer++ {
  193. blockIDs := req.BlockIDs[layer][oldBlocks:newBlocks]
  194. hashes := req.BlockHashes[oldBlocks:newBlocks]
  195. m.pool.CacheBlocks(layer, blockIDs, hashes)
  196. }
  197. }
  198. }
  199. }
  200. // AppendToken adds a generated token to the request.
  201. func (m *KVCacheManager) AppendToken(requestID string, tokenID int) error {
  202. m.mu.Lock()
  203. defer m.mu.Unlock()
  204. req, ok := m.requests[requestID]
  205. if !ok {
  206. return fmt.Errorf("request %s not found", requestID)
  207. }
  208. req.TokenIDs = append(req.TokenIDs, tokenID)
  209. // Check if we need more blocks
  210. numBlocks := (len(req.TokenIDs) + m.cfg.BlockSize - 1) / m.cfg.BlockSize
  211. currentBlocks := len(req.BlockIDs[0])
  212. if numBlocks > currentBlocks {
  213. // Allocate one more block per layer
  214. for layer := 0; layer < m.cfg.NumLayers; layer++ {
  215. newBlocks, err := m.pool.AllocateBlocks(layer, 1)
  216. if err != nil {
  217. return fmt.Errorf("layer %d: %w", layer, err)
  218. }
  219. req.BlockIDs[layer] = append(req.BlockIDs[layer], newBlocks...)
  220. }
  221. // Update block hashes
  222. req.BlockHashes = ComputeBlockHashes(req.TokenIDs, m.cfg.BlockSize)
  223. }
  224. return nil
  225. }
  226. // FreeRequest releases all blocks for a request.
  227. func (m *KVCacheManager) FreeRequest(requestID string) {
  228. m.mu.Lock()
  229. defer m.mu.Unlock()
  230. req, ok := m.requests[requestID]
  231. if !ok {
  232. return
  233. }
  234. for layer := 0; layer < m.cfg.NumLayers; layer++ {
  235. m.pool.FreeBlocks(layer, req.BlockIDs[layer])
  236. }
  237. delete(m.requests, requestID)
  238. }
  239. // GetRequest returns the request state.
  240. func (m *KVCacheManager) GetRequest(requestID string) *RequestKVCache {
  241. m.mu.Lock()
  242. defer m.mu.Unlock()
  243. return m.requests[requestID]
  244. }
  245. // NumComputedTokens returns how many tokens are already computed for a request.
  246. func (m *KVCacheManager) NumComputedTokens(requestID string) int {
  247. m.mu.Lock()
  248. defer m.mu.Unlock()
  249. if req, ok := m.requests[requestID]; ok {
  250. return req.NumComputedTokens
  251. }
  252. return 0
  253. }
  254. // Stats returns cache statistics.
  255. func (m *KVCacheManager) Stats() PrefixCacheStats {
  256. return m.pool.Stats()
  257. }
  258. // Usage returns cache utilization.
  259. func (m *KVCacheManager) Usage() float64 {
  260. return m.pool.Usage()
  261. }
  262. // Free releases all resources.
  263. func (m *KVCacheManager) Free() {
  264. m.mu.Lock()
  265. defer m.mu.Unlock()
  266. m.pool.Free()
  267. m.requests = nil
  268. }
  269. // BlockSize returns configured block size.
  270. func (m *KVCacheManager) BlockSize() int {
  271. return m.cfg.BlockSize
  272. }
  273. // Config returns the manager configuration.
  274. func (m *KVCacheManager) Config() ManagerConfig {
  275. return m.cfg
  276. }