| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394 |
- package kimi_linear
- import (
- "context"
- "fmt"
- "makarna/pkg/backend/cpu"
- "makarna/pkg/backend/cpu/nn"
- "makarna/pkg/compute"
- "makarna/pkg/model"
- "makarna/pkg/tensor"
- )
- func (m *Model) Forward(ctx context.Context, input tensor.Tensor, positions tensor.Tensor, kvCache model.KVCache) (tensor.Tensor, error) {
- cfg := m.config
- seqLen := input.Shape()[0]
- ids := nn.ParseTokenIDs(input)
- posArr := nn.ParsePositions(positions, seqLen)
- dispatcher := compute.DispatcherFromContext(ctx)
- scratchSet := compute.ScratchSetFromContext(ctx)
- baseScratch := compute.ScratchFromContext(ctx)
- cpuMoE := compute.CPUMoEFromContext(ctx)
- // Track GPU allocations for cleanup at end of forward pass
- var gpuAllocations []*compute.Activation
- defer func() {
- for _, act := range gpuAllocations {
- compute.FreeActivation(act)
- }
- }()
- hidden, err := compute.HybridTokenEmbedding(ids, m.tokenEmb, cfg.HiddenSize, dispatcher)
- if err != nil {
- return nil, fmt.Errorf("embedding: %w", err)
- }
- var cache *KimiCache
- if kvCache != nil {
- if c, ok := AsKimiCache(kvCache); ok {
- cache = c
- }
- }
- if cache == nil {
- kdaCfg, _ := parseLinearAttnConfig(cfg)
- cache, err = NewKimiCache(cfg.NumLayers, m.kdaNumHeads, m.kdaHeadDim, kdaCfg.ShortConvKernel, m.mlaNumHeads, m.mlaKHeadDim, m.mlaVHeadDim)
- if err != nil {
- return nil, err
- }
- }
- eps := float32(cfg.RMSNormEps)
- if eps == 0 {
- eps = 1e-5
- }
- for i, layer := range m.layers {
- compCtx := compute.NewContext(dispatcher, i)
- compCtx.CPUMoE = cpuMoE
- if p := compCtx.Placement(); p.Type == tensor.CUDA {
- var layerScratch *compute.ScratchSpace
- if scratchSet != nil {
- layerScratch = scratchSet.Scratch(p.GPU)
- } else if baseScratch != nil && baseScratch.GPU() == p.GPU {
- layerScratch = baseScratch
- }
- if layerScratch != nil {
- layerScratch.Reset()
- }
- compCtx.Scratch = layerScratch
- }
- allocAct := func(shape tensor.Shape) (*compute.Activation, error) {
- if compCtx.Scratch != nil && compCtx.Placement().Type == tensor.CUDA {
- if act, err := compCtx.Scratch.GetTensor(shape, tensor.Float32); err == nil {
- return act, nil
- }
- }
- act, err := compute.NewActivation(shape, compCtx.Placement())
- if err == nil && act != nil && compCtx.Placement().Type == tensor.CUDA {
- // Track GPU allocations that are not from scratch
- gpuAllocations = append(gpuAllocations, act)
- }
- return act, err
- }
- // Ensure activations are on the target device for this layer.
- if _, err := hidden.EnsureOn(compCtx.Placement()); err != nil {
- return nil, err
- }
- // Save residual BEFORE layernorm (pre-norm architecture)
- residualAct, err := allocAct(tensor.Shape{seqLen, cfg.HiddenSize})
- if err != nil {
- return nil, err
- }
- if err := compute.HybridCopy(compCtx, residualAct, hidden); err != nil {
- return nil, err
- }
- if layer.inputLayernorm == nil {
- return nil, fmt.Errorf("layer %d: missing input_layernorm", i)
- }
- if err := compute.HybridRMSNorm(compCtx, hidden, layer.inputLayernorm, eps); err != nil {
- return nil, fmt.Errorf("layer %d: attn norm: %w", i, err)
- }
- attnOutAct, err := allocAct(tensor.Shape{seqLen, cfg.HiddenSize})
- if err != nil {
- return nil, err
- }
- switch layer.kind {
- case layerKindKDA:
- if layer.kdaQProj == nil || layer.kdaKProj == nil || layer.kdaVProj == nil || layer.kdaOProj == nil {
- return nil, fmt.Errorf("layer %d: missing KDA projections", i)
- }
- if layer.kdaQConv == nil || layer.kdaKConv == nil || layer.kdaVConv == nil {
- return nil, fmt.Errorf("layer %d: missing KDA conv weights", i)
- }
- if layer.kdaFAProj == nil || layer.kdaFBProj == nil || layer.kdaBProj == nil {
- return nil, fmt.Errorf("layer %d: missing KDA gate/beta projections", i)
- }
- if layer.kdaALog == nil || layer.kdaDTBias == nil {
- return nil, fmt.Errorf("layer %d: missing KDA A_log/dt_bias", i)
- }
- if layer.kdaGAProj == nil || layer.kdaGBProj == nil {
- return nil, fmt.Errorf("layer %d: missing KDA o_norm gate projections", i)
- }
- kdaCfg, _ := parseLinearAttnConfig(m.config)
- convQ, convK, convV, err := cache.ConvStates(i, compCtx.Placement())
- if err != nil {
- return nil, err
- }
- stT, err := cache.RecurrentState(i, compCtx.Placement())
- if err != nil {
- return nil, err
- }
- if err := compute.HybridKDA(
- compCtx,
- hidden,
- layer.kdaQProj, layer.kdaKProj, layer.kdaVProj,
- layer.kdaQConv, layer.kdaKConv, layer.kdaVConv,
- layer.kdaFAProj, layer.kdaFBProj, layer.kdaBProj,
- layer.kdaALog,
- layer.kdaDTBias,
- layer.kdaGAProj, layer.kdaGBProj,
- layer.kdaONorm,
- layer.kdaOProj,
- convQ, convK, convV,
- stT,
- seqLen, m.kdaNumHeads, m.kdaHeadDim, kdaCfg.ShortConvKernel,
- eps,
- attnOutAct,
- ); err != nil {
- return nil, fmt.Errorf("layer %d: kda: %w", i, err)
- }
- case layerKindFull:
- if layer.mlaQProj == nil || layer.mlaKVAProjWithMQA == nil || layer.mlaKVALayernorm == nil || layer.mlaKVBProj == nil || layer.mlaOProj == nil {
- return nil, fmt.Errorf("layer %d: missing MLA weights", i)
- }
- mlaCfg, _ := parseMLAConfig(m.config)
- qkNope := mlaCfg.QKNopeHeadDim
- qkRope := mlaCfg.QKRopeHeadDim
- vDim := mlaCfg.VHeadDim
- kvARank := mlaCfg.KVLoraRank
- if kvARank <= 0 || qkNope <= 0 || qkRope <= 0 || vDim <= 0 {
- return nil, fmt.Errorf("layer %d: invalid MLA config", i)
- }
- if qkNope+qkRope != m.mlaKHeadDim {
- return nil, fmt.Errorf("layer %d: mla head dim mismatch", i)
- }
- if vDim != m.mlaVHeadDim {
- return nil, fmt.Errorf("layer %d: mla v dim mismatch", i)
- }
- qAct, err := compute.NewActivation(tensor.Shape{seqLen, m.mlaNumHeads * m.mlaKHeadDim}, tensor.DevicePlacement{Type: tensor.CPU, GPU: -1})
- if err != nil {
- return nil, err
- }
- if err := compute.HybridLinear(compCtx, hidden, layer.mlaQProj, qAct); err != nil {
- return nil, fmt.Errorf("layer %d: mla q_proj: %w", i, err)
- }
- qCPU, _ := qAct.AsCPU()
- kvAAct, err := compute.NewActivation(tensor.Shape{seqLen, kvARank + qkRope}, tensor.DevicePlacement{Type: tensor.CPU, GPU: -1})
- if err != nil {
- return nil, err
- }
- if err := compute.HybridLinear(compCtx, hidden, layer.mlaKVAProjWithMQA, kvAAct); err != nil {
- return nil, fmt.Errorf("layer %d: mla kv_a_proj: %w", i, err)
- }
- kvACPU, _ := kvAAct.AsCPU()
- kvPass := cpu.NewTensor(tensor.Shape{seqLen, kvARank}, nil)
- for t := 0; t < seqLen; t++ {
- copy(kvPass.DataFloat32()[t*kvARank:(t+1)*kvARank], kvACPU.DataFloat32()[t*(kvARank+qkRope):t*(kvARank+qkRope)+kvARank])
- }
- kvPassAct := compute.NewActivationFrom(kvPass)
- if compCtx.IsGPU() {
- if _, err := kvPassAct.EnsureOn(compCtx.Placement()); err != nil {
- return nil, err
- }
- }
- if err := compute.HybridRMSNorm(compCtx, kvPassAct, layer.mlaKVALayernorm, eps); err != nil {
- return nil, fmt.Errorf("layer %d: mla kv_a_layernorm: %w", i, err)
- }
- kvBAct, err := allocAct(tensor.Shape{seqLen, m.mlaNumHeads * (qkNope + vDim)})
- if err != nil {
- return nil, err
- }
- if err := compute.HybridLinear(compCtx, kvPassAct, layer.mlaKVBProj, kvBAct); err != nil {
- return nil, fmt.Errorf("layer %d: mla kv_b_proj: %w", i, err)
- }
- kvBCPU, _ := kvBAct.AsCPU()
- kStep := cpu.NewTensor(tensor.Shape{seqLen, m.mlaNumHeads * m.mlaKHeadDim}, nil)
- vStep := cpu.NewTensor(tensor.Shape{seqLen, m.mlaNumHeads * m.mlaVHeadDim}, nil)
- for t := 0; t < seqLen; t++ {
- kRotBase := t*(kvARank+qkRope) + kvARank
- kRot := kvACPU.DataFloat32()[kRotBase : kRotBase+qkRope]
- for h := 0; h < m.mlaNumHeads; h++ {
- srcBase := t*(m.mlaNumHeads*(qkNope+vDim)) + h*(qkNope+vDim)
- copy(kStep.DataFloat32()[t*(m.mlaNumHeads*m.mlaKHeadDim)+h*m.mlaKHeadDim:t*(m.mlaNumHeads*m.mlaKHeadDim)+h*m.mlaKHeadDim+qkNope], kvBCPU.DataFloat32()[srcBase:srcBase+qkNope])
- copy(kStep.DataFloat32()[t*(m.mlaNumHeads*m.mlaKHeadDim)+h*m.mlaKHeadDim+qkNope:t*(m.mlaNumHeads*m.mlaKHeadDim)+h*m.mlaKHeadDim+qkNope+qkRope], kRot)
- copy(vStep.DataFloat32()[t*(m.mlaNumHeads*m.mlaVHeadDim)+h*m.mlaVHeadDim:t*(m.mlaNumHeads*m.mlaVHeadDim)+h*m.mlaVHeadDim+vDim], kvBCPU.DataFloat32()[srcBase+qkNope:srcBase+qkNope+vDim])
- }
- }
- startPos, err := cache.AppendFull(i, kStep, vStep)
- if err != nil {
- return nil, fmt.Errorf("layer %d: cache append: %w", i, err)
- }
- fullK, fullV, _, ok := cache.FullKV(i)
- if !ok {
- return nil, fmt.Errorf("layer %d: full kv missing", i)
- }
- attnCore := cpu.NewTensor(tensor.Shape{seqLen, m.mlaNumHeads * m.mlaVHeadDim}, nil)
- if err := nn.CausalAttentionCachedKV(qCPU, fullK, fullV, attnCore, m.mlaNumHeads, m.mlaNumHeads, m.mlaKHeadDim, m.mlaVHeadDim, startPos); err != nil {
- return nil, fmt.Errorf("layer %d: mla attention: %w", i, err)
- }
- attnCoreAct := compute.NewActivationFrom(attnCore)
- if compCtx.IsGPU() {
- if _, err := attnCoreAct.EnsureOn(compCtx.Placement()); err != nil {
- return nil, err
- }
- }
- if err := compute.HybridLinear(compCtx, attnCoreAct, layer.mlaOProj, attnOutAct); err != nil {
- return nil, fmt.Errorf("layer %d: mla o_proj: %w", i, err)
- }
- default:
- return nil, fmt.Errorf("layer %d: unknown layer kind", i)
- }
- if err := compute.HybridAdd(compCtx, attnOutAct, residualAct); err != nil {
- return nil, err
- }
- hidden = attnOutAct
- // Save MLP residual BEFORE post_attention_layernorm (pre-norm architecture)
- mlpResidualAct, err := allocAct(tensor.Shape{seqLen, cfg.HiddenSize})
- if err != nil {
- return nil, err
- }
- if err := compute.HybridCopy(compCtx, mlpResidualAct, hidden); err != nil {
- return nil, err
- }
- if layer.postAttnNorm == nil {
- return nil, fmt.Errorf("layer %d: missing post_attention_layernorm", i)
- }
- if err := compute.HybridRMSNorm(compCtx, hidden, layer.postAttnNorm, eps); err != nil {
- return nil, fmt.Errorf("layer %d: post attn norm: %w", i, err)
- }
- moeCfg, _ := parseMoEConfig(m.config)
- useMoE := layer.moeGateW != nil && len(layer.moeW1) > 0 && len(layer.moeW2) > 0 && len(layer.moeW3) > 0
- if useMoE {
- first := moeCfg.FirstKDenseReplace
- if first <= 0 {
- first = 1
- }
- freq := moeCfg.LayerFreq
- if freq <= 0 {
- freq = 1
- }
- if i < first || (i-first)%freq != 0 {
- useMoE = false
- }
- }
- if useMoE {
- moeOutAct, err := allocAct(tensor.Shape{seqLen, cfg.HiddenSize})
- if err != nil {
- return nil, err
- }
- moeWeights := &compute.MoEWeights{
- GateW: layer.moeGateW,
- GateBias: layer.moeGateBias,
- W1: layer.moeW1,
- W2: layer.moeW2,
- W3: layer.moeW3,
- SharedGate: layer.moeSharedGate,
- SharedUp: layer.moeSharedUp,
- SharedDown: layer.moeSharedDown,
- }
- moeCfgCompute := compute.MoEConfig{
- NumExperts: moeCfg.NumExperts,
- TopK: moeCfg.TopK,
- IntermediateSize: moeCfg.IntermediateSize,
- RouterActivationFunc: moeCfg.RouterActivationFunc,
- UseGroupedTopK: moeCfg.UseGroupedTopK,
- NumExpertGroup: moeCfg.NumExpertGroup,
- TopKGroup: moeCfg.TopKGroup,
- Renormalize: moeCfg.Renormalize,
- RoutedScalingFactor: moeCfg.RoutedScalingFactor,
- NumSharedExperts: moeCfg.NumSharedExperts,
- }
- if err := compute.HybridMoE(compCtx, hidden, moeWeights, moeCfgCompute, moeOutAct); err != nil {
- return nil, fmt.Errorf("layer %d: moe: %w", i, err)
- }
- if err := compute.HybridAdd(compCtx, moeOutAct, mlpResidualAct); err != nil {
- return nil, err
- }
- hidden = moeOutAct
- } else {
- if layer.mlpGateProj == nil || layer.mlpUpProj == nil || layer.mlpDownProj == nil {
- return nil, fmt.Errorf("layer %d: missing MLP weights", i)
- }
- gateAct, err := allocAct(tensor.Shape{seqLen, cfg.Intermediate})
- if err != nil {
- return nil, err
- }
- upAct, err := allocAct(tensor.Shape{seqLen, cfg.Intermediate})
- if err != nil {
- return nil, err
- }
- if err := compute.HybridLinear(compCtx, hidden, layer.mlpGateProj, gateAct); err != nil {
- return nil, fmt.Errorf("layer %d: mlp gate: %w", i, err)
- }
- if err := compute.HybridLinear(compCtx, hidden, layer.mlpUpProj, upAct); err != nil {
- return nil, fmt.Errorf("layer %d: mlp up: %w", i, err)
- }
- act, err := allocAct(tensor.Shape{seqLen, cfg.Intermediate})
- if err != nil {
- return nil, err
- }
- if err := compute.HybridSwiGLU(compCtx, gateAct, upAct, act); err != nil {
- return nil, fmt.Errorf("layer %d: swiglu: %w", i, err)
- }
- mlpOut, err := allocAct(tensor.Shape{seqLen, cfg.HiddenSize})
- if err != nil {
- return nil, err
- }
- if err := compute.HybridLinear(compCtx, act, layer.mlpDownProj, mlpOut); err != nil {
- return nil, fmt.Errorf("layer %d: mlp down: %w", i, err)
- }
- if err := compute.HybridAdd(compCtx, mlpOut, mlpResidualAct); err != nil {
- return nil, err
- }
- hidden = mlpOut
- }
- }
- cache.Commit(seqLen)
- finalCtx := compute.NewContext(dispatcher, len(m.layers)-1)
- if m.norm == nil {
- return nil, fmt.Errorf("missing model.norm.weight")
- }
- if err := compute.HybridRMSNorm(finalCtx, hidden, m.norm, eps); err != nil {
- return nil, fmt.Errorf("final norm: %w", err)
- }
- logits := compute.NewActivationFrom(cpu.NewTensor(tensor.Shape{seqLen, cfg.VocabSize}, nil))
- outputW := m.output
- if outputW == nil {
- outputW = m.tokenEmb
- }
- if outputW == nil {
- return nil, fmt.Errorf("missing lm_head.weight and embed_tokens.weight")
- }
- if err := compute.HybridLinear(finalCtx, hidden, outputW, logits); err != nil {
- return nil, fmt.Errorf("lm head: %w", err)
- }
- _ = posArr
- return logits.Tensor(), nil
- }
|