package compute import ( "fmt" "unsafe" "makarna/pkg/backend/cpu" "makarna/pkg/backend/cpu/nn" "makarna/pkg/backend/cuda" "makarna/pkg/backend/device" "makarna/pkg/tensor" ) // HybridKDA runs Kimi Delta Attention (KDA) block and writes output to attnOut. // // Current behavior: // - CPU path is implemented. // - If the current layer is placed on GPU (ctx.IsGPU() && hidden.IsGPU()), this returns an error // until CUDA kernels are implemented. func HybridKDA( ctx *Context, hidden *Activation, qProj, kProj, vProj tensor.Tensor, qConv, kConv, vConv tensor.Tensor, fAProj, fBProj, bProj tensor.Tensor, aLog tensor.Tensor, dtBias tensor.Tensor, gAProj, gBProj tensor.Tensor, oNorm tensor.Tensor, oProj tensor.Tensor, convQState, convKState, convVState tensor.Tensor, recurrentState tensor.Tensor, seqLen, numHeads, headDim, shortConvKernel int, eps float32, attnOut *Activation, ) error { if ctx != nil && ctx.IsGPU() && hidden.IsGPU() { return hybridKDAGPU( ctx, hidden, qProj, kProj, vProj, qConv, kConv, vConv, fAProj, fBProj, bProj, aLog, dtBias, gAProj, gBProj, oNorm, oProj, convQState, convKState, convVState, recurrentState, seqLen, numHeads, headDim, shortConvKernel, eps, attnOut, ) } return hybridKDACPU( hidden, qProj, kProj, vProj, qConv, kConv, vConv, fAProj, fBProj, bProj, aLog, dtBias, gAProj, gBProj, oNorm, oProj, convQState, convKState, convVState, recurrentState, seqLen, numHeads, headDim, shortConvKernel, eps, attnOut, ) } func hybridKDAGPU( ctx *Context, hidden *Activation, qProj, kProj, vProj tensor.Tensor, qConv, kConv, vConv tensor.Tensor, fAProj, fBProj, bProj tensor.Tensor, aLog tensor.Tensor, dtBias tensor.Tensor, gAProj, gBProj tensor.Tensor, oNorm tensor.Tensor, oProj tensor.Tensor, convQState, convKState, convVState tensor.Tensor, recurrentState tensor.Tensor, seqLen, numHeads, headDim, shortConvKernel int, eps float32, attnOut *Activation, ) error { if ctx == nil || !ctx.IsGPU() { return fmt.Errorf("HybridKDA/GPU: missing GPU context") } if !device.CUDAAvailable() || !cuda.Available() { return fmt.Errorf("HybridKDA/GPU: CUDA not available") } gpu := ctx.Placement().GPU projSize := numHeads * headDim alloc := func(shape tensor.Shape) (*Activation, error) { if ctx.Scratch != nil { if act, err := ctx.Scratch.GetTensor(shape, tensor.Float32); err == nil { return act, nil } } return NewActivation(shape, tensor.DevicePlacement{Type: tensor.CUDA, GPU: gpu}) } // Project to Q/K/V on GPU. qAct, err := alloc(tensor.Shape{seqLen, projSize}) if err != nil { return err } kAct, err := alloc(tensor.Shape{seqLen, projSize}) if err != nil { return err } vAct, err := alloc(tensor.Shape{seqLen, projSize}) if err != nil { return err } if err := HybridLinear(ctx, hidden, qProj, qAct); err != nil { return err } if err := HybridLinear(ctx, hidden, kProj, kAct); err != nil { return err } if err := HybridLinear(ctx, hidden, vProj, vAct); err != nil { return err } qCUDA, _ := qAct.AsCUDA(gpu) kCUDA, _ := kAct.AsCUDA(gpu) vCUDA, _ := vAct.AsCUDA(gpu) qStateCUDA, ok := convQState.(*cuda.Tensor) if !ok { return fmt.Errorf("HybridKDA/GPU: convQState not cuda tensor") } kStateCUDA, ok := convKState.(*cuda.Tensor) if !ok { return fmt.Errorf("HybridKDA/GPU: convKState not cuda tensor") } vStateCUDA, ok := convVState.(*cuda.Tensor) if !ok { return fmt.Errorf("HybridKDA/GPU: convVState not cuda tensor") } cache := GetWeightCache(gpu) uploadF32 := func(label string, w tensor.Tensor) (unsafe.Pointer, error) { if w == nil { return nil, fmt.Errorf("HybridKDA/GPU: missing %s", label) } if wt, ok := w.(*cuda.Tensor); ok { if wt.GPU() != gpu { return nil, fmt.Errorf("HybridKDA/GPU: %s on gpu=%d (want %d)", label, wt.GPU(), gpu) } if wt.DType() != tensor.Float32 { return nil, fmt.Errorf("HybridKDA/GPU: %s dtype=%v (want Float32)", label, wt.DType()) } return wt.Data().(unsafe.Pointer), nil } wCPU, ok := w.(*cpu.Tensor) if !ok { return nil, fmt.Errorf("HybridKDA/GPU: %s not cpu/cuda tensor (%T)", label, w) } key := fmt.Sprintf("kda_l%d_%s", ctx.LayerIdx, label) if ptr, ok := cache.Get(key); ok { return ptr, nil } ptr, err := cache.Upload(key, wCPU) if err != nil { return nil, fmt.Errorf("HybridKDA/GPU: upload %s: %w", label, err) } return ptr, nil } qWPtr, err := uploadF32("qconv", qConv) if err != nil { return err } kWPtr, err := uploadF32("kconv", kConv) if err != nil { return err } vWPtr, err := uploadF32("vconv", vConv) if err != nil { return err } // 1. Conv1d + SiLU if err := cuda.KDACausalShortConv1D(qCUDA.Data().(unsafe.Pointer), qStateCUDA.Data().(unsafe.Pointer), qWPtr, seqLen, projSize, shortConvKernel, gpu); err != nil { return err } if err := cuda.KDACausalShortConv1D(kCUDA.Data().(unsafe.Pointer), kStateCUDA.Data().(unsafe.Pointer), kWPtr, seqLen, projSize, shortConvKernel, gpu); err != nil { return err } if err := cuda.KDACausalShortConv1D(vCUDA.Data().(unsafe.Pointer), vStateCUDA.Data().(unsafe.Pointer), vWPtr, seqLen, projSize, shortConvKernel, gpu); err != nil { return err } // 2. L2 Norm Q/K if err := cuda.L2NormHeads(qCUDA.Data().(unsafe.Pointer), kCUDA.Data().(unsafe.Pointer), seqLen, numHeads, headDim, 1e-6, gpu); err != nil { return err } // 3. Beta projection + sigmoid betaAct, _ := alloc(tensor.Shape{seqLen, numHeads}) if err := HybridLinear(ctx, hidden, bProj, betaAct); err != nil { return err } betaCUDA, err := betaAct.AsCUDA(gpu) if err != nil { return err } if err := cuda.Sigmoid(betaCUDA.Data().(unsafe.Pointer), seqLen*numHeads, gpu); err != nil { return err } // 4. Gate computation (f_a -> f_b -> gate) gAAct, _ := alloc(tensor.Shape{seqLen, headDim}) if err := HybridLinear(ctx, hidden, fAProj, gAAct); err != nil { return err } gBAct, _ := alloc(tensor.Shape{seqLen, projSize}) if err := HybridLinear(ctx, gAAct, fBProj, gBAct); err != nil { return err } // Upload aLog and dtBias (cached on GPU). aLogCPU, ok := aLog.(*cpu.Tensor) if !ok { return fmt.Errorf("HybridKDA/GPU: aLog not cpu tensor") } aLogFlat, err := nn.FlattenALog(aLogCPU, numHeads) if err != nil { return err } aLogView := cpu.NewTensor(tensor.Shape{numHeads}, aLogFlat[:numHeads]) aLogPtr, err := uploadF32("alog", aLogView) if err != nil { return err } dtBiasCPU, ok := dtBias.(*cpu.Tensor) if !ok { return fmt.Errorf("HybridKDA/GPU: dtBias not cpu tensor") } dtBiasPtr, err := uploadF32("dtbias", dtBiasCPU) if err != nil { return err } gBCUDA, _ := gBAct.AsCUDA(gpu) gOutAct, err := alloc(tensor.Shape{seqLen, projSize}) if err != nil { return err } gOutCUDA, _ := gOutAct.AsCUDA(gpu) if err := cuda.KDAGate(gBCUDA.Data().(unsafe.Pointer), aLogPtr, dtBiasPtr, gOutCUDA.Data().(unsafe.Pointer), seqLen, numHeads, headDim, gpu); err != nil { return err } // 5. Recurrent state update stateCUDA, ok := recurrentState.(*cuda.Tensor) if !ok { return fmt.Errorf("HybridKDA/GPU: recurrentState not cuda tensor") } if err := cuda.KDARecurrent( qCUDA.Data().(unsafe.Pointer), kCUDA.Data().(unsafe.Pointer), vCUDA.Data().(unsafe.Pointer), gOutCUDA.Data().(unsafe.Pointer), betaCUDA.Data().(unsafe.Pointer), stateCUDA.Data().(unsafe.Pointer), seqLen, numHeads, headDim, gpu, ); err != nil { return err } // 6. Output gate (g_a -> g_b) gGateAAct, _ := alloc(tensor.Shape{seqLen, headDim}) if err := HybridLinear(ctx, hidden, gAProj, gGateAAct); err != nil { return err } gGateBAct, _ := alloc(tensor.Shape{seqLen, projSize}) if err := HybridLinear(ctx, gGateAAct, gBProj, gGateBAct); err != nil { return err } gGateBCUDA, _ := gGateBAct.AsCUDA(gpu) // 7. RMSNorm gated (v now contains output from recurrent) if oNorm != nil { oNormCPU, ok := oNorm.(*cpu.Tensor) if !ok { return fmt.Errorf("HybridKDA/GPU: oNorm not cpu tensor") } oNormPtr, err := uploadF32("onorm", oNormCPU) if err != nil { return err } if err := cuda.RMSNormGated(vCUDA.Data().(unsafe.Pointer), gGateBCUDA.Data().(unsafe.Pointer), oNormPtr, seqLen*projSize, headDim, eps, gpu); err != nil { return err } } // 8. Output projection coreAct := NewActivationFrom(vCUDA) if err := HybridLinear(ctx, coreAct, oProj, attnOut); err != nil { return err } return nil } func uploadWeightGPU(w tensor.Tensor, gpu int) (*cuda.Tensor, error) { wCPU := w.(*cpu.Tensor) wDev, err := cuda.NewTensor(wCPU.Shape(), tensor.Float32, gpu) if err != nil { return nil, err } if err := wDev.CopyFrom(wCPU.DataFloat32()); err != nil { wDev.Free() return nil, err } return wDev, nil } func hybridKDACPU( hidden *Activation, qProj, kProj, vProj tensor.Tensor, qConv, kConv, vConv tensor.Tensor, fAProj, fBProj, bProj tensor.Tensor, aLog tensor.Tensor, dtBias tensor.Tensor, gAProj, gBProj tensor.Tensor, oNorm tensor.Tensor, oProj tensor.Tensor, convQState, convKState, convVState tensor.Tensor, recurrentState tensor.Tensor, seqLen, numHeads, headDim, shortConvKernel int, eps float32, attnOut *Activation, ) error { projSize := numHeads * headDim hiddenCPU, err := hidden.AsCPU() if err != nil { return err } qAct := NewActivationFrom(cpu.NewTensor(tensor.Shape{seqLen, projSize}, nil)) kAct := NewActivationFrom(cpu.NewTensor(tensor.Shape{seqLen, projSize}, nil)) vAct := NewActivationFrom(cpu.NewTensor(tensor.Shape{seqLen, projSize}, nil)) // Note: we intentionally call HybridLinear(nil, ...) to force CPU path. if err := HybridLinear(nil, NewActivationFrom(hiddenCPU), qProj, qAct); err != nil { return err } if err := HybridLinear(nil, NewActivationFrom(hiddenCPU), kProj, kAct); err != nil { return err } if err := HybridLinear(nil, NewActivationFrom(hiddenCPU), vProj, vAct); err != nil { return err } qCPU, _ := qAct.AsCPU() kCPU, _ := kAct.AsCPU() vCPU, _ := vAct.AsCPU() qConvStateCPU, ok := convQState.(*cpu.Tensor) if !ok { return fmt.Errorf("KDA: convQState not cpu tensor") } kConvStateCPU, ok := convKState.(*cpu.Tensor) if !ok { return fmt.Errorf("KDA: convKState not cpu tensor") } vConvStateCPU, ok := convVState.(*cpu.Tensor) if !ok { return fmt.Errorf("KDA: convVState not cpu tensor") } qConvW, ok := qConv.(*cpu.Tensor) if !ok { return fmt.Errorf("KDA: qConv not cpu tensor") } kConvW, ok := kConv.(*cpu.Tensor) if !ok { return fmt.Errorf("KDA: kConv not cpu tensor") } vConvW, ok := vConv.(*cpu.Tensor) if !ok { return fmt.Errorf("KDA: vConv not cpu tensor") } if err := nn.CausalShortConv1DInplace(qCPU.DataFloat32(), qConvStateCPU, qConvW, seqLen, projSize, shortConvKernel); err != nil { return err } if err := nn.CausalShortConv1DInplace(kCPU.DataFloat32(), kConvStateCPU, kConvW, seqLen, projSize, shortConvKernel); err != nil { return err } if err := nn.CausalShortConv1DInplace(vCPU.DataFloat32(), vConvStateCPU, vConvW, seqLen, projSize, shortConvKernel); err != nil { return err } nn.L2NormHeads(qCPU.DataFloat32(), kCPU.DataFloat32(), seqLen, numHeads, headDim, 1e-6) betaAct := NewActivationFrom(cpu.NewTensor(tensor.Shape{seqLen, numHeads}, nil)) if err := HybridLinear(nil, NewActivationFrom(hiddenCPU), bProj, betaAct); err != nil { return err } betaCPU, _ := betaAct.AsCPU() nn.SigmoidInplace(betaCPU.DataFloat32()) gAAct := NewActivationFrom(cpu.NewTensor(tensor.Shape{seqLen, headDim}, nil)) if err := HybridLinear(nil, NewActivationFrom(hiddenCPU), fAProj, gAAct); err != nil { return err } gBAct := NewActivationFrom(cpu.NewTensor(tensor.Shape{seqLen, projSize}, nil)) if err := HybridLinear(nil, gAAct, fBProj, gBAct); err != nil { return err } gBCPU, _ := gBAct.AsCPU() aLogCPU, ok := aLog.(*cpu.Tensor) if !ok { return fmt.Errorf("KDA: aLog not cpu tensor") } dtBiasCPU, ok := dtBias.(*cpu.Tensor) if !ok { return fmt.Errorf("KDA: dtBias not cpu tensor") } aLogFlat, err := nn.FlattenALog(aLogCPU, numHeads) if err != nil { return err } gOut := make([]float32, seqLen*projSize) for t := 0; t < seqLen; t++ { gTok := gBCPU.DataFloat32()[t*projSize : (t+1)*projSize] gTok2 := nn.KDAGate(gTok, aLogFlat, headDim, dtBiasCPU.DataFloat32()) copy(gOut[t*projSize:(t+1)*projSize], gTok2) } stCPU, ok := recurrentState.(*cpu.Tensor) if !ok { return fmt.Errorf("KDA: recurrentState not cpu tensor") } coreFlat := make([]float32, seqLen*projSize) copy(coreFlat, vCPU.DataFloat32()) if err := nn.KDARecurrent(qCPU.DataFloat32(), kCPU.DataFloat32(), coreFlat, gOut, betaCPU.DataFloat32(), stCPU.DataFloat32(), seqLen, numHeads, headDim); err != nil { return err } gGateAAct := NewActivationFrom(cpu.NewTensor(tensor.Shape{seqLen, headDim}, nil)) if err := HybridLinear(nil, NewActivationFrom(hiddenCPU), gAProj, gGateAAct); err != nil { return err } gGateBAct := NewActivationFrom(cpu.NewTensor(tensor.Shape{seqLen, projSize}, nil)) if err := HybridLinear(nil, gGateAAct, gBProj, gGateBAct); err != nil { return err } gGateBCPU, _ := gGateBAct.AsCPU() if oNorm != nil { if w, ok := oNorm.(*cpu.Tensor); ok { nn.RMSNormGated(coreFlat, gGateBCPU.DataFloat32(), w.DataFloat32(), headDim, eps) } } coreAct := NewActivationFrom(cpu.NewTensor(tensor.Shape{seqLen, projSize}, coreFlat)) if err := HybridLinear(nil, coreAct, oProj, attnOut); err != nil { return err } return nil }