// Package kvcache provides KV cache management with prefix caching. package kvcache import ( "fmt" "log" "sync" "makarna/pkg/tensor" ) // RequestKVCache holds the KV cache state for a single request. // This replaces the old session-based approach. type RequestKVCache struct { RequestID string TokenIDs []int // all tokens (prompt + generated) // BlockIDs per layer - which blocks this request owns BlockIDs [][]int // NumComputedTokens - how many tokens have KV computed NumComputedTokens int // BlockHashes for prefix matching BlockHashes []BlockHash } // KVCacheManager manages KV cache for all requests using a global block pool. // This is the main interface for the inference engine. type KVCacheManager struct { pool *BlockPool cfg ManagerConfig mu sync.Mutex requests map[string]*RequestKVCache // For contiguous GPU KV (optimization when all on one GPU) contiguousK []tensor.Tensor // [layer] -> [maxSeqLen, kvDim] contiguousV []tensor.Tensor useContig bool } // ManagerConfig configures the KV cache manager. type ManagerConfig struct { NumLayers int NumKVHeads int HeadDim int BlockSize int MaxSeqLen int Device tensor.DeviceType GPU int EnableCaching bool // enable prefix caching MaxNumRequests int // max concurrent requests } // NewKVCacheManager creates a new KV cache manager. func NewKVCacheManager(cfg ManagerConfig) (*KVCacheManager, error) { if cfg.BlockSize <= 0 { cfg.BlockSize = 16 } if cfg.MaxSeqLen <= 0 { cfg.MaxSeqLen = 8192 } if cfg.MaxNumRequests <= 0 { cfg.MaxNumRequests = 16 } // Calculate number of blocks needed blocksPerRequest := (cfg.MaxSeqLen + cfg.BlockSize - 1) / cfg.BlockSize totalBlocks := blocksPerRequest * cfg.MaxNumRequests pool, err := NewBlockPool(BlockPoolConfig{ NumLayers: cfg.NumLayers, NumKVHeads: cfg.NumKVHeads, HeadDim: cfg.HeadDim, BlockSize: cfg.BlockSize, NumBlocks: totalBlocks, Device: cfg.Device, GPU: cfg.GPU, }) if err != nil { return nil, fmt.Errorf("create block pool: %w", err) } mgr := &KVCacheManager{ pool: pool, cfg: cfg, requests: make(map[string]*RequestKVCache), } return mgr, nil } // AllocateRequest allocates KV cache for a new request. // Returns the number of prefix-cached tokens (already computed). func (m *KVCacheManager) AllocateRequest(requestID string, tokenIDs []int) (int, error) { m.mu.Lock() defer m.mu.Unlock() if _, exists := m.requests[requestID]; exists { return 0, fmt.Errorf("request %s already exists", requestID) } // Compute block hashes for prefix matching hashes := ComputeBlockHashes(tokenIDs, m.cfg.BlockSize) // Find cached prefix numCachedTokens := 0 cachedBlockIDs := make([][]int, m.cfg.NumLayers) if m.cfg.EnableCaching && len(hashes) > 0 { // Check layer 0 for prefix hits (all layers should have same structure) cached, tokens := m.pool.FindCachedBlocks(0, hashes) if len(cached) > 0 { numCachedTokens = tokens // Touch cached blocks in all layers for layer := 0; layer < m.cfg.NumLayers; layer++ { layerCached, _ := m.pool.FindCachedBlocks(layer, hashes[:len(cached)]) m.pool.TouchBlocks(layer, layerCached) cachedBlockIDs[layer] = layerCached } } } // Allocate new blocks for uncached portion numBlocks := (len(tokenIDs) + m.cfg.BlockSize - 1) / m.cfg.BlockSize numCachedBlocks := len(cachedBlockIDs[0]) numNewBlocks := numBlocks - numCachedBlocks allBlockIDs := make([][]int, m.cfg.NumLayers) for layer := 0; layer < m.cfg.NumLayers; layer++ { allBlockIDs[layer] = make([]int, 0, numBlocks) allBlockIDs[layer] = append(allBlockIDs[layer], cachedBlockIDs[layer]...) if numNewBlocks > 0 { newBlocks, err := m.pool.AllocateBlocks(layer, numNewBlocks) if err != nil { // Rollback: free already allocated for l := 0; l < layer; l++ { m.pool.FreeBlocks(l, allBlockIDs[l][numCachedBlocks:]) } return 0, fmt.Errorf("layer %d: %w", layer, err) } allBlockIDs[layer] = append(allBlockIDs[layer], newBlocks...) } } // Don't treat the full prompt as cached; keep the last token to compute logits. if numCachedTokens >= len(tokenIDs) { numCachedTokens = len(tokenIDs) - 1 if numCachedTokens < 0 { numCachedTokens = 0 } } req := &RequestKVCache{ RequestID: requestID, TokenIDs: tokenIDs, BlockIDs: allBlockIDs, NumComputedTokens: numCachedTokens, BlockHashes: hashes, } m.requests[requestID] = req if m.cfg.EnableCaching { log.Printf("kv_cache request=%s prompt_tokens=%d cached_tokens=%d new_tokens=%d blocks=%d", requestID, len(tokenIDs), numCachedTokens, len(tokenIDs)-numCachedTokens, numBlocks) } return numCachedTokens, nil } // GetComputedBlocks returns KV block references for already-computed tokens. func (m *KVCacheManager) GetComputedBlocks(requestID string, layer int) []*KVBlock { m.mu.Lock() defer m.mu.Unlock() req, ok := m.requests[requestID] if !ok || layer >= len(req.BlockIDs) { return nil } numComputedBlocks := req.NumComputedTokens / m.cfg.BlockSize blocks := make([]*KVBlock, 0, numComputedBlocks) for i := 0; i < numComputedBlocks && i < len(req.BlockIDs[layer]); i++ { block := m.pool.GetBlock(layer, req.BlockIDs[layer][i]) if block != nil { blocks = append(blocks, block) } } return blocks } // GetBlockForWrite returns the block to write new KV data into. func (m *KVCacheManager) GetBlockForWrite(requestID string, layer int, tokenPos int) *KVBlock { m.mu.Lock() defer m.mu.Unlock() req, ok := m.requests[requestID] if !ok || layer >= len(req.BlockIDs) { return nil } blockIdx := tokenPos / m.cfg.BlockSize if blockIdx >= len(req.BlockIDs[layer]) { return nil } return m.pool.GetBlock(layer, req.BlockIDs[layer][blockIdx]) } // CommitTokens marks tokens as computed and optionally caches the blocks. func (m *KVCacheManager) CommitTokens(requestID string, numTokens int) { m.mu.Lock() defer m.mu.Unlock() req, ok := m.requests[requestID] if !ok { return } oldComputed := req.NumComputedTokens req.NumComputedTokens += numTokens // Cache newly completed blocks if m.cfg.EnableCaching { oldBlocks := oldComputed / m.cfg.BlockSize newBlocks := req.NumComputedTokens / m.cfg.BlockSize if newBlocks > oldBlocks && newBlocks <= len(req.BlockHashes) { for layer := 0; layer < m.cfg.NumLayers; layer++ { blockIDs := req.BlockIDs[layer][oldBlocks:newBlocks] hashes := req.BlockHashes[oldBlocks:newBlocks] m.pool.CacheBlocks(layer, blockIDs, hashes) } } } } // AppendToken adds a generated token to the request. func (m *KVCacheManager) AppendToken(requestID string, tokenID int) error { m.mu.Lock() defer m.mu.Unlock() req, ok := m.requests[requestID] if !ok { return fmt.Errorf("request %s not found", requestID) } req.TokenIDs = append(req.TokenIDs, tokenID) // Check if we need more blocks numBlocks := (len(req.TokenIDs) + m.cfg.BlockSize - 1) / m.cfg.BlockSize currentBlocks := len(req.BlockIDs[0]) if numBlocks > currentBlocks { // Allocate one more block per layer for layer := 0; layer < m.cfg.NumLayers; layer++ { newBlocks, err := m.pool.AllocateBlocks(layer, 1) if err != nil { return fmt.Errorf("layer %d: %w", layer, err) } req.BlockIDs[layer] = append(req.BlockIDs[layer], newBlocks...) } // Update block hashes req.BlockHashes = ComputeBlockHashes(req.TokenIDs, m.cfg.BlockSize) } return nil } // FreeRequest releases all blocks for a request. func (m *KVCacheManager) FreeRequest(requestID string) { m.mu.Lock() defer m.mu.Unlock() req, ok := m.requests[requestID] if !ok { return } for layer := 0; layer < m.cfg.NumLayers; layer++ { m.pool.FreeBlocks(layer, req.BlockIDs[layer]) } delete(m.requests, requestID) } // GetRequest returns the request state. func (m *KVCacheManager) GetRequest(requestID string) *RequestKVCache { m.mu.Lock() defer m.mu.Unlock() return m.requests[requestID] } // NumComputedTokens returns how many tokens are already computed for a request. func (m *KVCacheManager) NumComputedTokens(requestID string) int { m.mu.Lock() defer m.mu.Unlock() if req, ok := m.requests[requestID]; ok { return req.NumComputedTokens } return 0 } // Stats returns cache statistics. func (m *KVCacheManager) Stats() PrefixCacheStats { return m.pool.Stats() } // Usage returns cache utilization. func (m *KVCacheManager) Usage() float64 { return m.pool.Usage() } // Free releases all resources. func (m *KVCacheManager) Free() { m.mu.Lock() defer m.mu.Unlock() m.pool.Free() m.requests = nil } // BlockSize returns configured block size. func (m *KVCacheManager) BlockSize() int { return m.cfg.BlockSize } // Config returns the manager configuration. func (m *KVCacheManager) Config() ManagerConfig { return m.cfg }