| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480 |
- package compute
- import (
- "fmt"
- "unsafe"
- "makarna/pkg/backend/cpu"
- "makarna/pkg/backend/cpu/nn"
- "makarna/pkg/backend/cuda"
- "makarna/pkg/backend/device"
- "makarna/pkg/tensor"
- )
- // HybridKDA runs Kimi Delta Attention (KDA) block and writes output to attnOut.
- //
- // Current behavior:
- // - CPU path is implemented.
- // - If the current layer is placed on GPU (ctx.IsGPU() && hidden.IsGPU()), this returns an error
- // until CUDA kernels are implemented.
- func HybridKDA(
- ctx *Context,
- hidden *Activation,
- qProj, kProj, vProj tensor.Tensor,
- qConv, kConv, vConv tensor.Tensor,
- fAProj, fBProj, bProj tensor.Tensor,
- aLog tensor.Tensor,
- dtBias tensor.Tensor,
- gAProj, gBProj tensor.Tensor,
- oNorm tensor.Tensor,
- oProj tensor.Tensor,
- convQState, convKState, convVState tensor.Tensor,
- recurrentState tensor.Tensor,
- seqLen, numHeads, headDim, shortConvKernel int,
- eps float32,
- attnOut *Activation,
- ) error {
- if ctx != nil && ctx.IsGPU() && hidden.IsGPU() {
- return hybridKDAGPU(
- ctx,
- hidden,
- qProj, kProj, vProj,
- qConv, kConv, vConv,
- fAProj, fBProj, bProj,
- aLog,
- dtBias,
- gAProj, gBProj,
- oNorm,
- oProj,
- convQState, convKState, convVState,
- recurrentState,
- seqLen, numHeads, headDim, shortConvKernel,
- eps,
- attnOut,
- )
- }
- return hybridKDACPU(
- hidden,
- qProj, kProj, vProj,
- qConv, kConv, vConv,
- fAProj, fBProj, bProj,
- aLog,
- dtBias,
- gAProj, gBProj,
- oNorm,
- oProj,
- convQState, convKState, convVState,
- recurrentState,
- seqLen, numHeads, headDim, shortConvKernel,
- eps,
- attnOut,
- )
- }
- func hybridKDAGPU(
- ctx *Context,
- hidden *Activation,
- qProj, kProj, vProj tensor.Tensor,
- qConv, kConv, vConv tensor.Tensor,
- fAProj, fBProj, bProj tensor.Tensor,
- aLog tensor.Tensor,
- dtBias tensor.Tensor,
- gAProj, gBProj tensor.Tensor,
- oNorm tensor.Tensor,
- oProj tensor.Tensor,
- convQState, convKState, convVState tensor.Tensor,
- recurrentState tensor.Tensor,
- seqLen, numHeads, headDim, shortConvKernel int,
- eps float32,
- attnOut *Activation,
- ) error {
- if ctx == nil || !ctx.IsGPU() {
- return fmt.Errorf("HybridKDA/GPU: missing GPU context")
- }
- if !device.CUDAAvailable() || !cuda.Available() {
- return fmt.Errorf("HybridKDA/GPU: CUDA not available")
- }
- gpu := ctx.Placement().GPU
- projSize := numHeads * headDim
- alloc := func(shape tensor.Shape) (*Activation, error) {
- if ctx.Scratch != nil {
- if act, err := ctx.Scratch.GetTensor(shape, tensor.Float32); err == nil {
- return act, nil
- }
- }
- return NewActivation(shape, tensor.DevicePlacement{Type: tensor.CUDA, GPU: gpu})
- }
- // Project to Q/K/V on GPU.
- qAct, err := alloc(tensor.Shape{seqLen, projSize})
- if err != nil {
- return err
- }
- kAct, err := alloc(tensor.Shape{seqLen, projSize})
- if err != nil {
- return err
- }
- vAct, err := alloc(tensor.Shape{seqLen, projSize})
- if err != nil {
- return err
- }
- if err := HybridLinear(ctx, hidden, qProj, qAct); err != nil {
- return err
- }
- if err := HybridLinear(ctx, hidden, kProj, kAct); err != nil {
- return err
- }
- if err := HybridLinear(ctx, hidden, vProj, vAct); err != nil {
- return err
- }
- qCUDA, _ := qAct.AsCUDA(gpu)
- kCUDA, _ := kAct.AsCUDA(gpu)
- vCUDA, _ := vAct.AsCUDA(gpu)
- qStateCUDA, ok := convQState.(*cuda.Tensor)
- if !ok {
- return fmt.Errorf("HybridKDA/GPU: convQState not cuda tensor")
- }
- kStateCUDA, ok := convKState.(*cuda.Tensor)
- if !ok {
- return fmt.Errorf("HybridKDA/GPU: convKState not cuda tensor")
- }
- vStateCUDA, ok := convVState.(*cuda.Tensor)
- if !ok {
- return fmt.Errorf("HybridKDA/GPU: convVState not cuda tensor")
- }
- cache := GetWeightCache(gpu)
- uploadF32 := func(label string, w tensor.Tensor) (unsafe.Pointer, error) {
- if w == nil {
- return nil, fmt.Errorf("HybridKDA/GPU: missing %s", label)
- }
- if wt, ok := w.(*cuda.Tensor); ok {
- if wt.GPU() != gpu {
- return nil, fmt.Errorf("HybridKDA/GPU: %s on gpu=%d (want %d)", label, wt.GPU(), gpu)
- }
- if wt.DType() != tensor.Float32 {
- return nil, fmt.Errorf("HybridKDA/GPU: %s dtype=%v (want Float32)", label, wt.DType())
- }
- return wt.Data().(unsafe.Pointer), nil
- }
- wCPU, ok := w.(*cpu.Tensor)
- if !ok {
- return nil, fmt.Errorf("HybridKDA/GPU: %s not cpu/cuda tensor (%T)", label, w)
- }
- key := fmt.Sprintf("kda_l%d_%s", ctx.LayerIdx, label)
- if ptr, ok := cache.Get(key); ok {
- return ptr, nil
- }
- ptr, err := cache.Upload(key, wCPU)
- if err != nil {
- return nil, fmt.Errorf("HybridKDA/GPU: upload %s: %w", label, err)
- }
- return ptr, nil
- }
- qWPtr, err := uploadF32("qconv", qConv)
- if err != nil {
- return err
- }
- kWPtr, err := uploadF32("kconv", kConv)
- if err != nil {
- return err
- }
- vWPtr, err := uploadF32("vconv", vConv)
- if err != nil {
- return err
- }
- // 1. Conv1d + SiLU
- if err := cuda.KDACausalShortConv1D(qCUDA.Data().(unsafe.Pointer), qStateCUDA.Data().(unsafe.Pointer), qWPtr, seqLen, projSize, shortConvKernel, gpu); err != nil {
- return err
- }
- if err := cuda.KDACausalShortConv1D(kCUDA.Data().(unsafe.Pointer), kStateCUDA.Data().(unsafe.Pointer), kWPtr, seqLen, projSize, shortConvKernel, gpu); err != nil {
- return err
- }
- if err := cuda.KDACausalShortConv1D(vCUDA.Data().(unsafe.Pointer), vStateCUDA.Data().(unsafe.Pointer), vWPtr, seqLen, projSize, shortConvKernel, gpu); err != nil {
- return err
- }
- // 2. L2 Norm Q/K
- if err := cuda.L2NormHeads(qCUDA.Data().(unsafe.Pointer), kCUDA.Data().(unsafe.Pointer), seqLen, numHeads, headDim, 1e-6, gpu); err != nil {
- return err
- }
- // 3. Beta projection + sigmoid
- betaAct, _ := alloc(tensor.Shape{seqLen, numHeads})
- if err := HybridLinear(ctx, hidden, bProj, betaAct); err != nil {
- return err
- }
- betaCUDA, err := betaAct.AsCUDA(gpu)
- if err != nil {
- return err
- }
- if err := cuda.Sigmoid(betaCUDA.Data().(unsafe.Pointer), seqLen*numHeads, gpu); err != nil {
- return err
- }
- // 4. Gate computation (f_a -> f_b -> gate)
- gAAct, _ := alloc(tensor.Shape{seqLen, headDim})
- if err := HybridLinear(ctx, hidden, fAProj, gAAct); err != nil {
- return err
- }
- gBAct, _ := alloc(tensor.Shape{seqLen, projSize})
- if err := HybridLinear(ctx, gAAct, fBProj, gBAct); err != nil {
- return err
- }
- // Upload aLog and dtBias (cached on GPU).
- aLogCPU, ok := aLog.(*cpu.Tensor)
- if !ok {
- return fmt.Errorf("HybridKDA/GPU: aLog not cpu tensor")
- }
- aLogFlat, err := nn.FlattenALog(aLogCPU, numHeads)
- if err != nil {
- return err
- }
- aLogView := cpu.NewTensor(tensor.Shape{numHeads}, aLogFlat[:numHeads])
- aLogPtr, err := uploadF32("alog", aLogView)
- if err != nil {
- return err
- }
- dtBiasCPU, ok := dtBias.(*cpu.Tensor)
- if !ok {
- return fmt.Errorf("HybridKDA/GPU: dtBias not cpu tensor")
- }
- dtBiasPtr, err := uploadF32("dtbias", dtBiasCPU)
- if err != nil {
- return err
- }
- gBCUDA, _ := gBAct.AsCUDA(gpu)
- gOutAct, err := alloc(tensor.Shape{seqLen, projSize})
- if err != nil {
- return err
- }
- gOutCUDA, _ := gOutAct.AsCUDA(gpu)
- if err := cuda.KDAGate(gBCUDA.Data().(unsafe.Pointer), aLogPtr, dtBiasPtr, gOutCUDA.Data().(unsafe.Pointer), seqLen, numHeads, headDim, gpu); err != nil {
- return err
- }
- // 5. Recurrent state update
- stateCUDA, ok := recurrentState.(*cuda.Tensor)
- if !ok {
- return fmt.Errorf("HybridKDA/GPU: recurrentState not cuda tensor")
- }
- if err := cuda.KDARecurrent(
- qCUDA.Data().(unsafe.Pointer),
- kCUDA.Data().(unsafe.Pointer),
- vCUDA.Data().(unsafe.Pointer),
- gOutCUDA.Data().(unsafe.Pointer),
- betaCUDA.Data().(unsafe.Pointer),
- stateCUDA.Data().(unsafe.Pointer),
- seqLen, numHeads, headDim, gpu,
- ); err != nil {
- return err
- }
- // 6. Output gate (g_a -> g_b)
- gGateAAct, _ := alloc(tensor.Shape{seqLen, headDim})
- if err := HybridLinear(ctx, hidden, gAProj, gGateAAct); err != nil {
- return err
- }
- gGateBAct, _ := alloc(tensor.Shape{seqLen, projSize})
- if err := HybridLinear(ctx, gGateAAct, gBProj, gGateBAct); err != nil {
- return err
- }
- gGateBCUDA, _ := gGateBAct.AsCUDA(gpu)
- // 7. RMSNorm gated (v now contains output from recurrent)
- if oNorm != nil {
- oNormCPU, ok := oNorm.(*cpu.Tensor)
- if !ok {
- return fmt.Errorf("HybridKDA/GPU: oNorm not cpu tensor")
- }
- oNormPtr, err := uploadF32("onorm", oNormCPU)
- if err != nil {
- return err
- }
- if err := cuda.RMSNormGated(vCUDA.Data().(unsafe.Pointer), gGateBCUDA.Data().(unsafe.Pointer), oNormPtr, seqLen*projSize, headDim, eps, gpu); err != nil {
- return err
- }
- }
- // 8. Output projection
- coreAct := NewActivationFrom(vCUDA)
- if err := HybridLinear(ctx, coreAct, oProj, attnOut); err != nil {
- return err
- }
- return nil
- }
- func uploadWeightGPU(w tensor.Tensor, gpu int) (*cuda.Tensor, error) {
- wCPU := w.(*cpu.Tensor)
- wDev, err := cuda.NewTensor(wCPU.Shape(), tensor.Float32, gpu)
- if err != nil {
- return nil, err
- }
- if err := wDev.CopyFrom(wCPU.DataFloat32()); err != nil {
- wDev.Free()
- return nil, err
- }
- return wDev, nil
- }
- func hybridKDACPU(
- hidden *Activation,
- qProj, kProj, vProj tensor.Tensor,
- qConv, kConv, vConv tensor.Tensor,
- fAProj, fBProj, bProj tensor.Tensor,
- aLog tensor.Tensor,
- dtBias tensor.Tensor,
- gAProj, gBProj tensor.Tensor,
- oNorm tensor.Tensor,
- oProj tensor.Tensor,
- convQState, convKState, convVState tensor.Tensor,
- recurrentState tensor.Tensor,
- seqLen, numHeads, headDim, shortConvKernel int,
- eps float32,
- attnOut *Activation,
- ) error {
- projSize := numHeads * headDim
- hiddenCPU, err := hidden.AsCPU()
- if err != nil {
- return err
- }
- qAct := NewActivationFrom(cpu.NewTensor(tensor.Shape{seqLen, projSize}, nil))
- kAct := NewActivationFrom(cpu.NewTensor(tensor.Shape{seqLen, projSize}, nil))
- vAct := NewActivationFrom(cpu.NewTensor(tensor.Shape{seqLen, projSize}, nil))
- // Note: we intentionally call HybridLinear(nil, ...) to force CPU path.
- if err := HybridLinear(nil, NewActivationFrom(hiddenCPU), qProj, qAct); err != nil {
- return err
- }
- if err := HybridLinear(nil, NewActivationFrom(hiddenCPU), kProj, kAct); err != nil {
- return err
- }
- if err := HybridLinear(nil, NewActivationFrom(hiddenCPU), vProj, vAct); err != nil {
- return err
- }
- qCPU, _ := qAct.AsCPU()
- kCPU, _ := kAct.AsCPU()
- vCPU, _ := vAct.AsCPU()
- qConvStateCPU, ok := convQState.(*cpu.Tensor)
- if !ok {
- return fmt.Errorf("KDA: convQState not cpu tensor")
- }
- kConvStateCPU, ok := convKState.(*cpu.Tensor)
- if !ok {
- return fmt.Errorf("KDA: convKState not cpu tensor")
- }
- vConvStateCPU, ok := convVState.(*cpu.Tensor)
- if !ok {
- return fmt.Errorf("KDA: convVState not cpu tensor")
- }
- qConvW, ok := qConv.(*cpu.Tensor)
- if !ok {
- return fmt.Errorf("KDA: qConv not cpu tensor")
- }
- kConvW, ok := kConv.(*cpu.Tensor)
- if !ok {
- return fmt.Errorf("KDA: kConv not cpu tensor")
- }
- vConvW, ok := vConv.(*cpu.Tensor)
- if !ok {
- return fmt.Errorf("KDA: vConv not cpu tensor")
- }
- if err := nn.CausalShortConv1DInplace(qCPU.DataFloat32(), qConvStateCPU, qConvW, seqLen, projSize, shortConvKernel); err != nil {
- return err
- }
- if err := nn.CausalShortConv1DInplace(kCPU.DataFloat32(), kConvStateCPU, kConvW, seqLen, projSize, shortConvKernel); err != nil {
- return err
- }
- if err := nn.CausalShortConv1DInplace(vCPU.DataFloat32(), vConvStateCPU, vConvW, seqLen, projSize, shortConvKernel); err != nil {
- return err
- }
- nn.L2NormHeads(qCPU.DataFloat32(), kCPU.DataFloat32(), seqLen, numHeads, headDim, 1e-6)
- betaAct := NewActivationFrom(cpu.NewTensor(tensor.Shape{seqLen, numHeads}, nil))
- if err := HybridLinear(nil, NewActivationFrom(hiddenCPU), bProj, betaAct); err != nil {
- return err
- }
- betaCPU, _ := betaAct.AsCPU()
- nn.SigmoidInplace(betaCPU.DataFloat32())
- gAAct := NewActivationFrom(cpu.NewTensor(tensor.Shape{seqLen, headDim}, nil))
- if err := HybridLinear(nil, NewActivationFrom(hiddenCPU), fAProj, gAAct); err != nil {
- return err
- }
- gBAct := NewActivationFrom(cpu.NewTensor(tensor.Shape{seqLen, projSize}, nil))
- if err := HybridLinear(nil, gAAct, fBProj, gBAct); err != nil {
- return err
- }
- gBCPU, _ := gBAct.AsCPU()
- aLogCPU, ok := aLog.(*cpu.Tensor)
- if !ok {
- return fmt.Errorf("KDA: aLog not cpu tensor")
- }
- dtBiasCPU, ok := dtBias.(*cpu.Tensor)
- if !ok {
- return fmt.Errorf("KDA: dtBias not cpu tensor")
- }
- aLogFlat, err := nn.FlattenALog(aLogCPU, numHeads)
- if err != nil {
- return err
- }
- gOut := make([]float32, seqLen*projSize)
- for t := 0; t < seqLen; t++ {
- gTok := gBCPU.DataFloat32()[t*projSize : (t+1)*projSize]
- gTok2 := nn.KDAGate(gTok, aLogFlat, headDim, dtBiasCPU.DataFloat32())
- copy(gOut[t*projSize:(t+1)*projSize], gTok2)
- }
- stCPU, ok := recurrentState.(*cpu.Tensor)
- if !ok {
- return fmt.Errorf("KDA: recurrentState not cpu tensor")
- }
- coreFlat := make([]float32, seqLen*projSize)
- copy(coreFlat, vCPU.DataFloat32())
- if err := nn.KDARecurrent(qCPU.DataFloat32(), kCPU.DataFloat32(), coreFlat, gOut, betaCPU.DataFloat32(), stCPU.DataFloat32(), seqLen, numHeads, headDim); err != nil {
- return err
- }
- gGateAAct := NewActivationFrom(cpu.NewTensor(tensor.Shape{seqLen, headDim}, nil))
- if err := HybridLinear(nil, NewActivationFrom(hiddenCPU), gAProj, gGateAAct); err != nil {
- return err
- }
- gGateBAct := NewActivationFrom(cpu.NewTensor(tensor.Shape{seqLen, projSize}, nil))
- if err := HybridLinear(nil, gGateAAct, gBProj, gGateBAct); err != nil {
- return err
- }
- gGateBCPU, _ := gGateBAct.AsCPU()
- if oNorm != nil {
- if w, ok := oNorm.(*cpu.Tensor); ok {
- nn.RMSNormGated(coreFlat, gGateBCPU.DataFloat32(), w.DataFloat32(), headDim, eps)
- }
- }
- coreAct := NewActivationFrom(cpu.NewTensor(tensor.Shape{seqLen, projSize}, coreFlat))
- if err := HybridLinear(nil, coreAct, oProj, attnOut); err != nil {
- return err
- }
- return nil
- }
|