| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287 |
- package kimi_linear
- import (
- "fmt"
- "makarna/pkg/backend/cpu"
- "makarna/pkg/backend/device"
- "makarna/pkg/model"
- "makarna/pkg/tensor"
- )
- type KimiCache struct {
- numLayers int
- kdaNumHeads int
- kdaHeadDim int
- kdaConvKernel int
- mlaNumHeads int
- mlaKHeadDim int
- mlaVHeadDim int
- seqLen int
- recurrent []tensor.Tensor
- convQ []tensor.Tensor
- convK []tensor.Tensor
- convV []tensor.Tensor
- fullKBuf []*cpu.Tensor
- fullVBuf []*cpu.Tensor
- fullLen []int
- committed int
- }
- func NewKimiCache(numLayers int, kdaNumHeads int, kdaHeadDim int, kdaConvKernel int, mlaNumHeads int, mlaKHeadDim int, mlaVHeadDim int) (*KimiCache, error) {
- if numLayers <= 0 {
- return nil, fmt.Errorf("kimi_linear: invalid numLayers %d", numLayers)
- }
- if kdaNumHeads <= 0 || kdaHeadDim <= 0 {
- return nil, fmt.Errorf("kimi_linear: invalid kda heads/dim %d/%d", kdaNumHeads, kdaHeadDim)
- }
- if kdaConvKernel <= 0 {
- return nil, fmt.Errorf("kimi_linear: invalid kda conv kernel %d", kdaConvKernel)
- }
- if mlaNumHeads <= 0 || mlaKHeadDim <= 0 || mlaVHeadDim <= 0 {
- return nil, fmt.Errorf("kimi_linear: invalid mla heads/dim %d/%d/%d", mlaNumHeads, mlaKHeadDim, mlaVHeadDim)
- }
- return &KimiCache{
- numLayers: numLayers,
- kdaNumHeads: kdaNumHeads,
- kdaHeadDim: kdaHeadDim,
- kdaConvKernel: kdaConvKernel,
- mlaNumHeads: mlaNumHeads,
- mlaKHeadDim: mlaKHeadDim,
- mlaVHeadDim: mlaVHeadDim,
- recurrent: make([]tensor.Tensor, numLayers),
- convQ: make([]tensor.Tensor, numLayers),
- convK: make([]tensor.Tensor, numLayers),
- convV: make([]tensor.Tensor, numLayers),
- fullKBuf: make([]*cpu.Tensor, numLayers),
- fullVBuf: make([]*cpu.Tensor, numLayers),
- fullLen: make([]int, numLayers),
- }, nil
- }
- func (c *KimiCache) SeqLen() int {
- if c == nil {
- return 0
- }
- return c.seqLen
- }
- func (c *KimiCache) Commit(newTokens int) {
- if c == nil {
- return
- }
- c.committed += newTokens
- c.seqLen += newTokens
- }
- func (c *KimiCache) RecurrentState(layer int, placement tensor.DevicePlacement) (tensor.Tensor, error) {
- if layer < 0 || layer >= c.numLayers {
- return nil, fmt.Errorf("kimi_linear: recurrent state layer out of range: %d", layer)
- }
- placement = placement.Normalize()
- state := c.recurrent[layer]
- if state == nil {
- shape := tensor.Shape{c.kdaNumHeads, c.kdaHeadDim, c.kdaHeadDim}
- if placement.Type == tensor.CUDA && device.CUDAAvailable() {
- t, err := device.EnsureOn(cpu.NewTensor(shape, nil), placement)
- if err == nil {
- state = t
- } else {
- state = cpu.NewTensor(shape, nil)
- }
- } else {
- state = cpu.NewTensor(shape, nil)
- }
- c.recurrent[layer] = state
- return state, nil
- }
- if twp, ok := state.(tensor.TensorWithPlacement); ok {
- if twp.Placement() == placement {
- return state, nil
- }
- }
- // Migrate state if needed (layer placement is stable, so this should happen at most once).
- moved, err := device.EnsureOn(state, placement)
- if err != nil {
- // Conservative fallback: keep CPU state.
- return state, nil
- }
- c.recurrent[layer] = moved
- return moved, nil
- }
- func (c *KimiCache) ConvStates(layer int, placement tensor.DevicePlacement) (tensor.Tensor, tensor.Tensor, tensor.Tensor, error) {
- if layer < 0 || layer >= c.numLayers {
- return nil, nil, nil, fmt.Errorf("kimi_linear: conv state layer out of range: %d", layer)
- }
- placement = placement.Normalize()
- convLen := c.kdaConvKernel - 1
- projSize := c.kdaNumHeads * c.kdaHeadDim
- shape := tensor.Shape{projSize, convLen}
- if convLen <= 0 {
- shape = tensor.Shape{projSize, 0}
- }
- q := c.convQ[layer]
- k := c.convK[layer]
- v := c.convV[layer]
- if q == nil {
- q = cpu.NewTensor(shape, nil)
- c.convQ[layer] = q
- }
- if k == nil {
- k = cpu.NewTensor(shape, nil)
- c.convK[layer] = k
- }
- if v == nil {
- v = cpu.NewTensor(shape, nil)
- c.convV[layer] = v
- }
- if placement.Type == tensor.CUDA && device.CUDAAvailable() {
- if qtwp, ok := q.(tensor.TensorWithPlacement); ok {
- if qtwp.Placement() != placement {
- if moved, err := device.EnsureOn(q, placement); err == nil {
- q = moved
- c.convQ[layer] = q
- }
- }
- } else {
- if moved, err := device.EnsureOn(q, placement); err == nil {
- q = moved
- c.convQ[layer] = q
- }
- }
- if ktwp, ok := k.(tensor.TensorWithPlacement); ok {
- if ktwp.Placement() != placement {
- if moved, err := device.EnsureOn(k, placement); err == nil {
- k = moved
- c.convK[layer] = k
- }
- }
- } else {
- if moved, err := device.EnsureOn(k, placement); err == nil {
- k = moved
- c.convK[layer] = k
- }
- }
- if vtwp, ok := v.(tensor.TensorWithPlacement); ok {
- if vtwp.Placement() != placement {
- if moved, err := device.EnsureOn(v, placement); err == nil {
- v = moved
- c.convV[layer] = v
- }
- }
- } else {
- if moved, err := device.EnsureOn(v, placement); err == nil {
- v = moved
- c.convV[layer] = v
- }
- }
- }
- return q, k, v, nil
- }
- func (c *KimiCache) AppendFull(layer int, k, v *cpu.Tensor) (int, error) {
- if layer < 0 || layer >= c.numLayers {
- return 0, fmt.Errorf("kimi_linear: full cache layer out of range: %d", layer)
- }
- if k == nil || v == nil {
- return 0, fmt.Errorf("kimi_linear: nil k/v")
- }
- startPos := 0
- newTokens := k.Shape()[0]
- if newTokens != v.Shape()[0] {
- return 0, fmt.Errorf("kimi_linear: k/v token mismatch")
- }
- if k.Shape().NumElements() != newTokens*c.mlaNumHeads*c.mlaKHeadDim {
- return 0, fmt.Errorf("kimi_linear: unexpected K shape %v", k.Shape())
- }
- if v.Shape().NumElements() != newTokens*c.mlaNumHeads*c.mlaVHeadDim {
- return 0, fmt.Errorf("kimi_linear: unexpected V shape %v", v.Shape())
- }
- oldTokens := c.fullLen[layer]
- startPos = oldTokens
- kDim := c.mlaNumHeads * c.mlaKHeadDim
- vDim := c.mlaNumHeads * c.mlaVHeadDim
- required := oldTokens + newTokens
- // Grow buffers with exponential strategy to avoid O(n^2) reallocations.
- kBuf := c.fullKBuf[layer]
- vBuf := c.fullVBuf[layer]
- if kBuf == nil || vBuf == nil {
- capacity := 1
- for capacity < required {
- capacity <<= 1
- }
- if capacity < 64 {
- capacity = 64
- }
- kBuf = cpu.NewTensor(tensor.Shape{capacity, kDim}, nil)
- vBuf = cpu.NewTensor(tensor.Shape{capacity, vDim}, nil)
- c.fullKBuf[layer] = kBuf
- c.fullVBuf[layer] = vBuf
- } else {
- curCap := kBuf.Shape()[0]
- if curCap != vBuf.Shape()[0] {
- return 0, fmt.Errorf("kimi_linear: full cache capacity mismatch")
- }
- if kBuf.Shape()[1] != kDim || vBuf.Shape()[1] != vDim {
- return 0, fmt.Errorf("kimi_linear: full cache dim mismatch")
- }
- if required > curCap {
- newCap := curCap
- if newCap <= 0 {
- newCap = 64
- }
- for newCap < required {
- newCap <<= 1
- }
- newK := cpu.NewTensor(tensor.Shape{newCap, kDim}, nil)
- newV := cpu.NewTensor(tensor.Shape{newCap, vDim}, nil)
- copy(newK.DataFloat32(), kBuf.DataFloat32()[:oldTokens*kDim])
- copy(newV.DataFloat32(), vBuf.DataFloat32()[:oldTokens*vDim])
- kBuf = newK
- vBuf = newV
- c.fullKBuf[layer] = kBuf
- c.fullVBuf[layer] = vBuf
- }
- }
- copy(kBuf.DataFloat32()[oldTokens*kDim:required*kDim], k.DataFloat32())
- copy(vBuf.DataFloat32()[oldTokens*vDim:required*vDim], v.DataFloat32())
- c.fullLen[layer] = required
- return startPos, nil
- }
- func (c *KimiCache) FullKV(layer int) (*cpu.Tensor, *cpu.Tensor, int, bool) {
- if layer < 0 || layer >= c.numLayers {
- return nil, nil, 0, false
- }
- kBuf := c.fullKBuf[layer]
- vBuf := c.fullVBuf[layer]
- kvLen := c.fullLen[layer]
- if kBuf == nil || vBuf == nil || kvLen <= 0 {
- return nil, nil, 0, false
- }
- kDim := c.mlaNumHeads * c.mlaKHeadDim
- vDim := c.mlaNumHeads * c.mlaVHeadDim
- kView := cpu.NewTensor(tensor.Shape{kvLen, kDim}, kBuf.DataFloat32()[:kvLen*kDim])
- vView := cpu.NewTensor(tensor.Shape{kvLen, vDim}, vBuf.DataFloat32()[:kvLen*vDim])
- return kView, vView, kvLen, true
- }
- func AsKimiCache(kvCache model.KVCache) (*KimiCache, bool) {
- c, ok := kvCache.(*KimiCache)
- return c, ok
- }
|