| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333 |
- // Package kvcache provides KV cache management with prefix caching.
- package kvcache
- import (
- "fmt"
- "log"
- "sync"
- "makarna/pkg/tensor"
- )
- // RequestKVCache holds the KV cache state for a single request.
- // This replaces the old session-based approach.
- type RequestKVCache struct {
- RequestID string
- TokenIDs []int // all tokens (prompt + generated)
- // BlockIDs per layer - which blocks this request owns
- BlockIDs [][]int
- // NumComputedTokens - how many tokens have KV computed
- NumComputedTokens int
- // BlockHashes for prefix matching
- BlockHashes []BlockHash
- }
- // KVCacheManager manages KV cache for all requests using a global block pool.
- // This is the main interface for the inference engine.
- type KVCacheManager struct {
- pool *BlockPool
- cfg ManagerConfig
- mu sync.Mutex
- requests map[string]*RequestKVCache
- // For contiguous GPU KV (optimization when all on one GPU)
- contiguousK []tensor.Tensor // [layer] -> [maxSeqLen, kvDim]
- contiguousV []tensor.Tensor
- useContig bool
- }
- // ManagerConfig configures the KV cache manager.
- type ManagerConfig struct {
- NumLayers int
- NumKVHeads int
- HeadDim int
- BlockSize int
- MaxSeqLen int
- Device tensor.DeviceType
- GPU int
- EnableCaching bool // enable prefix caching
- MaxNumRequests int // max concurrent requests
- }
- // NewKVCacheManager creates a new KV cache manager.
- func NewKVCacheManager(cfg ManagerConfig) (*KVCacheManager, error) {
- if cfg.BlockSize <= 0 {
- cfg.BlockSize = 16
- }
- if cfg.MaxSeqLen <= 0 {
- cfg.MaxSeqLen = 8192
- }
- if cfg.MaxNumRequests <= 0 {
- cfg.MaxNumRequests = 16
- }
- // Calculate number of blocks needed
- blocksPerRequest := (cfg.MaxSeqLen + cfg.BlockSize - 1) / cfg.BlockSize
- totalBlocks := blocksPerRequest * cfg.MaxNumRequests
- pool, err := NewBlockPool(BlockPoolConfig{
- NumLayers: cfg.NumLayers,
- NumKVHeads: cfg.NumKVHeads,
- HeadDim: cfg.HeadDim,
- BlockSize: cfg.BlockSize,
- NumBlocks: totalBlocks,
- Device: cfg.Device,
- GPU: cfg.GPU,
- })
- if err != nil {
- return nil, fmt.Errorf("create block pool: %w", err)
- }
- mgr := &KVCacheManager{
- pool: pool,
- cfg: cfg,
- requests: make(map[string]*RequestKVCache),
- }
- return mgr, nil
- }
- // AllocateRequest allocates KV cache for a new request.
- // Returns the number of prefix-cached tokens (already computed).
- func (m *KVCacheManager) AllocateRequest(requestID string, tokenIDs []int) (int, error) {
- m.mu.Lock()
- defer m.mu.Unlock()
- if _, exists := m.requests[requestID]; exists {
- return 0, fmt.Errorf("request %s already exists", requestID)
- }
- // Compute block hashes for prefix matching
- hashes := ComputeBlockHashes(tokenIDs, m.cfg.BlockSize)
- // Find cached prefix
- numCachedTokens := 0
- cachedBlockIDs := make([][]int, m.cfg.NumLayers)
- if m.cfg.EnableCaching && len(hashes) > 0 {
- // Check layer 0 for prefix hits (all layers should have same structure)
- cached, tokens := m.pool.FindCachedBlocks(0, hashes)
- if len(cached) > 0 {
- numCachedTokens = tokens
- // Touch cached blocks in all layers
- for layer := 0; layer < m.cfg.NumLayers; layer++ {
- layerCached, _ := m.pool.FindCachedBlocks(layer, hashes[:len(cached)])
- m.pool.TouchBlocks(layer, layerCached)
- cachedBlockIDs[layer] = layerCached
- }
- }
- }
- // Allocate new blocks for uncached portion
- numBlocks := (len(tokenIDs) + m.cfg.BlockSize - 1) / m.cfg.BlockSize
- numCachedBlocks := len(cachedBlockIDs[0])
- numNewBlocks := numBlocks - numCachedBlocks
- allBlockIDs := make([][]int, m.cfg.NumLayers)
- for layer := 0; layer < m.cfg.NumLayers; layer++ {
- allBlockIDs[layer] = make([]int, 0, numBlocks)
- allBlockIDs[layer] = append(allBlockIDs[layer], cachedBlockIDs[layer]...)
- if numNewBlocks > 0 {
- newBlocks, err := m.pool.AllocateBlocks(layer, numNewBlocks)
- if err != nil {
- // Rollback: free already allocated
- for l := 0; l < layer; l++ {
- m.pool.FreeBlocks(l, allBlockIDs[l][numCachedBlocks:])
- }
- return 0, fmt.Errorf("layer %d: %w", layer, err)
- }
- allBlockIDs[layer] = append(allBlockIDs[layer], newBlocks...)
- }
- }
- // Don't treat the full prompt as cached; keep the last token to compute logits.
- if numCachedTokens >= len(tokenIDs) {
- numCachedTokens = len(tokenIDs) - 1
- if numCachedTokens < 0 {
- numCachedTokens = 0
- }
- }
- req := &RequestKVCache{
- RequestID: requestID,
- TokenIDs: tokenIDs,
- BlockIDs: allBlockIDs,
- NumComputedTokens: numCachedTokens,
- BlockHashes: hashes,
- }
- m.requests[requestID] = req
- if m.cfg.EnableCaching {
- log.Printf("kv_cache request=%s prompt_tokens=%d cached_tokens=%d new_tokens=%d blocks=%d",
- requestID, len(tokenIDs), numCachedTokens, len(tokenIDs)-numCachedTokens, numBlocks)
- }
- return numCachedTokens, nil
- }
- // GetComputedBlocks returns KV block references for already-computed tokens.
- func (m *KVCacheManager) GetComputedBlocks(requestID string, layer int) []*KVBlock {
- m.mu.Lock()
- defer m.mu.Unlock()
- req, ok := m.requests[requestID]
- if !ok || layer >= len(req.BlockIDs) {
- return nil
- }
- numComputedBlocks := req.NumComputedTokens / m.cfg.BlockSize
- blocks := make([]*KVBlock, 0, numComputedBlocks)
- for i := 0; i < numComputedBlocks && i < len(req.BlockIDs[layer]); i++ {
- block := m.pool.GetBlock(layer, req.BlockIDs[layer][i])
- if block != nil {
- blocks = append(blocks, block)
- }
- }
- return blocks
- }
- // GetBlockForWrite returns the block to write new KV data into.
- func (m *KVCacheManager) GetBlockForWrite(requestID string, layer int, tokenPos int) *KVBlock {
- m.mu.Lock()
- defer m.mu.Unlock()
- req, ok := m.requests[requestID]
- if !ok || layer >= len(req.BlockIDs) {
- return nil
- }
- blockIdx := tokenPos / m.cfg.BlockSize
- if blockIdx >= len(req.BlockIDs[layer]) {
- return nil
- }
- return m.pool.GetBlock(layer, req.BlockIDs[layer][blockIdx])
- }
- // CommitTokens marks tokens as computed and optionally caches the blocks.
- func (m *KVCacheManager) CommitTokens(requestID string, numTokens int) {
- m.mu.Lock()
- defer m.mu.Unlock()
- req, ok := m.requests[requestID]
- if !ok {
- return
- }
- oldComputed := req.NumComputedTokens
- req.NumComputedTokens += numTokens
- // Cache newly completed blocks
- if m.cfg.EnableCaching {
- oldBlocks := oldComputed / m.cfg.BlockSize
- newBlocks := req.NumComputedTokens / m.cfg.BlockSize
- if newBlocks > oldBlocks && newBlocks <= len(req.BlockHashes) {
- for layer := 0; layer < m.cfg.NumLayers; layer++ {
- blockIDs := req.BlockIDs[layer][oldBlocks:newBlocks]
- hashes := req.BlockHashes[oldBlocks:newBlocks]
- m.pool.CacheBlocks(layer, blockIDs, hashes)
- }
- }
- }
- }
- // AppendToken adds a generated token to the request.
- func (m *KVCacheManager) AppendToken(requestID string, tokenID int) error {
- m.mu.Lock()
- defer m.mu.Unlock()
- req, ok := m.requests[requestID]
- if !ok {
- return fmt.Errorf("request %s not found", requestID)
- }
- req.TokenIDs = append(req.TokenIDs, tokenID)
- // Check if we need more blocks
- numBlocks := (len(req.TokenIDs) + m.cfg.BlockSize - 1) / m.cfg.BlockSize
- currentBlocks := len(req.BlockIDs[0])
- if numBlocks > currentBlocks {
- // Allocate one more block per layer
- for layer := 0; layer < m.cfg.NumLayers; layer++ {
- newBlocks, err := m.pool.AllocateBlocks(layer, 1)
- if err != nil {
- return fmt.Errorf("layer %d: %w", layer, err)
- }
- req.BlockIDs[layer] = append(req.BlockIDs[layer], newBlocks...)
- }
- // Update block hashes
- req.BlockHashes = ComputeBlockHashes(req.TokenIDs, m.cfg.BlockSize)
- }
- return nil
- }
- // FreeRequest releases all blocks for a request.
- func (m *KVCacheManager) FreeRequest(requestID string) {
- m.mu.Lock()
- defer m.mu.Unlock()
- req, ok := m.requests[requestID]
- if !ok {
- return
- }
- for layer := 0; layer < m.cfg.NumLayers; layer++ {
- m.pool.FreeBlocks(layer, req.BlockIDs[layer])
- }
- delete(m.requests, requestID)
- }
- // GetRequest returns the request state.
- func (m *KVCacheManager) GetRequest(requestID string) *RequestKVCache {
- m.mu.Lock()
- defer m.mu.Unlock()
- return m.requests[requestID]
- }
- // NumComputedTokens returns how many tokens are already computed for a request.
- func (m *KVCacheManager) NumComputedTokens(requestID string) int {
- m.mu.Lock()
- defer m.mu.Unlock()
- if req, ok := m.requests[requestID]; ok {
- return req.NumComputedTokens
- }
- return 0
- }
- // Stats returns cache statistics.
- func (m *KVCacheManager) Stats() PrefixCacheStats {
- return m.pool.Stats()
- }
- // Usage returns cache utilization.
- func (m *KVCacheManager) Usage() float64 {
- return m.pool.Usage()
- }
- // Free releases all resources.
- func (m *KVCacheManager) Free() {
- m.mu.Lock()
- defer m.mu.Unlock()
- m.pool.Free()
- m.requests = nil
- }
- // BlockSize returns configured block size.
- func (m *KVCacheManager) BlockSize() int {
- return m.cfg.BlockSize
- }
- // Config returns the manager configuration.
- func (m *KVCacheManager) Config() ManagerConfig {
- return m.cfg
- }
|