| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126 |
- // Package kvcache implements a paged KV cache with global prefix caching.
- package kvcache
- import (
- "fmt"
- "math"
- "sync"
- "unsafe"
- "makarna/pkg/backend/cpu"
- "makarna/pkg/backend/cuda"
- "makarna/pkg/tensor"
- )
- func float32ToFloat16Bits(f float32) uint16 {
- bits := math.Float32bits(f)
- sign := uint16((bits >> 16) & 0x8000)
- exp := int((bits >> 23) & 0xFF)
- mant := bits & 0x007FFFFF
- // NaN / Inf
- if exp == 255 {
- if mant == 0 {
- return sign | 0x7C00
- }
- m := uint16(mant >> 13)
- if m == 0 {
- m = 1
- }
- return sign | 0x7C00 | m
- }
- // Bias adjust: float32 bias=127, float16 bias=15
- exp16 := exp - 127 + 15
- // Overflow -> Inf
- if exp16 >= 31 {
- return sign | 0x7C00
- }
- // Subnormal / underflow
- if exp16 <= 0 {
- if exp16 < -10 {
- return sign
- }
- // Add implicit leading 1.
- mant |= 0x00800000
- // Shift to 10-bit mantissa (plus 13-bit alignment), with round-to-nearest-even.
- shift := uint32(1-exp16) + 13
- m16 := mant >> shift
- rem := mant & ((uint32(1) << shift) - 1)
- half := uint32(1) << (shift - 1)
- if rem > half || (rem == half && (m16&1) == 1) {
- m16++
- }
- return sign | uint16(m16)
- }
- // Normalized: round mantissa to 10 bits (round-to-nearest-even).
- m16 := mant >> 13
- rem := mant & 0x1FFF
- if rem > 0x1000 || (rem == 0x1000 && (m16&1) == 1) {
- m16++
- if m16 == 0x400 {
- m16 = 0
- exp16++
- if exp16 >= 31 {
- return sign | 0x7C00
- }
- }
- }
- return sign | uint16(exp16<<10) | uint16(m16)
- }
- func float32ToBFloat16Bits(f float32) uint16 {
- bits := math.Float32bits(f)
- upper := uint16(bits >> 16)
- lower := uint16(bits & 0xFFFF)
- if lower > 0x8000 || (lower == 0x8000 && (upper&1) == 1) {
- upper++
- }
- return upper
- }
- // PagedKVCache is a paged KV cache that uses a global BlockPool.
- // Unlike the old Cache which allocated per-request, this shares blocks across
- // requests and enables prefix caching via block hashing.
- type PagedKVCache struct {
- pool *BlockPool
- cfg PagedCacheConfig
- mu sync.RWMutex
- // Per-request state
- requestID string
- tokenIDs []int
- // Block allocation per layer: blockIDs[layer] = list of block IDs
- blockIDs [][]int
- // How many tokens have computed KV
- numComputed int
- // How many tokens have written KV (may be ahead of numComputed within a step).
- numWritten int
- // Block hashes for prefix matching
- blockHashes []BlockHash
- // Whether this cache owns its blocks (vs borrowed from prefix cache)
- ownedBlocks [][]bool
- // Cached device pointer tables for paged attention (per layer).
- // These are (re)built only when the number of blocks grows.
- ptrTables []devicePtrTable
- }
- type devicePtrTable struct {
- kDev unsafe.Pointer
- vDev unsafe.Pointer
- len int
- gpu int
- kvType tensor.DType
- }
- func (c *PagedKVCache) clearPtrTablesLocked() {
- if c.ptrTables == nil {
- return
- }
- for i := range c.ptrTables {
- if c.ptrTables[i].kDev != nil {
- cuda.FreeDevicePtr(c.ptrTables[i].kDev)
- }
- if c.ptrTables[i].vDev != nil {
- cuda.FreeDevicePtr(c.ptrTables[i].vDev)
- }
- c.ptrTables[i] = devicePtrTable{}
- }
- }
- func (c *PagedKVCache) ensureCapacityLocked(requiredBlocks int) error {
- if requiredBlocks <= 0 {
- return nil
- }
- if len(c.blockIDs) == 0 {
- return fmt.Errorf("cache not initialized")
- }
- // All layers are expected to have the same number of blocks.
- cur := 0
- if len(c.blockIDs[0]) > 0 {
- cur = len(c.blockIDs[0])
- }
- if requiredBlocks <= cur {
- return nil
- }
- need := requiredBlocks - cur
- allocatedByLayer := make([][]int, c.cfg.NumLayers)
- for layer := 0; layer < c.cfg.NumLayers; layer++ {
- newBlocks, err := c.pool.AllocateBlocks(layer, need)
- if err != nil {
- // rollback
- for l := 0; l < layer; l++ {
- c.pool.FreeBlocks(l, allocatedByLayer[l])
- }
- return err
- }
- allocatedByLayer[layer] = newBlocks
- c.blockIDs[layer] = append(c.blockIDs[layer], newBlocks...)
- for range newBlocks {
- c.ownedBlocks[layer] = append(c.ownedBlocks[layer], true)
- }
- }
- return nil
- }
- // AppendToken updates the token history for this request.
- // This is used to extend block hashing/caching beyond the initial prompt.
- func (c *PagedKVCache) AppendToken(tokenID int) {
- c.mu.Lock()
- defer c.mu.Unlock()
- c.tokenIDs = append(c.tokenIDs, tokenID)
- // Only add hashes for fully completed blocks.
- bs := c.pool.cfg.BlockSize
- if bs <= 0 {
- return
- }
- if len(c.tokenIDs)%bs != 0 {
- return
- }
- blockIdx := len(c.tokenIDs)/bs - 1
- if blockIdx < 0 {
- return
- }
- // Ensure blockIDs can hold this block.
- _ = c.ensureCapacityLocked(blockIdx + 1)
- start := blockIdx * bs
- end := start + bs
- if end > len(c.tokenIDs) {
- end = len(c.tokenIDs)
- }
- var parent *BlockHash
- if blockIdx-1 >= 0 && blockIdx-1 < len(c.blockHashes) {
- parent = &c.blockHashes[blockIdx-1]
- }
- h := ComputeBlockHash(c.tokenIDs[start:end], parent)
- // Extend or overwrite
- if blockIdx < len(c.blockHashes) {
- c.blockHashes[blockIdx] = h
- } else {
- c.blockHashes = append(c.blockHashes, h)
- }
- }
- // PagedCacheConfig configures a paged KV cache.
- type PagedCacheConfig struct {
- NumLayers int
- NumKVHeads int
- HeadDim int
- BlockSize int
- MaxSeqLen int
- Device tensor.DeviceType
- GPU int
- }
- // NewPagedKVCache creates a new paged cache backed by the given block pool.
- func NewPagedKVCache(pool *BlockPool, cfg PagedCacheConfig, requestID string) *PagedKVCache {
- return &PagedKVCache{
- pool: pool,
- cfg: cfg,
- requestID: requestID,
- blockIDs: make([][]int, cfg.NumLayers),
- ownedBlocks: make([][]bool, cfg.NumLayers),
- }
- }
- // AllocateForTokens allocates blocks for the given token sequence.
- // Returns number of tokens that are already cached (prefix hit).
- func (c *PagedKVCache) AllocateForTokens(tokens []int) (int, error) {
- c.mu.Lock()
- defer c.mu.Unlock()
- if c.cfg.MaxSeqLen > 0 && len(tokens) > c.cfg.MaxSeqLen {
- return 0, fmt.Errorf("prompt length %d exceeds MaxSeqLen %d", len(tokens), c.cfg.MaxSeqLen)
- }
- c.tokenIDs = tokens
- c.blockHashes = ComputeBlockHashes(tokens, c.pool.cfg.BlockSize)
- c.clearPtrTablesLocked()
- numBlocks := len(c.blockHashes)
- if numBlocks == 0 {
- return 0, nil
- }
- // Find cached prefix in layer 0 (all layers have same structure)
- cachedBlockIDs, cachedTokens := c.pool.FindCachedBlocks(0, c.blockHashes)
- numCachedBlocks := len(cachedBlockIDs)
- // Allocate blocks for all layers
- for layer := 0; layer < c.cfg.NumLayers; layer++ {
- c.blockIDs[layer] = make([]int, 0, numBlocks)
- c.ownedBlocks[layer] = make([]bool, 0, numBlocks)
- // First, get cached blocks for this layer
- if numCachedBlocks > 0 {
- layerCached, _ := c.pool.FindCachedBlocks(layer, c.blockHashes[:numCachedBlocks])
- c.pool.TouchBlocks(layer, layerCached)
- c.blockIDs[layer] = append(c.blockIDs[layer], layerCached...)
- for range layerCached {
- c.ownedBlocks[layer] = append(c.ownedBlocks[layer], false) // borrowed
- }
- }
- // Allocate new blocks for uncached portion
- numNewBlocks := numBlocks - numCachedBlocks
- if numNewBlocks > 0 {
- newBlocks, err := c.pool.AllocateBlocks(layer, numNewBlocks)
- if err != nil {
- // Rollback on failure
- for l := 0; l < layer; l++ {
- c.pool.FreeBlocks(l, c.blockIDs[l])
- }
- return 0, fmt.Errorf("layer %d allocation failed: %w", layer, err)
- }
- c.blockIDs[layer] = append(c.blockIDs[layer], newBlocks...)
- for range newBlocks {
- c.ownedBlocks[layer] = append(c.ownedBlocks[layer], true) // owned
- }
- }
- }
- // Don't cache the very last token so we always compute logits for the final position.
- c.numComputed = cachedTokens
- c.numWritten = c.numComputed
- if c.numComputed >= len(tokens) {
- c.numComputed = len(tokens) - 1
- if c.numComputed < 0 {
- c.numComputed = 0
- }
- c.numWritten = c.numComputed
- }
- return c.numComputed, nil
- }
- // LayerDevicePtrTables returns device pointers to contiguous pointer tables for this layer.
- // The returned pointers live until the cache is freed (or reallocated) and are updated only
- // when the required number of blocks increases.
- func (c *PagedKVCache) LayerDevicePtrTables(layer int, numBlocks int) (unsafe.Pointer, unsafe.Pointer, int, tensor.DType, error) {
- c.mu.Lock()
- defer c.mu.Unlock()
- if c.pool == nil {
- return nil, nil, 0, 0, fmt.Errorf("cache not initialized")
- }
- if layer < 0 || layer >= len(c.blockIDs) {
- return nil, nil, 0, 0, fmt.Errorf("invalid layer %d", layer)
- }
- if numBlocks <= 0 {
- return nil, nil, 0, 0, fmt.Errorf("numBlocks must be > 0")
- }
- layerDev := c.pool.LayerDevice(layer).Normalize()
- if layerDev.Type != tensor.CUDA {
- return nil, nil, 0, 0, fmt.Errorf("layer %d is not on CUDA (got %v)", layer, layerDev)
- }
- if numBlocks > len(c.blockIDs[layer]) {
- numBlocks = len(c.blockIDs[layer])
- }
- if numBlocks <= 0 {
- return nil, nil, 0, 0, fmt.Errorf("no blocks")
- }
- if c.ptrTables == nil || len(c.ptrTables) != c.cfg.NumLayers {
- c.ptrTables = make([]devicePtrTable, c.cfg.NumLayers)
- }
- pt := &c.ptrTables[layer]
- if pt.kDev != nil && pt.vDev != nil && pt.len == numBlocks && pt.gpu == layerDev.GPU {
- return pt.kDev, pt.vDev, c.pool.cfg.BlockSize, pt.kvType, nil
- }
- // Rebuild pointer tables on growth or device mismatch.
- kPtrs := make([]uintptr, numBlocks)
- vPtrs := make([]uintptr, numBlocks)
- var kvType tensor.DType
- for i := 0; i < numBlocks; i++ {
- bid := c.blockIDs[layer][i]
- b := c.pool.GetBlock(layer, bid)
- if b == nil {
- return nil, nil, 0, 0, fmt.Errorf("block %d not found", bid)
- }
- kT, ok := b.K.(*cuda.Tensor)
- if !ok {
- return nil, nil, 0, 0, fmt.Errorf("block K is not CUDA tensor")
- }
- vT, ok := b.V.(*cuda.Tensor)
- if !ok {
- return nil, nil, 0, 0, fmt.Errorf("block V is not CUDA tensor")
- }
- if kvType == 0 {
- kvType = kT.DType()
- }
- if kT.DType() != kvType || vT.DType() != kvType {
- return nil, nil, 0, 0, fmt.Errorf("mixed KV dtypes in blocks")
- }
- kPtrs[i] = uintptr(kT.Data().(unsafe.Pointer))
- vPtrs[i] = uintptr(vT.Data().(unsafe.Pointer))
- }
- kDev, err := cuda.AllocAndCopyPtrTable(kPtrs, layerDev.GPU)
- if err != nil {
- return nil, nil, 0, 0, err
- }
- vDev, err := cuda.AllocAndCopyPtrTable(vPtrs, layerDev.GPU)
- if err != nil {
- cuda.FreeDevicePtr(kDev)
- return nil, nil, 0, 0, err
- }
- if pt.kDev != nil {
- cuda.FreeDevicePtr(pt.kDev)
- }
- if pt.vDev != nil {
- cuda.FreeDevicePtr(pt.vDev)
- }
- pt.kDev = kDev
- pt.vDev = vDev
- pt.len = numBlocks
- pt.gpu = layerDev.GPU
- pt.kvType = kvType
- return pt.kDev, pt.vDev, c.pool.cfg.BlockSize, pt.kvType, nil
- }
- func writePackedBlockU32(dstK, dstV []float32, numKVHeads, headDim, blockSize, blockOffset, writeCount int, kSrc, vSrc []float32) {
- // kSrc/vSrc are token-major: [t][kvHead][d]
- // dstK/dstV are head-major: [kvHead][t][d]
- for t := 0; t < writeCount; t++ {
- baseTok := t * (numKVHeads * headDim)
- dstTok := blockOffset + t
- for h := 0; h < numKVHeads; h++ {
- srcBase := baseTok + h*headDim
- dstBase := h*(blockSize*headDim) + dstTok*headDim
- copy(dstK[dstBase:dstBase+headDim], kSrc[srcBase:srcBase+headDim])
- copy(dstV[dstBase:dstBase+headDim], vSrc[srcBase:srcBase+headDim])
- }
- }
- }
- // GetBlockForPosition returns the KV block for writing at a token position.
- func (c *PagedKVCache) GetBlockForPosition(layer, tokenPos int) *KVBlock {
- c.mu.RLock()
- defer c.mu.RUnlock()
- if layer < 0 || layer >= len(c.blockIDs) {
- return nil
- }
- blockIdx := tokenPos / c.pool.cfg.BlockSize
- if blockIdx < 0 || blockIdx >= len(c.blockIDs[layer]) {
- return nil
- }
- blockID := c.blockIDs[layer][blockIdx]
- return c.pool.GetBlock(layer, blockID)
- }
- // GetBlockOffset returns the offset within a block for a token position.
- func (c *PagedKVCache) GetBlockOffset(tokenPos int) int {
- return tokenPos % c.pool.cfg.BlockSize
- }
- func (c *PagedKVCache) LayerBlockIDs(layer int, numBlocks int) []int {
- c.mu.RLock()
- defer c.mu.RUnlock()
- if layer < 0 || layer >= len(c.blockIDs) {
- return nil
- }
- if numBlocks <= 0 {
- return nil
- }
- if numBlocks > len(c.blockIDs[layer]) {
- numBlocks = len(c.blockIDs[layer])
- }
- out := make([]int, numBlocks)
- copy(out, c.blockIDs[layer][:numBlocks])
- return out
- }
- func (c *PagedKVCache) LayerBlockPtrTables(layer int, numBlocks int) ([]uintptr, []uintptr, int, tensor.DType, error) {
- if numBlocks <= 0 {
- return nil, nil, 0, 0, fmt.Errorf("numBlocks must be > 0")
- }
- blockIDs := c.LayerBlockIDs(layer, numBlocks)
- if len(blockIDs) == 0 {
- return nil, nil, 0, 0, fmt.Errorf("no blocks")
- }
- kPtrs := make([]uintptr, len(blockIDs))
- vPtrs := make([]uintptr, len(blockIDs))
- var kvType tensor.DType
- for i, bid := range blockIDs {
- b := c.pool.GetBlock(layer, bid)
- if b == nil {
- return nil, nil, 0, 0, fmt.Errorf("block %d not found", bid)
- }
- kT, ok := b.K.(*cuda.Tensor)
- if !ok {
- return nil, nil, 0, 0, fmt.Errorf("block K is not CUDA tensor")
- }
- vT, ok := b.V.(*cuda.Tensor)
- if !ok {
- return nil, nil, 0, 0, fmt.Errorf("block V is not CUDA tensor")
- }
- if kvType == 0 {
- kvType = kT.DType()
- }
- if kT.DType() != kvType || vT.DType() != kvType {
- return nil, nil, 0, 0, fmt.Errorf("mixed KV dtypes in blocks")
- }
- kPtrs[i] = uintptr(kT.Data().(unsafe.Pointer))
- vPtrs[i] = uintptr(vT.Data().(unsafe.Pointer))
- }
- return kPtrs, vPtrs, c.pool.cfg.BlockSize, kvType, nil
- }
- // Append writes new K/V tokens into the cache for a layer.
- // Implements KVCacheInterface.
- func (c *PagedKVCache) Append(layer int, k, v tensor.Tensor) ([]View, int, error) {
- c.mu.Lock()
- defer c.mu.Unlock()
- if layer < 0 || layer >= len(c.blockIDs) {
- return nil, 0, fmt.Errorf("invalid layer %d", layer)
- }
- layerPlacement := c.LayerDevice(layer).Normalize()
- startPos := c.numComputed
- newTokens := k.Shape()[0]
- kvDim := k.Shape()[1]
- endPos := startPos + newTokens
- if c.cfg.MaxSeqLen > 0 && endPos > c.cfg.MaxSeqLen {
- return nil, 0, fmt.Errorf("KV cache overflow: need %d tokens (start %d + new %d) > MaxSeqLen %d", endPos, startPos, newTokens, c.cfg.MaxSeqLen)
- }
- if layerPlacement.Type == tensor.CUDA {
- kSrc, kIsCUDA := k.(*cuda.Tensor)
- vSrc, vIsCUDA := v.(*cuda.Tensor)
- if !kIsCUDA || !vIsCUDA {
- return nil, 0, fmt.Errorf("PagedKVCache layer %d requires CUDA tensors, got %T/%T", layer, k, v)
- }
- if layerPlacement.GPU >= 0 {
- if kSrc.GPU() != layerPlacement.GPU || vSrc.GPU() != layerPlacement.GPU {
- return nil, 0, fmt.Errorf("PagedKVCache layer %d on GPU %d, got k/v on GPU %d/%d", layer, layerPlacement.GPU, kSrc.GPU(), vSrc.GPU())
- }
- }
- // Ensure we have enough blocks to cover the write range.
- requiredBlocks := (endPos + c.pool.cfg.BlockSize - 1) / c.pool.cfg.BlockSize
- if err := c.ensureCapacityLocked(requiredBlocks); err != nil {
- return nil, 0, fmt.Errorf("ensure capacity: %w", err)
- }
- written := 0
- for written < newTokens {
- globalPos := startPos + written
- blockIdx := globalPos / c.pool.cfg.BlockSize
- blockOffset := globalPos % c.pool.cfg.BlockSize
- // Capacity is ensured above.
- block := c.pool.GetBlock(layer, c.blockIDs[layer][blockIdx])
- if block == nil {
- return nil, 0, fmt.Errorf("block %d not found", c.blockIDs[layer][blockIdx])
- }
- // How many tokens can we write to this block?
- writeCount := min(newTokens-written, c.pool.cfg.BlockSize-blockOffset)
- // Copy K/V to block
- srcStart := written * kvDim
- dstStart := blockOffset * kvDim
- length := writeCount * kvDim
- kDst, ok := block.K.(*cuda.Tensor)
- if !ok {
- return nil, 0, fmt.Errorf("block K is not CUDA tensor")
- }
- vDst, ok := block.V.(*cuda.Tensor)
- if !ok {
- return nil, 0, fmt.Errorf("block V is not CUDA tensor")
- }
- if kDst.DType() == tensor.Float16 && kSrc.DType() == tensor.Float32 {
- srcPtr := unsafe.Pointer(uintptr(kSrc.Data().(unsafe.Pointer)) + uintptr(srcStart*4))
- dstPtr := unsafe.Pointer(uintptr(kDst.Data().(unsafe.Pointer)) + uintptr(dstStart*2))
- if err := cuda.CastF32ToF16(srcPtr, dstPtr, length, kDst.GPU()); err != nil {
- return nil, 0, fmt.Errorf("cast K f32->f16 failed: %w", err)
- }
- } else {
- if err := kDst.CopyPartialFromDevice(dstStart, kSrc, srcStart, length); err != nil {
- return nil, 0, fmt.Errorf("copy K failed: %w", err)
- }
- }
- if vDst.DType() == tensor.Float16 && vSrc.DType() == tensor.Float32 {
- srcPtr := unsafe.Pointer(uintptr(vSrc.Data().(unsafe.Pointer)) + uintptr(srcStart*4))
- dstPtr := unsafe.Pointer(uintptr(vDst.Data().(unsafe.Pointer)) + uintptr(dstStart*2))
- if err := cuda.CastF32ToF16(srcPtr, dstPtr, length, vDst.GPU()); err != nil {
- return nil, 0, fmt.Errorf("cast V f32->f16 failed: %w", err)
- }
- } else {
- if err := vDst.CopyPartialFromDevice(dstStart, vSrc, srcStart, length); err != nil {
- return nil, 0, fmt.Errorf("copy V failed: %w", err)
- }
- }
- written += writeCount
- }
- if endPos > c.numWritten {
- c.numWritten = endPos
- }
- return c.viewsLockedAt(layer, endPos), startPos, nil
- }
- kSrc, kIsCPU := k.(*cpu.Tensor)
- vSrc, vIsCPU := v.(*cpu.Tensor)
- if !kIsCPU || !vIsCPU {
- return nil, 0, fmt.Errorf("PagedKVCache layer %d requires CPU tensors, got %T/%T", layer, k, v)
- }
- // Ensure we have enough blocks to cover the write range.
- requiredBlocks := (endPos + c.pool.cfg.BlockSize - 1) / c.pool.cfg.BlockSize
- if err := c.ensureCapacityLocked(requiredBlocks); err != nil {
- return nil, 0, fmt.Errorf("ensure capacity: %w", err)
- }
- var kF32, vF32 []float32
- if kSrc.DType() == tensor.Float32 {
- kF32 = kSrc.DataFloat32()
- }
- if vSrc.DType() == tensor.Float32 {
- vF32 = vSrc.DataFloat32()
- }
- written := 0
- for written < newTokens {
- globalPos := startPos + written
- blockIdx := globalPos / c.pool.cfg.BlockSize
- blockOffset := globalPos % c.pool.cfg.BlockSize
- // Capacity is ensured above.
- block := c.pool.GetBlock(layer, c.blockIDs[layer][blockIdx])
- if block == nil {
- return nil, 0, fmt.Errorf("block %d not found", c.blockIDs[layer][blockIdx])
- }
- // How many tokens can we write to this block?
- writeCount := min(newTokens-written, c.pool.cfg.BlockSize-blockOffset)
- // Copy K/V to block
- srcStart := written * kvDim
- dstStart := blockOffset * kvDim
- length := writeCount * kvDim
- kDst, ok := block.K.(*cpu.Tensor)
- if !ok {
- return nil, 0, fmt.Errorf("block K is not CPU tensor")
- }
- vDst, ok := block.V.(*cpu.Tensor)
- if !ok {
- return nil, 0, fmt.Errorf("block V is not CPU tensor")
- }
- switch kDst.DType() {
- case tensor.Float16:
- kOut := kDst.DataUint16()[dstStart : dstStart+length]
- switch kSrc.DType() {
- case tensor.Float32:
- kIn := kSrc.DataFloat32()[srcStart : srcStart+length]
- for i := 0; i < length; i++ {
- kOut[i] = float32ToFloat16Bits(kIn[i])
- }
- case tensor.Float16:
- copy(kOut, kSrc.DataUint16()[srcStart:srcStart+length])
- default:
- return nil, 0, fmt.Errorf("unsupported CPU K src dtype: %v", kSrc.DType())
- }
- case tensor.BFloat16:
- kOut := kDst.DataUint16()[dstStart : dstStart+length]
- switch kSrc.DType() {
- case tensor.Float32:
- kIn := kSrc.DataFloat32()[srcStart : srcStart+length]
- for i := 0; i < length; i++ {
- kOut[i] = float32ToBFloat16Bits(kIn[i])
- }
- case tensor.BFloat16:
- copy(kOut, kSrc.DataUint16()[srcStart:srcStart+length])
- default:
- return nil, 0, fmt.Errorf("unsupported CPU K src dtype: %v", kSrc.DType())
- }
- default:
- return nil, 0, fmt.Errorf("unsupported CPU K dst dtype: %v", kDst.DType())
- }
- switch vDst.DType() {
- case tensor.Float16:
- vOut := vDst.DataUint16()[dstStart : dstStart+length]
- switch vSrc.DType() {
- case tensor.Float32:
- vIn := vSrc.DataFloat32()[srcStart : srcStart+length]
- for i := 0; i < length; i++ {
- vOut[i] = float32ToFloat16Bits(vIn[i])
- }
- case tensor.Float16:
- copy(vOut, vSrc.DataUint16()[srcStart:srcStart+length])
- default:
- return nil, 0, fmt.Errorf("unsupported CPU V src dtype: %v", vSrc.DType())
- }
- case tensor.BFloat16:
- vOut := vDst.DataUint16()[dstStart : dstStart+length]
- switch vSrc.DType() {
- case tensor.Float32:
- vIn := vSrc.DataFloat32()[srcStart : srcStart+length]
- for i := 0; i < length; i++ {
- vOut[i] = float32ToBFloat16Bits(vIn[i])
- }
- case tensor.BFloat16:
- copy(vOut, vSrc.DataUint16()[srcStart:srcStart+length])
- default:
- return nil, 0, fmt.Errorf("unsupported CPU V src dtype: %v", vSrc.DType())
- }
- default:
- return nil, 0, fmt.Errorf("unsupported CPU V dst dtype: %v", vDst.DType())
- }
- // Populate packed CPU layout when available and source is float32.
- if block.pk != nil && block.pv != nil && kF32 != nil && vF32 != nil {
- srcStartTok := written * kvDim
- srcEndTok := srcStartTok + length
- writePackedBlockU32(block.pk, block.pv, c.cfg.NumKVHeads, c.cfg.HeadDim, c.pool.cfg.BlockSize, blockOffset, writeCount, kF32[srcStartTok:srcEndTok], vF32[srcStartTok:srcEndTok])
- }
- written += writeCount
- }
- if endPos > c.numWritten {
- c.numWritten = endPos
- }
- return c.viewsLockedAt(layer, endPos), startPos, nil
- }
- // Views returns the live KV block views for a layer.
- func (c *PagedKVCache) Views(layer int) []View {
- c.mu.RLock()
- defer c.mu.RUnlock()
- return c.viewsLockedAt(layer, c.numComputed)
- }
- func (c *PagedKVCache) ViewsPacked(layer int) []PackedView {
- c.mu.RLock()
- defer c.mu.RUnlock()
- if c.LayerDevice(layer).Type != tensor.CPU {
- return nil
- }
- computed := c.numWritten
- if computed <= 0 {
- return nil
- }
- if layer < 0 || layer >= len(c.blockIDs) {
- return nil
- }
- numBlocks := (computed + c.pool.cfg.BlockSize - 1) / c.pool.cfg.BlockSize
- views := make([]PackedView, 0, numBlocks)
- for i := 0; i < numBlocks && i < len(c.blockIDs[layer]); i++ {
- block := c.pool.GetBlock(layer, c.blockIDs[layer][i])
- if block == nil || block.pk == nil || block.pv == nil {
- continue
- }
- start := i * c.pool.cfg.BlockSize
- length := c.pool.cfg.BlockSize
- if start+length > computed {
- length = computed - start
- }
- if length <= 0 {
- continue
- }
- views = append(views, PackedView{
- K: block.pk,
- V: block.pv,
- Start: start,
- Length: length,
- BlockSize: c.pool.cfg.BlockSize,
- HeadDim: c.cfg.HeadDim,
- NumKVHeads: c.cfg.NumKVHeads,
- })
- }
- return views
- }
- // viewsLocked returns live views for the currently committed KV length.
- // Kept for internal call sites.
- func (c *PagedKVCache) viewsLocked(layer int) []View {
- return c.viewsLockedAt(layer, c.numComputed)
- }
- func (c *PagedKVCache) viewsLockedAt(layer int, computed int) []View {
- if layer < 0 || layer >= len(c.blockIDs) {
- return nil
- }
- if computed <= 0 {
- return nil
- }
- numBlocks := (computed + c.pool.cfg.BlockSize - 1) / c.pool.cfg.BlockSize
- views := make([]View, 0, numBlocks)
- layerDev := c.pool.LayerDevice(layer).Normalize()
- for i := 0; i < numBlocks && i < len(c.blockIDs[layer]); i++ {
- block := c.pool.GetBlock(layer, c.blockIDs[layer][i])
- if block == nil {
- continue
- }
- start := i * c.pool.cfg.BlockSize
- length := c.pool.cfg.BlockSize
- if start+length > computed {
- length = computed - start
- }
- if length <= 0 {
- continue
- }
- views = append(views, View{
- K: block.K,
- V: block.V,
- Start: start,
- Length: length,
- Device: layerDev.Type,
- GPU: layerDev.GPU,
- })
- }
- return views
- }
- // AppendKV writes K/V values for new tokens into the cache.
- // This is the main write path for prefill and decode.
- func (c *PagedKVCache) AppendKV(layer int, k, v tensor.Tensor, startPos int) error {
- c.mu.Lock()
- defer c.mu.Unlock()
- if layer < 0 || layer >= len(c.blockIDs) {
- return fmt.Errorf("invalid layer %d", layer)
- }
- layerPlacement := c.LayerDevice(layer).Normalize()
- newTokens := k.Shape()[0]
- kvDim := k.Shape()[1]
- if c.cfg.MaxSeqLen > 0 && startPos+newTokens > c.cfg.MaxSeqLen {
- return fmt.Errorf("KV cache overflow: need %d tokens (start %d + new %d) > MaxSeqLen %d", startPos+newTokens, startPos, newTokens, c.cfg.MaxSeqLen)
- }
- if layerPlacement.Type == tensor.CUDA {
- kSrc, kIsCUDA := k.(*cuda.Tensor)
- vSrc, vIsCUDA := v.(*cuda.Tensor)
- if !kIsCUDA || !vIsCUDA {
- return fmt.Errorf("PagedKVCache layer %d requires CUDA tensors, got %T/%T", layer, k, v)
- }
- written := 0
- for written < newTokens {
- globalPos := startPos + written
- blockIdx := globalPos / c.pool.cfg.BlockSize
- blockOffset := globalPos % c.pool.cfg.BlockSize
- if blockIdx >= len(c.blockIDs[layer]) {
- return fmt.Errorf("block index %d out of range", blockIdx)
- }
- block := c.pool.GetBlock(layer, c.blockIDs[layer][blockIdx])
- if block == nil {
- return fmt.Errorf("block %d not found", c.blockIDs[layer][blockIdx])
- }
- // How many tokens can we write to this block?
- writeCount := min(newTokens-written, c.pool.cfg.BlockSize-blockOffset)
- // Copy K/V to block
- srcStart := written * kvDim
- dstStart := blockOffset * kvDim
- length := writeCount * kvDim
- kDst, ok := block.K.(*cuda.Tensor)
- if !ok {
- return fmt.Errorf("block K is not CUDA tensor")
- }
- vDst, ok := block.V.(*cuda.Tensor)
- if !ok {
- return fmt.Errorf("block V is not CUDA tensor")
- }
- if kDst.DType() == tensor.Float16 && kSrc.DType() == tensor.Float32 {
- srcPtr := unsafe.Pointer(uintptr(kSrc.Data().(unsafe.Pointer)) + uintptr(srcStart*4))
- dstPtr := unsafe.Pointer(uintptr(kDst.Data().(unsafe.Pointer)) + uintptr(dstStart*2))
- if err := cuda.CastF32ToF16(srcPtr, dstPtr, length, kDst.GPU()); err != nil {
- return fmt.Errorf("cast K f32->f16 failed: %w", err)
- }
- } else {
- if err := kDst.CopyPartialFromDevice(dstStart, kSrc, srcStart, length); err != nil {
- return fmt.Errorf("copy K failed: %w", err)
- }
- }
- if vDst.DType() == tensor.Float16 && vSrc.DType() == tensor.Float32 {
- srcPtr := unsafe.Pointer(uintptr(vSrc.Data().(unsafe.Pointer)) + uintptr(srcStart*4))
- dstPtr := unsafe.Pointer(uintptr(vDst.Data().(unsafe.Pointer)) + uintptr(dstStart*2))
- if err := cuda.CastF32ToF16(srcPtr, dstPtr, length, vDst.GPU()); err != nil {
- return fmt.Errorf("cast V f32->f16 failed: %w", err)
- }
- } else {
- if err := vDst.CopyPartialFromDevice(dstStart, vSrc, srcStart, length); err != nil {
- return fmt.Errorf("copy V failed: %w", err)
- }
- }
- written += writeCount
- }
- return nil
- }
- kSrc, kIsCPU := k.(*cpu.Tensor)
- vSrc, vIsCPU := v.(*cpu.Tensor)
- if !kIsCPU || !vIsCPU {
- return fmt.Errorf("PagedKVCache layer %d requires CPU tensors, got %T/%T", layer, k, v)
- }
- written := 0
- for written < newTokens {
- globalPos := startPos + written
- blockIdx := globalPos / c.pool.cfg.BlockSize
- blockOffset := globalPos % c.pool.cfg.BlockSize
- if blockIdx >= len(c.blockIDs[layer]) {
- return fmt.Errorf("block index %d out of range", blockIdx)
- }
- block := c.pool.GetBlock(layer, c.blockIDs[layer][blockIdx])
- if block == nil {
- return fmt.Errorf("block %d not found", c.blockIDs[layer][blockIdx])
- }
- // How many tokens can we write to this block?
- writeCount := min(newTokens-written, c.pool.cfg.BlockSize-blockOffset)
- // Copy K/V to block
- srcStart := written * kvDim
- dstStart := blockOffset * kvDim
- length := writeCount * kvDim
- kDst, ok := block.K.(*cpu.Tensor)
- if !ok {
- return fmt.Errorf("block K is not CPU tensor")
- }
- vDst, ok := block.V.(*cpu.Tensor)
- if !ok {
- return fmt.Errorf("block V is not CPU tensor")
- }
- switch kDst.DType() {
- case tensor.Float16:
- kOut := kDst.DataUint16()[dstStart : dstStart+length]
- switch kSrc.DType() {
- case tensor.Float32:
- kIn := kSrc.DataFloat32()[srcStart : srcStart+length]
- for i := 0; i < length; i++ {
- kOut[i] = float32ToFloat16Bits(kIn[i])
- }
- case tensor.Float16:
- copy(kOut, kSrc.DataUint16()[srcStart:srcStart+length])
- default:
- return fmt.Errorf("unsupported CPU K src dtype: %v", kSrc.DType())
- }
- case tensor.BFloat16:
- kOut := kDst.DataUint16()[dstStart : dstStart+length]
- switch kSrc.DType() {
- case tensor.Float32:
- kIn := kSrc.DataFloat32()[srcStart : srcStart+length]
- for i := 0; i < length; i++ {
- kOut[i] = float32ToBFloat16Bits(kIn[i])
- }
- case tensor.BFloat16:
- copy(kOut, kSrc.DataUint16()[srcStart:srcStart+length])
- default:
- return fmt.Errorf("unsupported CPU K src dtype: %v", kSrc.DType())
- }
- default:
- return fmt.Errorf("unsupported CPU K dst dtype: %v", kDst.DType())
- }
- switch vDst.DType() {
- case tensor.Float16:
- vOut := vDst.DataUint16()[dstStart : dstStart+length]
- switch vSrc.DType() {
- case tensor.Float32:
- vIn := vSrc.DataFloat32()[srcStart : srcStart+length]
- for i := 0; i < length; i++ {
- vOut[i] = float32ToFloat16Bits(vIn[i])
- }
- case tensor.Float16:
- copy(vOut, vSrc.DataUint16()[srcStart:srcStart+length])
- default:
- return fmt.Errorf("unsupported CPU V src dtype: %v", vSrc.DType())
- }
- case tensor.BFloat16:
- vOut := vDst.DataUint16()[dstStart : dstStart+length]
- switch vSrc.DType() {
- case tensor.Float32:
- vIn := vSrc.DataFloat32()[srcStart : srcStart+length]
- for i := 0; i < length; i++ {
- vOut[i] = float32ToBFloat16Bits(vIn[i])
- }
- case tensor.BFloat16:
- copy(vOut, vSrc.DataUint16()[srcStart:srcStart+length])
- default:
- return fmt.Errorf("unsupported CPU V src dtype: %v", vSrc.DType())
- }
- default:
- return fmt.Errorf("unsupported CPU V dst dtype: %v", vDst.DType())
- }
- written += writeCount
- }
- return nil
- }
- // Commit marks tokens as computed and caches completed blocks.
- func (c *PagedKVCache) Commit(newTokens int) {
- c.mu.Lock()
- defer c.mu.Unlock()
- oldComputed := c.numComputed
- c.numComputed += newTokens
- if c.numWritten < c.numComputed {
- c.numWritten = c.numComputed
- }
- // Cache newly completed blocks (only those we have hashes for).
- oldBlocks := oldComputed / c.pool.cfg.BlockSize
- newBlocks := c.numComputed / c.pool.cfg.BlockSize
- maxHashBlocks := len(c.blockHashes)
- if newBlocks > maxHashBlocks {
- newBlocks = maxHashBlocks
- }
- if newBlocks > oldBlocks {
- for layer := 0; layer < c.cfg.NumLayers; layer++ {
- blockIDs := c.blockIDs[layer][oldBlocks:newBlocks]
- hashes := c.blockHashes[oldBlocks:newBlocks]
- c.pool.CacheBlocks(layer, blockIDs, hashes)
- }
- }
- }
- // SeqLen returns the number of computed tokens.
- func (c *PagedKVCache) SeqLen() int {
- c.mu.RLock()
- defer c.mu.RUnlock()
- return c.numComputed
- }
- // ContiguousKV returns a contiguous view of K/V for attention.
- // For paged cache, this gathers from blocks into a contiguous buffer.
- func (c *PagedKVCache) ContiguousKV(layer, kvLen, kvDim int) (tensor.Tensor, tensor.Tensor, bool, error) {
- c.mu.RLock()
- defer c.mu.RUnlock()
- // NOTE: PagedKVCache does not currently provide a contiguous K/V view.
- // Building one by allocating [kvLen, kvDim] tensors per layer causes large
- // transient CUDA allocations and can trigger GPU OOM under load.
- // Callers should fall back to Views()+concatKVOnDevice.
- return nil, nil, false, nil
- }
- // Free releases all blocks back to the pool.
- func (c *PagedKVCache) Free() {
- c.mu.Lock()
- defer c.mu.Unlock()
- c.clearPtrTablesLocked()
- for layer := 0; layer < len(c.blockIDs); layer++ {
- c.pool.FreeBlocks(layer, c.blockIDs[layer])
- }
- c.blockIDs = nil
- c.ownedBlocks = nil
- c.tokenIDs = nil
- c.blockHashes = nil
- }
- // RequestID returns the request ID.
- func (c *PagedKVCache) RequestID() string {
- return c.requestID
- }
- // NumTokens returns total tokens in sequence.
- func (c *PagedKVCache) NumTokens() int {
- c.mu.RLock()
- defer c.mu.RUnlock()
- return len(c.tokenIDs)
- }
- // BlockSize returns the block size.
- func (c *PagedKVCache) BlockSize() int {
- return c.pool.cfg.BlockSize
- }
- // Truncate rewinds the cache to a specific sequence length.
- func (c *PagedKVCache) Truncate(seqLen int) {
- c.mu.Lock()
- defer c.mu.Unlock()
- if seqLen < 0 {
- seqLen = 0
- }
- if seqLen >= c.numComputed {
- return
- }
- c.numComputed = seqLen
- if c.numWritten > c.numComputed {
- c.numWritten = c.numComputed
- }
- }
- // LayerDevice returns the device placement for a layer.
- func (c *PagedKVCache) LayerDevice(layer int) tensor.DevicePlacement {
- if c.pool != nil {
- return c.pool.LayerDevice(layer)
- }
- return tensor.DevicePlacement{Type: c.cfg.Device, GPU: c.cfg.GPU}
- }
- // MaxSeqLen returns the maximum sequence length.
- func (c *PagedKVCache) MaxSeqLen() int {
- return c.cfg.MaxSeqLen
- }
- // IsOnGPU returns true if the cache is on GPU.
- func (c *PagedKVCache) IsOnGPU() bool {
- if c == nil {
- return false
- }
- if c.pool != nil {
- for i := 0; i < c.cfg.NumLayers; i++ {
- if c.pool.LayerDevice(i).Type == tensor.CUDA {
- return true
- }
- }
- return false
- }
- return c.cfg.Device == tensor.CUDA
- }
|