forward_device.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394
  1. package kimi_linear
  2. import (
  3. "context"
  4. "fmt"
  5. "makarna/pkg/backend/cpu"
  6. "makarna/pkg/backend/cpu/nn"
  7. "makarna/pkg/compute"
  8. "makarna/pkg/model"
  9. "makarna/pkg/tensor"
  10. )
  11. func (m *Model) Forward(ctx context.Context, input tensor.Tensor, positions tensor.Tensor, kvCache model.KVCache) (tensor.Tensor, error) {
  12. cfg := m.config
  13. seqLen := input.Shape()[0]
  14. ids := nn.ParseTokenIDs(input)
  15. posArr := nn.ParsePositions(positions, seqLen)
  16. dispatcher := compute.DispatcherFromContext(ctx)
  17. scratchSet := compute.ScratchSetFromContext(ctx)
  18. baseScratch := compute.ScratchFromContext(ctx)
  19. cpuMoE := compute.CPUMoEFromContext(ctx)
  20. // Track GPU allocations for cleanup at end of forward pass
  21. var gpuAllocations []*compute.Activation
  22. defer func() {
  23. for _, act := range gpuAllocations {
  24. compute.FreeActivation(act)
  25. }
  26. }()
  27. hidden, err := compute.HybridTokenEmbedding(ids, m.tokenEmb, cfg.HiddenSize, dispatcher)
  28. if err != nil {
  29. return nil, fmt.Errorf("embedding: %w", err)
  30. }
  31. var cache *KimiCache
  32. if kvCache != nil {
  33. if c, ok := AsKimiCache(kvCache); ok {
  34. cache = c
  35. }
  36. }
  37. if cache == nil {
  38. kdaCfg, _ := parseLinearAttnConfig(cfg)
  39. cache, err = NewKimiCache(cfg.NumLayers, m.kdaNumHeads, m.kdaHeadDim, kdaCfg.ShortConvKernel, m.mlaNumHeads, m.mlaKHeadDim, m.mlaVHeadDim)
  40. if err != nil {
  41. return nil, err
  42. }
  43. }
  44. eps := float32(cfg.RMSNormEps)
  45. if eps == 0 {
  46. eps = 1e-5
  47. }
  48. for i, layer := range m.layers {
  49. compCtx := compute.NewContext(dispatcher, i)
  50. compCtx.CPUMoE = cpuMoE
  51. if p := compCtx.Placement(); p.Type == tensor.CUDA {
  52. var layerScratch *compute.ScratchSpace
  53. if scratchSet != nil {
  54. layerScratch = scratchSet.Scratch(p.GPU)
  55. } else if baseScratch != nil && baseScratch.GPU() == p.GPU {
  56. layerScratch = baseScratch
  57. }
  58. if layerScratch != nil {
  59. layerScratch.Reset()
  60. }
  61. compCtx.Scratch = layerScratch
  62. }
  63. allocAct := func(shape tensor.Shape) (*compute.Activation, error) {
  64. if compCtx.Scratch != nil && compCtx.Placement().Type == tensor.CUDA {
  65. if act, err := compCtx.Scratch.GetTensor(shape, tensor.Float32); err == nil {
  66. return act, nil
  67. }
  68. }
  69. act, err := compute.NewActivation(shape, compCtx.Placement())
  70. if err == nil && act != nil && compCtx.Placement().Type == tensor.CUDA {
  71. // Track GPU allocations that are not from scratch
  72. gpuAllocations = append(gpuAllocations, act)
  73. }
  74. return act, err
  75. }
  76. // Ensure activations are on the target device for this layer.
  77. if _, err := hidden.EnsureOn(compCtx.Placement()); err != nil {
  78. return nil, err
  79. }
  80. // Save residual BEFORE layernorm (pre-norm architecture)
  81. residualAct, err := allocAct(tensor.Shape{seqLen, cfg.HiddenSize})
  82. if err != nil {
  83. return nil, err
  84. }
  85. if err := compute.HybridCopy(compCtx, residualAct, hidden); err != nil {
  86. return nil, err
  87. }
  88. if layer.inputLayernorm == nil {
  89. return nil, fmt.Errorf("layer %d: missing input_layernorm", i)
  90. }
  91. if err := compute.HybridRMSNorm(compCtx, hidden, layer.inputLayernorm, eps); err != nil {
  92. return nil, fmt.Errorf("layer %d: attn norm: %w", i, err)
  93. }
  94. attnOutAct, err := allocAct(tensor.Shape{seqLen, cfg.HiddenSize})
  95. if err != nil {
  96. return nil, err
  97. }
  98. switch layer.kind {
  99. case layerKindKDA:
  100. if layer.kdaQProj == nil || layer.kdaKProj == nil || layer.kdaVProj == nil || layer.kdaOProj == nil {
  101. return nil, fmt.Errorf("layer %d: missing KDA projections", i)
  102. }
  103. if layer.kdaQConv == nil || layer.kdaKConv == nil || layer.kdaVConv == nil {
  104. return nil, fmt.Errorf("layer %d: missing KDA conv weights", i)
  105. }
  106. if layer.kdaFAProj == nil || layer.kdaFBProj == nil || layer.kdaBProj == nil {
  107. return nil, fmt.Errorf("layer %d: missing KDA gate/beta projections", i)
  108. }
  109. if layer.kdaALog == nil || layer.kdaDTBias == nil {
  110. return nil, fmt.Errorf("layer %d: missing KDA A_log/dt_bias", i)
  111. }
  112. if layer.kdaGAProj == nil || layer.kdaGBProj == nil {
  113. return nil, fmt.Errorf("layer %d: missing KDA o_norm gate projections", i)
  114. }
  115. kdaCfg, _ := parseLinearAttnConfig(m.config)
  116. convQ, convK, convV, err := cache.ConvStates(i, compCtx.Placement())
  117. if err != nil {
  118. return nil, err
  119. }
  120. stT, err := cache.RecurrentState(i, compCtx.Placement())
  121. if err != nil {
  122. return nil, err
  123. }
  124. if err := compute.HybridKDA(
  125. compCtx,
  126. hidden,
  127. layer.kdaQProj, layer.kdaKProj, layer.kdaVProj,
  128. layer.kdaQConv, layer.kdaKConv, layer.kdaVConv,
  129. layer.kdaFAProj, layer.kdaFBProj, layer.kdaBProj,
  130. layer.kdaALog,
  131. layer.kdaDTBias,
  132. layer.kdaGAProj, layer.kdaGBProj,
  133. layer.kdaONorm,
  134. layer.kdaOProj,
  135. convQ, convK, convV,
  136. stT,
  137. seqLen, m.kdaNumHeads, m.kdaHeadDim, kdaCfg.ShortConvKernel,
  138. eps,
  139. attnOutAct,
  140. ); err != nil {
  141. return nil, fmt.Errorf("layer %d: kda: %w", i, err)
  142. }
  143. case layerKindFull:
  144. if layer.mlaQProj == nil || layer.mlaKVAProjWithMQA == nil || layer.mlaKVALayernorm == nil || layer.mlaKVBProj == nil || layer.mlaOProj == nil {
  145. return nil, fmt.Errorf("layer %d: missing MLA weights", i)
  146. }
  147. mlaCfg, _ := parseMLAConfig(m.config)
  148. qkNope := mlaCfg.QKNopeHeadDim
  149. qkRope := mlaCfg.QKRopeHeadDim
  150. vDim := mlaCfg.VHeadDim
  151. kvARank := mlaCfg.KVLoraRank
  152. if kvARank <= 0 || qkNope <= 0 || qkRope <= 0 || vDim <= 0 {
  153. return nil, fmt.Errorf("layer %d: invalid MLA config", i)
  154. }
  155. if qkNope+qkRope != m.mlaKHeadDim {
  156. return nil, fmt.Errorf("layer %d: mla head dim mismatch", i)
  157. }
  158. if vDim != m.mlaVHeadDim {
  159. return nil, fmt.Errorf("layer %d: mla v dim mismatch", i)
  160. }
  161. qAct, err := compute.NewActivation(tensor.Shape{seqLen, m.mlaNumHeads * m.mlaKHeadDim}, tensor.DevicePlacement{Type: tensor.CPU, GPU: -1})
  162. if err != nil {
  163. return nil, err
  164. }
  165. if err := compute.HybridLinear(compCtx, hidden, layer.mlaQProj, qAct); err != nil {
  166. return nil, fmt.Errorf("layer %d: mla q_proj: %w", i, err)
  167. }
  168. qCPU, _ := qAct.AsCPU()
  169. kvAAct, err := compute.NewActivation(tensor.Shape{seqLen, kvARank + qkRope}, tensor.DevicePlacement{Type: tensor.CPU, GPU: -1})
  170. if err != nil {
  171. return nil, err
  172. }
  173. if err := compute.HybridLinear(compCtx, hidden, layer.mlaKVAProjWithMQA, kvAAct); err != nil {
  174. return nil, fmt.Errorf("layer %d: mla kv_a_proj: %w", i, err)
  175. }
  176. kvACPU, _ := kvAAct.AsCPU()
  177. kvPass := cpu.NewTensor(tensor.Shape{seqLen, kvARank}, nil)
  178. for t := 0; t < seqLen; t++ {
  179. copy(kvPass.DataFloat32()[t*kvARank:(t+1)*kvARank], kvACPU.DataFloat32()[t*(kvARank+qkRope):t*(kvARank+qkRope)+kvARank])
  180. }
  181. kvPassAct := compute.NewActivationFrom(kvPass)
  182. if compCtx.IsGPU() {
  183. if _, err := kvPassAct.EnsureOn(compCtx.Placement()); err != nil {
  184. return nil, err
  185. }
  186. }
  187. if err := compute.HybridRMSNorm(compCtx, kvPassAct, layer.mlaKVALayernorm, eps); err != nil {
  188. return nil, fmt.Errorf("layer %d: mla kv_a_layernorm: %w", i, err)
  189. }
  190. kvBAct, err := allocAct(tensor.Shape{seqLen, m.mlaNumHeads * (qkNope + vDim)})
  191. if err != nil {
  192. return nil, err
  193. }
  194. if err := compute.HybridLinear(compCtx, kvPassAct, layer.mlaKVBProj, kvBAct); err != nil {
  195. return nil, fmt.Errorf("layer %d: mla kv_b_proj: %w", i, err)
  196. }
  197. kvBCPU, _ := kvBAct.AsCPU()
  198. kStep := cpu.NewTensor(tensor.Shape{seqLen, m.mlaNumHeads * m.mlaKHeadDim}, nil)
  199. vStep := cpu.NewTensor(tensor.Shape{seqLen, m.mlaNumHeads * m.mlaVHeadDim}, nil)
  200. for t := 0; t < seqLen; t++ {
  201. kRotBase := t*(kvARank+qkRope) + kvARank
  202. kRot := kvACPU.DataFloat32()[kRotBase : kRotBase+qkRope]
  203. for h := 0; h < m.mlaNumHeads; h++ {
  204. srcBase := t*(m.mlaNumHeads*(qkNope+vDim)) + h*(qkNope+vDim)
  205. copy(kStep.DataFloat32()[t*(m.mlaNumHeads*m.mlaKHeadDim)+h*m.mlaKHeadDim:t*(m.mlaNumHeads*m.mlaKHeadDim)+h*m.mlaKHeadDim+qkNope], kvBCPU.DataFloat32()[srcBase:srcBase+qkNope])
  206. copy(kStep.DataFloat32()[t*(m.mlaNumHeads*m.mlaKHeadDim)+h*m.mlaKHeadDim+qkNope:t*(m.mlaNumHeads*m.mlaKHeadDim)+h*m.mlaKHeadDim+qkNope+qkRope], kRot)
  207. copy(vStep.DataFloat32()[t*(m.mlaNumHeads*m.mlaVHeadDim)+h*m.mlaVHeadDim:t*(m.mlaNumHeads*m.mlaVHeadDim)+h*m.mlaVHeadDim+vDim], kvBCPU.DataFloat32()[srcBase+qkNope:srcBase+qkNope+vDim])
  208. }
  209. }
  210. startPos, err := cache.AppendFull(i, kStep, vStep)
  211. if err != nil {
  212. return nil, fmt.Errorf("layer %d: cache append: %w", i, err)
  213. }
  214. fullK, fullV, _, ok := cache.FullKV(i)
  215. if !ok {
  216. return nil, fmt.Errorf("layer %d: full kv missing", i)
  217. }
  218. attnCore := cpu.NewTensor(tensor.Shape{seqLen, m.mlaNumHeads * m.mlaVHeadDim}, nil)
  219. if err := nn.CausalAttentionCachedKV(qCPU, fullK, fullV, attnCore, m.mlaNumHeads, m.mlaNumHeads, m.mlaKHeadDim, m.mlaVHeadDim, startPos); err != nil {
  220. return nil, fmt.Errorf("layer %d: mla attention: %w", i, err)
  221. }
  222. attnCoreAct := compute.NewActivationFrom(attnCore)
  223. if compCtx.IsGPU() {
  224. if _, err := attnCoreAct.EnsureOn(compCtx.Placement()); err != nil {
  225. return nil, err
  226. }
  227. }
  228. if err := compute.HybridLinear(compCtx, attnCoreAct, layer.mlaOProj, attnOutAct); err != nil {
  229. return nil, fmt.Errorf("layer %d: mla o_proj: %w", i, err)
  230. }
  231. default:
  232. return nil, fmt.Errorf("layer %d: unknown layer kind", i)
  233. }
  234. if err := compute.HybridAdd(compCtx, attnOutAct, residualAct); err != nil {
  235. return nil, err
  236. }
  237. hidden = attnOutAct
  238. // Save MLP residual BEFORE post_attention_layernorm (pre-norm architecture)
  239. mlpResidualAct, err := allocAct(tensor.Shape{seqLen, cfg.HiddenSize})
  240. if err != nil {
  241. return nil, err
  242. }
  243. if err := compute.HybridCopy(compCtx, mlpResidualAct, hidden); err != nil {
  244. return nil, err
  245. }
  246. if layer.postAttnNorm == nil {
  247. return nil, fmt.Errorf("layer %d: missing post_attention_layernorm", i)
  248. }
  249. if err := compute.HybridRMSNorm(compCtx, hidden, layer.postAttnNorm, eps); err != nil {
  250. return nil, fmt.Errorf("layer %d: post attn norm: %w", i, err)
  251. }
  252. moeCfg, _ := parseMoEConfig(m.config)
  253. useMoE := layer.moeGateW != nil && len(layer.moeW1) > 0 && len(layer.moeW2) > 0 && len(layer.moeW3) > 0
  254. if useMoE {
  255. first := moeCfg.FirstKDenseReplace
  256. if first <= 0 {
  257. first = 1
  258. }
  259. freq := moeCfg.LayerFreq
  260. if freq <= 0 {
  261. freq = 1
  262. }
  263. if i < first || (i-first)%freq != 0 {
  264. useMoE = false
  265. }
  266. }
  267. if useMoE {
  268. moeOutAct, err := allocAct(tensor.Shape{seqLen, cfg.HiddenSize})
  269. if err != nil {
  270. return nil, err
  271. }
  272. moeWeights := &compute.MoEWeights{
  273. GateW: layer.moeGateW,
  274. GateBias: layer.moeGateBias,
  275. W1: layer.moeW1,
  276. W2: layer.moeW2,
  277. W3: layer.moeW3,
  278. SharedGate: layer.moeSharedGate,
  279. SharedUp: layer.moeSharedUp,
  280. SharedDown: layer.moeSharedDown,
  281. }
  282. moeCfgCompute := compute.MoEConfig{
  283. NumExperts: moeCfg.NumExperts,
  284. TopK: moeCfg.TopK,
  285. IntermediateSize: moeCfg.IntermediateSize,
  286. RouterActivationFunc: moeCfg.RouterActivationFunc,
  287. UseGroupedTopK: moeCfg.UseGroupedTopK,
  288. NumExpertGroup: moeCfg.NumExpertGroup,
  289. TopKGroup: moeCfg.TopKGroup,
  290. Renormalize: moeCfg.Renormalize,
  291. RoutedScalingFactor: moeCfg.RoutedScalingFactor,
  292. NumSharedExperts: moeCfg.NumSharedExperts,
  293. }
  294. if err := compute.HybridMoE(compCtx, hidden, moeWeights, moeCfgCompute, moeOutAct); err != nil {
  295. return nil, fmt.Errorf("layer %d: moe: %w", i, err)
  296. }
  297. if err := compute.HybridAdd(compCtx, moeOutAct, mlpResidualAct); err != nil {
  298. return nil, err
  299. }
  300. hidden = moeOutAct
  301. } else {
  302. if layer.mlpGateProj == nil || layer.mlpUpProj == nil || layer.mlpDownProj == nil {
  303. return nil, fmt.Errorf("layer %d: missing MLP weights", i)
  304. }
  305. gateAct, err := allocAct(tensor.Shape{seqLen, cfg.Intermediate})
  306. if err != nil {
  307. return nil, err
  308. }
  309. upAct, err := allocAct(tensor.Shape{seqLen, cfg.Intermediate})
  310. if err != nil {
  311. return nil, err
  312. }
  313. if err := compute.HybridLinear(compCtx, hidden, layer.mlpGateProj, gateAct); err != nil {
  314. return nil, fmt.Errorf("layer %d: mlp gate: %w", i, err)
  315. }
  316. if err := compute.HybridLinear(compCtx, hidden, layer.mlpUpProj, upAct); err != nil {
  317. return nil, fmt.Errorf("layer %d: mlp up: %w", i, err)
  318. }
  319. act, err := allocAct(tensor.Shape{seqLen, cfg.Intermediate})
  320. if err != nil {
  321. return nil, err
  322. }
  323. if err := compute.HybridSwiGLU(compCtx, gateAct, upAct, act); err != nil {
  324. return nil, fmt.Errorf("layer %d: swiglu: %w", i, err)
  325. }
  326. mlpOut, err := allocAct(tensor.Shape{seqLen, cfg.HiddenSize})
  327. if err != nil {
  328. return nil, err
  329. }
  330. if err := compute.HybridLinear(compCtx, act, layer.mlpDownProj, mlpOut); err != nil {
  331. return nil, fmt.Errorf("layer %d: mlp down: %w", i, err)
  332. }
  333. if err := compute.HybridAdd(compCtx, mlpOut, mlpResidualAct); err != nil {
  334. return nil, err
  335. }
  336. hidden = mlpOut
  337. }
  338. }
  339. cache.Commit(seqLen)
  340. finalCtx := compute.NewContext(dispatcher, len(m.layers)-1)
  341. if m.norm == nil {
  342. return nil, fmt.Errorf("missing model.norm.weight")
  343. }
  344. if err := compute.HybridRMSNorm(finalCtx, hidden, m.norm, eps); err != nil {
  345. return nil, fmt.Errorf("final norm: %w", err)
  346. }
  347. logits := compute.NewActivationFrom(cpu.NewTensor(tensor.Shape{seqLen, cfg.VocabSize}, nil))
  348. outputW := m.output
  349. if outputW == nil {
  350. outputW = m.tokenEmb
  351. }
  352. if outputW == nil {
  353. return nil, fmt.Errorf("missing lm_head.weight and embed_tokens.weight")
  354. }
  355. if err := compute.HybridLinear(finalCtx, hidden, outputW, logits); err != nil {
  356. return nil, fmt.Errorf("lm head: %w", err)
  357. }
  358. _ = posArr
  359. return logits.Tensor(), nil
  360. }