package kimi_linear import ( "fmt" "makarna/pkg/backend/cpu" "makarna/pkg/backend/device" "makarna/pkg/model" "makarna/pkg/tensor" ) type KimiCache struct { numLayers int kdaNumHeads int kdaHeadDim int kdaConvKernel int mlaNumHeads int mlaKHeadDim int mlaVHeadDim int seqLen int recurrent []tensor.Tensor convQ []tensor.Tensor convK []tensor.Tensor convV []tensor.Tensor fullKBuf []*cpu.Tensor fullVBuf []*cpu.Tensor fullLen []int committed int } func NewKimiCache(numLayers int, kdaNumHeads int, kdaHeadDim int, kdaConvKernel int, mlaNumHeads int, mlaKHeadDim int, mlaVHeadDim int) (*KimiCache, error) { if numLayers <= 0 { return nil, fmt.Errorf("kimi_linear: invalid numLayers %d", numLayers) } if kdaNumHeads <= 0 || kdaHeadDim <= 0 { return nil, fmt.Errorf("kimi_linear: invalid kda heads/dim %d/%d", kdaNumHeads, kdaHeadDim) } if kdaConvKernel <= 0 { return nil, fmt.Errorf("kimi_linear: invalid kda conv kernel %d", kdaConvKernel) } if mlaNumHeads <= 0 || mlaKHeadDim <= 0 || mlaVHeadDim <= 0 { return nil, fmt.Errorf("kimi_linear: invalid mla heads/dim %d/%d/%d", mlaNumHeads, mlaKHeadDim, mlaVHeadDim) } return &KimiCache{ numLayers: numLayers, kdaNumHeads: kdaNumHeads, kdaHeadDim: kdaHeadDim, kdaConvKernel: kdaConvKernel, mlaNumHeads: mlaNumHeads, mlaKHeadDim: mlaKHeadDim, mlaVHeadDim: mlaVHeadDim, recurrent: make([]tensor.Tensor, numLayers), convQ: make([]tensor.Tensor, numLayers), convK: make([]tensor.Tensor, numLayers), convV: make([]tensor.Tensor, numLayers), fullKBuf: make([]*cpu.Tensor, numLayers), fullVBuf: make([]*cpu.Tensor, numLayers), fullLen: make([]int, numLayers), }, nil } func (c *KimiCache) SeqLen() int { if c == nil { return 0 } return c.seqLen } func (c *KimiCache) Commit(newTokens int) { if c == nil { return } c.committed += newTokens c.seqLen += newTokens } func (c *KimiCache) RecurrentState(layer int, placement tensor.DevicePlacement) (tensor.Tensor, error) { if layer < 0 || layer >= c.numLayers { return nil, fmt.Errorf("kimi_linear: recurrent state layer out of range: %d", layer) } placement = placement.Normalize() state := c.recurrent[layer] if state == nil { shape := tensor.Shape{c.kdaNumHeads, c.kdaHeadDim, c.kdaHeadDim} if placement.Type == tensor.CUDA && device.CUDAAvailable() { t, err := device.EnsureOn(cpu.NewTensor(shape, nil), placement) if err == nil { state = t } else { state = cpu.NewTensor(shape, nil) } } else { state = cpu.NewTensor(shape, nil) } c.recurrent[layer] = state return state, nil } if twp, ok := state.(tensor.TensorWithPlacement); ok { if twp.Placement() == placement { return state, nil } } // Migrate state if needed (layer placement is stable, so this should happen at most once). moved, err := device.EnsureOn(state, placement) if err != nil { // Conservative fallback: keep CPU state. return state, nil } c.recurrent[layer] = moved return moved, nil } func (c *KimiCache) ConvStates(layer int, placement tensor.DevicePlacement) (tensor.Tensor, tensor.Tensor, tensor.Tensor, error) { if layer < 0 || layer >= c.numLayers { return nil, nil, nil, fmt.Errorf("kimi_linear: conv state layer out of range: %d", layer) } placement = placement.Normalize() convLen := c.kdaConvKernel - 1 projSize := c.kdaNumHeads * c.kdaHeadDim shape := tensor.Shape{projSize, convLen} if convLen <= 0 { shape = tensor.Shape{projSize, 0} } q := c.convQ[layer] k := c.convK[layer] v := c.convV[layer] if q == nil { q = cpu.NewTensor(shape, nil) c.convQ[layer] = q } if k == nil { k = cpu.NewTensor(shape, nil) c.convK[layer] = k } if v == nil { v = cpu.NewTensor(shape, nil) c.convV[layer] = v } if placement.Type == tensor.CUDA && device.CUDAAvailable() { if qtwp, ok := q.(tensor.TensorWithPlacement); ok { if qtwp.Placement() != placement { if moved, err := device.EnsureOn(q, placement); err == nil { q = moved c.convQ[layer] = q } } } else { if moved, err := device.EnsureOn(q, placement); err == nil { q = moved c.convQ[layer] = q } } if ktwp, ok := k.(tensor.TensorWithPlacement); ok { if ktwp.Placement() != placement { if moved, err := device.EnsureOn(k, placement); err == nil { k = moved c.convK[layer] = k } } } else { if moved, err := device.EnsureOn(k, placement); err == nil { k = moved c.convK[layer] = k } } if vtwp, ok := v.(tensor.TensorWithPlacement); ok { if vtwp.Placement() != placement { if moved, err := device.EnsureOn(v, placement); err == nil { v = moved c.convV[layer] = v } } } else { if moved, err := device.EnsureOn(v, placement); err == nil { v = moved c.convV[layer] = v } } } return q, k, v, nil } func (c *KimiCache) AppendFull(layer int, k, v *cpu.Tensor) (int, error) { if layer < 0 || layer >= c.numLayers { return 0, fmt.Errorf("kimi_linear: full cache layer out of range: %d", layer) } if k == nil || v == nil { return 0, fmt.Errorf("kimi_linear: nil k/v") } startPos := 0 newTokens := k.Shape()[0] if newTokens != v.Shape()[0] { return 0, fmt.Errorf("kimi_linear: k/v token mismatch") } if k.Shape().NumElements() != newTokens*c.mlaNumHeads*c.mlaKHeadDim { return 0, fmt.Errorf("kimi_linear: unexpected K shape %v", k.Shape()) } if v.Shape().NumElements() != newTokens*c.mlaNumHeads*c.mlaVHeadDim { return 0, fmt.Errorf("kimi_linear: unexpected V shape %v", v.Shape()) } oldTokens := c.fullLen[layer] startPos = oldTokens kDim := c.mlaNumHeads * c.mlaKHeadDim vDim := c.mlaNumHeads * c.mlaVHeadDim required := oldTokens + newTokens // Grow buffers with exponential strategy to avoid O(n^2) reallocations. kBuf := c.fullKBuf[layer] vBuf := c.fullVBuf[layer] if kBuf == nil || vBuf == nil { capacity := 1 for capacity < required { capacity <<= 1 } if capacity < 64 { capacity = 64 } kBuf = cpu.NewTensor(tensor.Shape{capacity, kDim}, nil) vBuf = cpu.NewTensor(tensor.Shape{capacity, vDim}, nil) c.fullKBuf[layer] = kBuf c.fullVBuf[layer] = vBuf } else { curCap := kBuf.Shape()[0] if curCap != vBuf.Shape()[0] { return 0, fmt.Errorf("kimi_linear: full cache capacity mismatch") } if kBuf.Shape()[1] != kDim || vBuf.Shape()[1] != vDim { return 0, fmt.Errorf("kimi_linear: full cache dim mismatch") } if required > curCap { newCap := curCap if newCap <= 0 { newCap = 64 } for newCap < required { newCap <<= 1 } newK := cpu.NewTensor(tensor.Shape{newCap, kDim}, nil) newV := cpu.NewTensor(tensor.Shape{newCap, vDim}, nil) copy(newK.DataFloat32(), kBuf.DataFloat32()[:oldTokens*kDim]) copy(newV.DataFloat32(), vBuf.DataFloat32()[:oldTokens*vDim]) kBuf = newK vBuf = newV c.fullKBuf[layer] = kBuf c.fullVBuf[layer] = vBuf } } copy(kBuf.DataFloat32()[oldTokens*kDim:required*kDim], k.DataFloat32()) copy(vBuf.DataFloat32()[oldTokens*vDim:required*vDim], v.DataFloat32()) c.fullLen[layer] = required return startPos, nil } func (c *KimiCache) FullKV(layer int) (*cpu.Tensor, *cpu.Tensor, int, bool) { if layer < 0 || layer >= c.numLayers { return nil, nil, 0, false } kBuf := c.fullKBuf[layer] vBuf := c.fullVBuf[layer] kvLen := c.fullLen[layer] if kBuf == nil || vBuf == nil || kvLen <= 0 { return nil, nil, 0, false } kDim := c.mlaNumHeads * c.mlaKHeadDim vDim := c.mlaNumHeads * c.mlaVHeadDim kView := cpu.NewTensor(tensor.Shape{kvLen, kDim}, kBuf.DataFloat32()[:kvLen*kDim]) vView := cpu.NewTensor(tensor.Shape{kvLen, vDim}, vBuf.DataFloat32()[:kvLen*vDim]) return kView, vView, kvLen, true } func AsKimiCache(kvCache model.KVCache) (*KimiCache, bool) { c, ok := kvCache.(*KimiCache) return c, ok }