cache.go 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. package kimi_linear
  2. import (
  3. "fmt"
  4. "makarna/pkg/backend/cpu"
  5. "makarna/pkg/backend/device"
  6. "makarna/pkg/model"
  7. "makarna/pkg/tensor"
  8. )
  9. type KimiCache struct {
  10. numLayers int
  11. kdaNumHeads int
  12. kdaHeadDim int
  13. kdaConvKernel int
  14. mlaNumHeads int
  15. mlaKHeadDim int
  16. mlaVHeadDim int
  17. seqLen int
  18. recurrent []tensor.Tensor
  19. convQ []tensor.Tensor
  20. convK []tensor.Tensor
  21. convV []tensor.Tensor
  22. fullKBuf []*cpu.Tensor
  23. fullVBuf []*cpu.Tensor
  24. fullLen []int
  25. committed int
  26. }
  27. func NewKimiCache(numLayers int, kdaNumHeads int, kdaHeadDim int, kdaConvKernel int, mlaNumHeads int, mlaKHeadDim int, mlaVHeadDim int) (*KimiCache, error) {
  28. if numLayers <= 0 {
  29. return nil, fmt.Errorf("kimi_linear: invalid numLayers %d", numLayers)
  30. }
  31. if kdaNumHeads <= 0 || kdaHeadDim <= 0 {
  32. return nil, fmt.Errorf("kimi_linear: invalid kda heads/dim %d/%d", kdaNumHeads, kdaHeadDim)
  33. }
  34. if kdaConvKernel <= 0 {
  35. return nil, fmt.Errorf("kimi_linear: invalid kda conv kernel %d", kdaConvKernel)
  36. }
  37. if mlaNumHeads <= 0 || mlaKHeadDim <= 0 || mlaVHeadDim <= 0 {
  38. return nil, fmt.Errorf("kimi_linear: invalid mla heads/dim %d/%d/%d", mlaNumHeads, mlaKHeadDim, mlaVHeadDim)
  39. }
  40. return &KimiCache{
  41. numLayers: numLayers,
  42. kdaNumHeads: kdaNumHeads,
  43. kdaHeadDim: kdaHeadDim,
  44. kdaConvKernel: kdaConvKernel,
  45. mlaNumHeads: mlaNumHeads,
  46. mlaKHeadDim: mlaKHeadDim,
  47. mlaVHeadDim: mlaVHeadDim,
  48. recurrent: make([]tensor.Tensor, numLayers),
  49. convQ: make([]tensor.Tensor, numLayers),
  50. convK: make([]tensor.Tensor, numLayers),
  51. convV: make([]tensor.Tensor, numLayers),
  52. fullKBuf: make([]*cpu.Tensor, numLayers),
  53. fullVBuf: make([]*cpu.Tensor, numLayers),
  54. fullLen: make([]int, numLayers),
  55. }, nil
  56. }
  57. func (c *KimiCache) SeqLen() int {
  58. if c == nil {
  59. return 0
  60. }
  61. return c.seqLen
  62. }
  63. func (c *KimiCache) Commit(newTokens int) {
  64. if c == nil {
  65. return
  66. }
  67. c.committed += newTokens
  68. c.seqLen += newTokens
  69. }
  70. func (c *KimiCache) RecurrentState(layer int, placement tensor.DevicePlacement) (tensor.Tensor, error) {
  71. if layer < 0 || layer >= c.numLayers {
  72. return nil, fmt.Errorf("kimi_linear: recurrent state layer out of range: %d", layer)
  73. }
  74. placement = placement.Normalize()
  75. state := c.recurrent[layer]
  76. if state == nil {
  77. shape := tensor.Shape{c.kdaNumHeads, c.kdaHeadDim, c.kdaHeadDim}
  78. if placement.Type == tensor.CUDA && device.CUDAAvailable() {
  79. t, err := device.EnsureOn(cpu.NewTensor(shape, nil), placement)
  80. if err == nil {
  81. state = t
  82. } else {
  83. state = cpu.NewTensor(shape, nil)
  84. }
  85. } else {
  86. state = cpu.NewTensor(shape, nil)
  87. }
  88. c.recurrent[layer] = state
  89. return state, nil
  90. }
  91. if twp, ok := state.(tensor.TensorWithPlacement); ok {
  92. if twp.Placement() == placement {
  93. return state, nil
  94. }
  95. }
  96. // Migrate state if needed (layer placement is stable, so this should happen at most once).
  97. moved, err := device.EnsureOn(state, placement)
  98. if err != nil {
  99. // Conservative fallback: keep CPU state.
  100. return state, nil
  101. }
  102. c.recurrent[layer] = moved
  103. return moved, nil
  104. }
  105. func (c *KimiCache) ConvStates(layer int, placement tensor.DevicePlacement) (tensor.Tensor, tensor.Tensor, tensor.Tensor, error) {
  106. if layer < 0 || layer >= c.numLayers {
  107. return nil, nil, nil, fmt.Errorf("kimi_linear: conv state layer out of range: %d", layer)
  108. }
  109. placement = placement.Normalize()
  110. convLen := c.kdaConvKernel - 1
  111. projSize := c.kdaNumHeads * c.kdaHeadDim
  112. shape := tensor.Shape{projSize, convLen}
  113. if convLen <= 0 {
  114. shape = tensor.Shape{projSize, 0}
  115. }
  116. q := c.convQ[layer]
  117. k := c.convK[layer]
  118. v := c.convV[layer]
  119. if q == nil {
  120. q = cpu.NewTensor(shape, nil)
  121. c.convQ[layer] = q
  122. }
  123. if k == nil {
  124. k = cpu.NewTensor(shape, nil)
  125. c.convK[layer] = k
  126. }
  127. if v == nil {
  128. v = cpu.NewTensor(shape, nil)
  129. c.convV[layer] = v
  130. }
  131. if placement.Type == tensor.CUDA && device.CUDAAvailable() {
  132. if qtwp, ok := q.(tensor.TensorWithPlacement); ok {
  133. if qtwp.Placement() != placement {
  134. if moved, err := device.EnsureOn(q, placement); err == nil {
  135. q = moved
  136. c.convQ[layer] = q
  137. }
  138. }
  139. } else {
  140. if moved, err := device.EnsureOn(q, placement); err == nil {
  141. q = moved
  142. c.convQ[layer] = q
  143. }
  144. }
  145. if ktwp, ok := k.(tensor.TensorWithPlacement); ok {
  146. if ktwp.Placement() != placement {
  147. if moved, err := device.EnsureOn(k, placement); err == nil {
  148. k = moved
  149. c.convK[layer] = k
  150. }
  151. }
  152. } else {
  153. if moved, err := device.EnsureOn(k, placement); err == nil {
  154. k = moved
  155. c.convK[layer] = k
  156. }
  157. }
  158. if vtwp, ok := v.(tensor.TensorWithPlacement); ok {
  159. if vtwp.Placement() != placement {
  160. if moved, err := device.EnsureOn(v, placement); err == nil {
  161. v = moved
  162. c.convV[layer] = v
  163. }
  164. }
  165. } else {
  166. if moved, err := device.EnsureOn(v, placement); err == nil {
  167. v = moved
  168. c.convV[layer] = v
  169. }
  170. }
  171. }
  172. return q, k, v, nil
  173. }
  174. func (c *KimiCache) AppendFull(layer int, k, v *cpu.Tensor) (int, error) {
  175. if layer < 0 || layer >= c.numLayers {
  176. return 0, fmt.Errorf("kimi_linear: full cache layer out of range: %d", layer)
  177. }
  178. if k == nil || v == nil {
  179. return 0, fmt.Errorf("kimi_linear: nil k/v")
  180. }
  181. startPos := 0
  182. newTokens := k.Shape()[0]
  183. if newTokens != v.Shape()[0] {
  184. return 0, fmt.Errorf("kimi_linear: k/v token mismatch")
  185. }
  186. if k.Shape().NumElements() != newTokens*c.mlaNumHeads*c.mlaKHeadDim {
  187. return 0, fmt.Errorf("kimi_linear: unexpected K shape %v", k.Shape())
  188. }
  189. if v.Shape().NumElements() != newTokens*c.mlaNumHeads*c.mlaVHeadDim {
  190. return 0, fmt.Errorf("kimi_linear: unexpected V shape %v", v.Shape())
  191. }
  192. oldTokens := c.fullLen[layer]
  193. startPos = oldTokens
  194. kDim := c.mlaNumHeads * c.mlaKHeadDim
  195. vDim := c.mlaNumHeads * c.mlaVHeadDim
  196. required := oldTokens + newTokens
  197. // Grow buffers with exponential strategy to avoid O(n^2) reallocations.
  198. kBuf := c.fullKBuf[layer]
  199. vBuf := c.fullVBuf[layer]
  200. if kBuf == nil || vBuf == nil {
  201. capacity := 1
  202. for capacity < required {
  203. capacity <<= 1
  204. }
  205. if capacity < 64 {
  206. capacity = 64
  207. }
  208. kBuf = cpu.NewTensor(tensor.Shape{capacity, kDim}, nil)
  209. vBuf = cpu.NewTensor(tensor.Shape{capacity, vDim}, nil)
  210. c.fullKBuf[layer] = kBuf
  211. c.fullVBuf[layer] = vBuf
  212. } else {
  213. curCap := kBuf.Shape()[0]
  214. if curCap != vBuf.Shape()[0] {
  215. return 0, fmt.Errorf("kimi_linear: full cache capacity mismatch")
  216. }
  217. if kBuf.Shape()[1] != kDim || vBuf.Shape()[1] != vDim {
  218. return 0, fmt.Errorf("kimi_linear: full cache dim mismatch")
  219. }
  220. if required > curCap {
  221. newCap := curCap
  222. if newCap <= 0 {
  223. newCap = 64
  224. }
  225. for newCap < required {
  226. newCap <<= 1
  227. }
  228. newK := cpu.NewTensor(tensor.Shape{newCap, kDim}, nil)
  229. newV := cpu.NewTensor(tensor.Shape{newCap, vDim}, nil)
  230. copy(newK.DataFloat32(), kBuf.DataFloat32()[:oldTokens*kDim])
  231. copy(newV.DataFloat32(), vBuf.DataFloat32()[:oldTokens*vDim])
  232. kBuf = newK
  233. vBuf = newV
  234. c.fullKBuf[layer] = kBuf
  235. c.fullVBuf[layer] = vBuf
  236. }
  237. }
  238. copy(kBuf.DataFloat32()[oldTokens*kDim:required*kDim], k.DataFloat32())
  239. copy(vBuf.DataFloat32()[oldTokens*vDim:required*vDim], v.DataFloat32())
  240. c.fullLen[layer] = required
  241. return startPos, nil
  242. }
  243. func (c *KimiCache) FullKV(layer int) (*cpu.Tensor, *cpu.Tensor, int, bool) {
  244. if layer < 0 || layer >= c.numLayers {
  245. return nil, nil, 0, false
  246. }
  247. kBuf := c.fullKBuf[layer]
  248. vBuf := c.fullVBuf[layer]
  249. kvLen := c.fullLen[layer]
  250. if kBuf == nil || vBuf == nil || kvLen <= 0 {
  251. return nil, nil, 0, false
  252. }
  253. kDim := c.mlaNumHeads * c.mlaKHeadDim
  254. vDim := c.mlaNumHeads * c.mlaVHeadDim
  255. kView := cpu.NewTensor(tensor.Shape{kvLen, kDim}, kBuf.DataFloat32()[:kvLen*kDim])
  256. vView := cpu.NewTensor(tensor.Shape{kvLen, vDim}, vBuf.DataFloat32()[:kvLen*vDim])
  257. return kView, vView, kvLen, true
  258. }
  259. func AsKimiCache(kvCache model.KVCache) (*KimiCache, bool) {
  260. c, ok := kvCache.(*KimiCache)
  261. return c, ok
  262. }