package kimi_linear import ( "context" "fmt" "makarna/pkg/backend/cpu" "makarna/pkg/backend/cpu/nn" "makarna/pkg/compute" "makarna/pkg/model" "makarna/pkg/tensor" ) func (m *Model) Forward(ctx context.Context, input tensor.Tensor, positions tensor.Tensor, kvCache model.KVCache) (tensor.Tensor, error) { cfg := m.config seqLen := input.Shape()[0] ids := nn.ParseTokenIDs(input) posArr := nn.ParsePositions(positions, seqLen) dispatcher := compute.DispatcherFromContext(ctx) scratchSet := compute.ScratchSetFromContext(ctx) baseScratch := compute.ScratchFromContext(ctx) cpuMoE := compute.CPUMoEFromContext(ctx) // Track GPU allocations for cleanup at end of forward pass var gpuAllocations []*compute.Activation defer func() { for _, act := range gpuAllocations { compute.FreeActivation(act) } }() hidden, err := compute.HybridTokenEmbedding(ids, m.tokenEmb, cfg.HiddenSize, dispatcher) if err != nil { return nil, fmt.Errorf("embedding: %w", err) } var cache *KimiCache if kvCache != nil { if c, ok := AsKimiCache(kvCache); ok { cache = c } } if cache == nil { kdaCfg, _ := parseLinearAttnConfig(cfg) cache, err = NewKimiCache(cfg.NumLayers, m.kdaNumHeads, m.kdaHeadDim, kdaCfg.ShortConvKernel, m.mlaNumHeads, m.mlaKHeadDim, m.mlaVHeadDim) if err != nil { return nil, err } } eps := float32(cfg.RMSNormEps) if eps == 0 { eps = 1e-5 } for i, layer := range m.layers { compCtx := compute.NewContext(dispatcher, i) compCtx.CPUMoE = cpuMoE if p := compCtx.Placement(); p.Type == tensor.CUDA { var layerScratch *compute.ScratchSpace if scratchSet != nil { layerScratch = scratchSet.Scratch(p.GPU) } else if baseScratch != nil && baseScratch.GPU() == p.GPU { layerScratch = baseScratch } if layerScratch != nil { layerScratch.Reset() } compCtx.Scratch = layerScratch } allocAct := func(shape tensor.Shape) (*compute.Activation, error) { if compCtx.Scratch != nil && compCtx.Placement().Type == tensor.CUDA { if act, err := compCtx.Scratch.GetTensor(shape, tensor.Float32); err == nil { return act, nil } } act, err := compute.NewActivation(shape, compCtx.Placement()) if err == nil && act != nil && compCtx.Placement().Type == tensor.CUDA { // Track GPU allocations that are not from scratch gpuAllocations = append(gpuAllocations, act) } return act, err } // Ensure activations are on the target device for this layer. if _, err := hidden.EnsureOn(compCtx.Placement()); err != nil { return nil, err } // Save residual BEFORE layernorm (pre-norm architecture) residualAct, err := allocAct(tensor.Shape{seqLen, cfg.HiddenSize}) if err != nil { return nil, err } if err := compute.HybridCopy(compCtx, residualAct, hidden); err != nil { return nil, err } if layer.inputLayernorm == nil { return nil, fmt.Errorf("layer %d: missing input_layernorm", i) } if err := compute.HybridRMSNorm(compCtx, hidden, layer.inputLayernorm, eps); err != nil { return nil, fmt.Errorf("layer %d: attn norm: %w", i, err) } attnOutAct, err := allocAct(tensor.Shape{seqLen, cfg.HiddenSize}) if err != nil { return nil, err } switch layer.kind { case layerKindKDA: if layer.kdaQProj == nil || layer.kdaKProj == nil || layer.kdaVProj == nil || layer.kdaOProj == nil { return nil, fmt.Errorf("layer %d: missing KDA projections", i) } if layer.kdaQConv == nil || layer.kdaKConv == nil || layer.kdaVConv == nil { return nil, fmt.Errorf("layer %d: missing KDA conv weights", i) } if layer.kdaFAProj == nil || layer.kdaFBProj == nil || layer.kdaBProj == nil { return nil, fmt.Errorf("layer %d: missing KDA gate/beta projections", i) } if layer.kdaALog == nil || layer.kdaDTBias == nil { return nil, fmt.Errorf("layer %d: missing KDA A_log/dt_bias", i) } if layer.kdaGAProj == nil || layer.kdaGBProj == nil { return nil, fmt.Errorf("layer %d: missing KDA o_norm gate projections", i) } kdaCfg, _ := parseLinearAttnConfig(m.config) convQ, convK, convV, err := cache.ConvStates(i, compCtx.Placement()) if err != nil { return nil, err } stT, err := cache.RecurrentState(i, compCtx.Placement()) if err != nil { return nil, err } if err := compute.HybridKDA( compCtx, hidden, layer.kdaQProj, layer.kdaKProj, layer.kdaVProj, layer.kdaQConv, layer.kdaKConv, layer.kdaVConv, layer.kdaFAProj, layer.kdaFBProj, layer.kdaBProj, layer.kdaALog, layer.kdaDTBias, layer.kdaGAProj, layer.kdaGBProj, layer.kdaONorm, layer.kdaOProj, convQ, convK, convV, stT, seqLen, m.kdaNumHeads, m.kdaHeadDim, kdaCfg.ShortConvKernel, eps, attnOutAct, ); err != nil { return nil, fmt.Errorf("layer %d: kda: %w", i, err) } case layerKindFull: if layer.mlaQProj == nil || layer.mlaKVAProjWithMQA == nil || layer.mlaKVALayernorm == nil || layer.mlaKVBProj == nil || layer.mlaOProj == nil { return nil, fmt.Errorf("layer %d: missing MLA weights", i) } mlaCfg, _ := parseMLAConfig(m.config) qkNope := mlaCfg.QKNopeHeadDim qkRope := mlaCfg.QKRopeHeadDim vDim := mlaCfg.VHeadDim kvARank := mlaCfg.KVLoraRank if kvARank <= 0 || qkNope <= 0 || qkRope <= 0 || vDim <= 0 { return nil, fmt.Errorf("layer %d: invalid MLA config", i) } if qkNope+qkRope != m.mlaKHeadDim { return nil, fmt.Errorf("layer %d: mla head dim mismatch", i) } if vDim != m.mlaVHeadDim { return nil, fmt.Errorf("layer %d: mla v dim mismatch", i) } qAct, err := compute.NewActivation(tensor.Shape{seqLen, m.mlaNumHeads * m.mlaKHeadDim}, tensor.DevicePlacement{Type: tensor.CPU, GPU: -1}) if err != nil { return nil, err } if err := compute.HybridLinear(compCtx, hidden, layer.mlaQProj, qAct); err != nil { return nil, fmt.Errorf("layer %d: mla q_proj: %w", i, err) } qCPU, _ := qAct.AsCPU() kvAAct, err := compute.NewActivation(tensor.Shape{seqLen, kvARank + qkRope}, tensor.DevicePlacement{Type: tensor.CPU, GPU: -1}) if err != nil { return nil, err } if err := compute.HybridLinear(compCtx, hidden, layer.mlaKVAProjWithMQA, kvAAct); err != nil { return nil, fmt.Errorf("layer %d: mla kv_a_proj: %w", i, err) } kvACPU, _ := kvAAct.AsCPU() kvPass := cpu.NewTensor(tensor.Shape{seqLen, kvARank}, nil) for t := 0; t < seqLen; t++ { copy(kvPass.DataFloat32()[t*kvARank:(t+1)*kvARank], kvACPU.DataFloat32()[t*(kvARank+qkRope):t*(kvARank+qkRope)+kvARank]) } kvPassAct := compute.NewActivationFrom(kvPass) if compCtx.IsGPU() { if _, err := kvPassAct.EnsureOn(compCtx.Placement()); err != nil { return nil, err } } if err := compute.HybridRMSNorm(compCtx, kvPassAct, layer.mlaKVALayernorm, eps); err != nil { return nil, fmt.Errorf("layer %d: mla kv_a_layernorm: %w", i, err) } kvBAct, err := allocAct(tensor.Shape{seqLen, m.mlaNumHeads * (qkNope + vDim)}) if err != nil { return nil, err } if err := compute.HybridLinear(compCtx, kvPassAct, layer.mlaKVBProj, kvBAct); err != nil { return nil, fmt.Errorf("layer %d: mla kv_b_proj: %w", i, err) } kvBCPU, _ := kvBAct.AsCPU() kStep := cpu.NewTensor(tensor.Shape{seqLen, m.mlaNumHeads * m.mlaKHeadDim}, nil) vStep := cpu.NewTensor(tensor.Shape{seqLen, m.mlaNumHeads * m.mlaVHeadDim}, nil) for t := 0; t < seqLen; t++ { kRotBase := t*(kvARank+qkRope) + kvARank kRot := kvACPU.DataFloat32()[kRotBase : kRotBase+qkRope] for h := 0; h < m.mlaNumHeads; h++ { srcBase := t*(m.mlaNumHeads*(qkNope+vDim)) + h*(qkNope+vDim) copy(kStep.DataFloat32()[t*(m.mlaNumHeads*m.mlaKHeadDim)+h*m.mlaKHeadDim:t*(m.mlaNumHeads*m.mlaKHeadDim)+h*m.mlaKHeadDim+qkNope], kvBCPU.DataFloat32()[srcBase:srcBase+qkNope]) copy(kStep.DataFloat32()[t*(m.mlaNumHeads*m.mlaKHeadDim)+h*m.mlaKHeadDim+qkNope:t*(m.mlaNumHeads*m.mlaKHeadDim)+h*m.mlaKHeadDim+qkNope+qkRope], kRot) copy(vStep.DataFloat32()[t*(m.mlaNumHeads*m.mlaVHeadDim)+h*m.mlaVHeadDim:t*(m.mlaNumHeads*m.mlaVHeadDim)+h*m.mlaVHeadDim+vDim], kvBCPU.DataFloat32()[srcBase+qkNope:srcBase+qkNope+vDim]) } } startPos, err := cache.AppendFull(i, kStep, vStep) if err != nil { return nil, fmt.Errorf("layer %d: cache append: %w", i, err) } fullK, fullV, _, ok := cache.FullKV(i) if !ok { return nil, fmt.Errorf("layer %d: full kv missing", i) } attnCore := cpu.NewTensor(tensor.Shape{seqLen, m.mlaNumHeads * m.mlaVHeadDim}, nil) if err := nn.CausalAttentionCachedKV(qCPU, fullK, fullV, attnCore, m.mlaNumHeads, m.mlaNumHeads, m.mlaKHeadDim, m.mlaVHeadDim, startPos); err != nil { return nil, fmt.Errorf("layer %d: mla attention: %w", i, err) } attnCoreAct := compute.NewActivationFrom(attnCore) if compCtx.IsGPU() { if _, err := attnCoreAct.EnsureOn(compCtx.Placement()); err != nil { return nil, err } } if err := compute.HybridLinear(compCtx, attnCoreAct, layer.mlaOProj, attnOutAct); err != nil { return nil, fmt.Errorf("layer %d: mla o_proj: %w", i, err) } default: return nil, fmt.Errorf("layer %d: unknown layer kind", i) } if err := compute.HybridAdd(compCtx, attnOutAct, residualAct); err != nil { return nil, err } hidden = attnOutAct // Save MLP residual BEFORE post_attention_layernorm (pre-norm architecture) mlpResidualAct, err := allocAct(tensor.Shape{seqLen, cfg.HiddenSize}) if err != nil { return nil, err } if err := compute.HybridCopy(compCtx, mlpResidualAct, hidden); err != nil { return nil, err } if layer.postAttnNorm == nil { return nil, fmt.Errorf("layer %d: missing post_attention_layernorm", i) } if err := compute.HybridRMSNorm(compCtx, hidden, layer.postAttnNorm, eps); err != nil { return nil, fmt.Errorf("layer %d: post attn norm: %w", i, err) } moeCfg, _ := parseMoEConfig(m.config) useMoE := layer.moeGateW != nil && len(layer.moeW1) > 0 && len(layer.moeW2) > 0 && len(layer.moeW3) > 0 if useMoE { first := moeCfg.FirstKDenseReplace if first <= 0 { first = 1 } freq := moeCfg.LayerFreq if freq <= 0 { freq = 1 } if i < first || (i-first)%freq != 0 { useMoE = false } } if useMoE { moeOutAct, err := allocAct(tensor.Shape{seqLen, cfg.HiddenSize}) if err != nil { return nil, err } moeWeights := &compute.MoEWeights{ GateW: layer.moeGateW, GateBias: layer.moeGateBias, W1: layer.moeW1, W2: layer.moeW2, W3: layer.moeW3, SharedGate: layer.moeSharedGate, SharedUp: layer.moeSharedUp, SharedDown: layer.moeSharedDown, } moeCfgCompute := compute.MoEConfig{ NumExperts: moeCfg.NumExperts, TopK: moeCfg.TopK, IntermediateSize: moeCfg.IntermediateSize, RouterActivationFunc: moeCfg.RouterActivationFunc, UseGroupedTopK: moeCfg.UseGroupedTopK, NumExpertGroup: moeCfg.NumExpertGroup, TopKGroup: moeCfg.TopKGroup, Renormalize: moeCfg.Renormalize, RoutedScalingFactor: moeCfg.RoutedScalingFactor, NumSharedExperts: moeCfg.NumSharedExperts, } if err := compute.HybridMoE(compCtx, hidden, moeWeights, moeCfgCompute, moeOutAct); err != nil { return nil, fmt.Errorf("layer %d: moe: %w", i, err) } if err := compute.HybridAdd(compCtx, moeOutAct, mlpResidualAct); err != nil { return nil, err } hidden = moeOutAct } else { if layer.mlpGateProj == nil || layer.mlpUpProj == nil || layer.mlpDownProj == nil { return nil, fmt.Errorf("layer %d: missing MLP weights", i) } gateAct, err := allocAct(tensor.Shape{seqLen, cfg.Intermediate}) if err != nil { return nil, err } upAct, err := allocAct(tensor.Shape{seqLen, cfg.Intermediate}) if err != nil { return nil, err } if err := compute.HybridLinear(compCtx, hidden, layer.mlpGateProj, gateAct); err != nil { return nil, fmt.Errorf("layer %d: mlp gate: %w", i, err) } if err := compute.HybridLinear(compCtx, hidden, layer.mlpUpProj, upAct); err != nil { return nil, fmt.Errorf("layer %d: mlp up: %w", i, err) } act, err := allocAct(tensor.Shape{seqLen, cfg.Intermediate}) if err != nil { return nil, err } if err := compute.HybridSwiGLU(compCtx, gateAct, upAct, act); err != nil { return nil, fmt.Errorf("layer %d: swiglu: %w", i, err) } mlpOut, err := allocAct(tensor.Shape{seqLen, cfg.HiddenSize}) if err != nil { return nil, err } if err := compute.HybridLinear(compCtx, act, layer.mlpDownProj, mlpOut); err != nil { return nil, fmt.Errorf("layer %d: mlp down: %w", i, err) } if err := compute.HybridAdd(compCtx, mlpOut, mlpResidualAct); err != nil { return nil, err } hidden = mlpOut } } cache.Commit(seqLen) finalCtx := compute.NewContext(dispatcher, len(m.layers)-1) if m.norm == nil { return nil, fmt.Errorf("missing model.norm.weight") } if err := compute.HybridRMSNorm(finalCtx, hidden, m.norm, eps); err != nil { return nil, fmt.Errorf("final norm: %w", err) } logits := compute.NewActivationFrom(cpu.NewTensor(tensor.Shape{seqLen, cfg.VocabSize}, nil)) outputW := m.output if outputW == nil { outputW = m.tokenEmb } if outputW == nil { return nil, fmt.Errorf("missing lm_head.weight and embed_tokens.weight") } if err := compute.HybridLinear(finalCtx, hidden, outputW, logits); err != nil { return nil, fmt.Errorf("lm head: %w", err) } _ = posArr return logits.Tensor(), nil }