// Package qwen3 implements device-aware forward pass using Hybrid* operations. // This enables efficient GPU/CPU offloading without duplicating forward pass logic. package qwen3 import ( "context" "fmt" "makarna/pkg/backend/cpu" "makarna/pkg/backend/cpu/nn" "makarna/pkg/compute" "makarna/pkg/kvcache" "makarna/pkg/model/arch" "makarna/pkg/model" "makarna/pkg/tensor" ) // Forward performs a forward pass with automatic device placement. // If a dispatcher is provided in context, uses GPU operations for GPU layers. func (m *Model) Forward(ctx context.Context, input tensor.Tensor, positions tensor.Tensor, kvCache model.KVCache) (tensor.Tensor, error) { cfg := m.config seqLen := input.Shape()[0] headDim := cfg.HeadDim if headDim == 0 { headDim = cfg.HiddenSize / cfg.NumHeads } rmsNormEps := float32(cfg.RMSNormEps) if rmsNormEps == 0 { rmsNormEps = 1e-6 } // Parse inputs ids := nn.ParseTokenIDs(input) posArr := nn.ParsePositions(positions, seqLen) dispatcher := compute.DispatcherFromContext(ctx) hidden, err := compute.HybridTokenEmbedding(ids, m.tokenEmb, cfg.HiddenSize, dispatcher) if err != nil { return nil, fmt.Errorf("embedding: %w", err) } scratchSet := compute.ScratchSetFromContext(ctx) baseScratch := compute.ScratchFromContext(ctx) // Get cache - supports both Cache and PagedKVCache via interface var cache kvcache.KVCacheInterface if kvCache != nil { cache, _ = kvCache.(kvcache.KVCacheInterface) } // Process transformer layers with device-aware operations var lastPlacement tensor.DevicePlacement = tensor.DevicePlacement{Type: tensor.CPU, GPU: -1} for i, layer := range m.layers { // Get target device for this layer var targetPlacement tensor.DevicePlacement if dispatcher != nil { targetPlacement = dispatcher.LayerPlacement(i).Normalize() } else { targetPlacement = tensor.DevicePlacement{Type: tensor.CPU, GPU: -1} } // Transfer activation if crossing device boundary if targetPlacement != lastPlacement { if err := compute.EnsureOnDevice(hidden, targetPlacement); err != nil { return nil, fmt.Errorf("layer %d: device transfer failed: %w", i, err) } } lastPlacement = targetPlacement var layerScratch *compute.ScratchSpace if targetPlacement.Type == tensor.CUDA { if scratchSet != nil { layerScratch = scratchSet.Scratch(targetPlacement.GPU) } else if baseScratch != nil && baseScratch.GPU() == targetPlacement.GPU { layerScratch = baseScratch } if layerScratch != nil { layerScratch.Reset() } } // Create compute context for this layer compCtx := compute.NewContext(dispatcher, i) compCtx.Scratch = layerScratch // Run transformer block with device-aware operations if err := m.transformerBlockDeviceAware(compCtx, hidden, layer, posArr, cache, cfg, headDim, rmsNormEps); err != nil { return nil, fmt.Errorf("layer %d: %w", i, err) } } // Commit KV cache if cache != nil { cache.Commit(seqLen) } // Final norm // Use the last layer's context or create a new one to enable GPU execution if hidden is on GPU. // We reuse the dispatcher logic: if hidden is on GPU, we try to keep it there. finalCtx := compute.NewContext(dispatcher, len(m.layers)-1) if lastPlacement.Type == tensor.CUDA { if scratchSet != nil { finalCtx.Scratch = scratchSet.Scratch(lastPlacement.GPU) } else if baseScratch != nil && baseScratch.GPU() == lastPlacement.GPU { finalCtx.Scratch = baseScratch } if finalCtx.Scratch != nil { finalCtx.Scratch.Reset() } } // Use HybridRMSNorm (keeps on GPU if already there) if err := compute.HybridRMSNorm(finalCtx, hidden, m.norm, rmsNormEps); err != nil { return nil, fmt.Errorf("final norm: %w", err) } // LM head // We initialize logits activation. HybridLinear will handle device placement. // If hidden is GPU, HybridLinear will execute on GPU and upgrade logits to GPU. logitsAct := compute.NewActivationFrom( cpu.NewTensor(tensor.Shape{seqLen, cfg.VocabSize}, nil), ) outputWeights := m.output if outputWeights == nil { outputWeights = m.tokenEmb } if err := compute.HybridLinear(finalCtx, hidden, outputWeights, logitsAct); err != nil { return nil, fmt.Errorf("lm head: %w", err) } // Return logits on the device they were computed on. return logitsAct.Tensor(), nil } func (m *Model) ForwardBatch(ctx context.Context, input tensor.Tensor, positions tensor.Tensor, kvCaches []model.KVCache) (tensor.Tensor, error) { cfg := m.config seqLen := input.Shape()[0] if len(kvCaches) != seqLen { return nil, fmt.Errorf("kvCaches len %d != input len %d", len(kvCaches), seqLen) } headDim := cfg.HeadDim if headDim == 0 { headDim = cfg.HiddenSize / cfg.NumHeads } rmsNormEps := float32(cfg.RMSNormEps) if rmsNormEps == 0 { rmsNormEps = 1e-6 } ids := nn.ParseTokenIDs(input) posArr := nn.ParsePositions(positions, seqLen) dispatcher := compute.DispatcherFromContext(ctx) hidden, err := compute.HybridTokenEmbedding(ids, m.tokenEmb, cfg.HiddenSize, dispatcher) if err != nil { return nil, fmt.Errorf("embedding: %w", err) } scratchSet := compute.ScratchSetFromContext(ctx) baseScratch := compute.ScratchFromContext(ctx) caches := make([]kvcache.KVCacheInterface, len(kvCaches)) for i := range kvCaches { if kvCaches[i] == nil { caches[i] = nil continue } c, ok := kvCaches[i].(kvcache.KVCacheInterface) if !ok { return nil, fmt.Errorf("kvCache[%d] does not implement KVCacheInterface: %T", i, kvCaches[i]) } caches[i] = c } var lastPlacement tensor.DevicePlacement = tensor.DevicePlacement{Type: tensor.CPU, GPU: -1} for i, layer := range m.layers { var targetPlacement tensor.DevicePlacement if dispatcher != nil { targetPlacement = dispatcher.LayerPlacement(i).Normalize() } else { targetPlacement = tensor.DevicePlacement{Type: tensor.CPU, GPU: -1} } if targetPlacement != lastPlacement { if err := compute.EnsureOnDevice(hidden, targetPlacement); err != nil { return nil, fmt.Errorf("layer %d: device transfer failed: %w", i, err) } } lastPlacement = targetPlacement var layerScratch *compute.ScratchSpace if targetPlacement.Type == tensor.CUDA { if scratchSet != nil { layerScratch = scratchSet.Scratch(targetPlacement.GPU) } else if baseScratch != nil && baseScratch.GPU() == targetPlacement.GPU { layerScratch = baseScratch } if layerScratch != nil { layerScratch.Reset() } } compCtx := compute.NewContext(dispatcher, i) compCtx.Scratch = layerScratch blockCfg := arch.HybridDecoderConfig{ HiddenSize: cfg.HiddenSize, NumHeads: cfg.NumHeads, NumKVHeads: cfg.NumKVHeads, Intermediate: cfg.Intermediate, HeadDim: headDim, RopeTheta: float32(cfg.RopeTheta), } blockLayer := arch.HybridDecoderLayerWeights{ Idx: layer.idx, AttnNorm: layer.attnNorm, Wq: layer.wq, Wk: layer.wk, Wv: layer.wv, Wo: layer.wo, QNorm: layer.qNorm, KNorm: layer.kNorm, MlpNorm: layer.mlpNorm, WGate: layer.wGate, WUp: layer.wUp, WDown: layer.wDown, } if err := arch.HybridDecoderBlockBatch(compCtx, hidden, &blockLayer, posArr, caches, blockCfg, rmsNormEps); err != nil { return nil, fmt.Errorf("layer %d: %w", i, err) } } for i := range caches { if caches[i] != nil { caches[i].Commit(1) } } finalCtx := compute.NewContext(dispatcher, len(m.layers)-1) if lastPlacement.Type == tensor.CUDA { if scratchSet != nil { finalCtx.Scratch = scratchSet.Scratch(lastPlacement.GPU) } else if baseScratch != nil && baseScratch.GPU() == lastPlacement.GPU { finalCtx.Scratch = baseScratch } if finalCtx.Scratch != nil { finalCtx.Scratch.Reset() } } if err := compute.HybridRMSNorm(finalCtx, hidden, m.norm, rmsNormEps); err != nil { return nil, fmt.Errorf("final norm: %w", err) } logitsAct := compute.NewActivationFrom( cpu.NewTensor(tensor.Shape{seqLen, cfg.VocabSize}, nil), ) outputWeights := m.output if outputWeights == nil { outputWeights = m.tokenEmb } if err := compute.HybridLinear(finalCtx, hidden, outputWeights, logitsAct); err != nil { return nil, fmt.Errorf("lm head: %w", err) } return logitsAct.Tensor(), nil } // transformerBlockDeviceAware processes a single transformer layer using device-aware operations. func (m *Model) transformerBlockDeviceAware(ctx *compute.Context, hidden *compute.Activation, layer *Layer, positions []int, cache kvcache.KVCacheInterface, cfg *model.Config, headDim int, eps float32) error { blockCfg := arch.HybridDecoderConfig{ HiddenSize: cfg.HiddenSize, NumHeads: cfg.NumHeads, NumKVHeads: cfg.NumKVHeads, Intermediate: cfg.Intermediate, HeadDim: headDim, RopeTheta: float32(cfg.RopeTheta), } blockLayer := arch.HybridDecoderLayerWeights{ Idx: layer.idx, AttnNorm: layer.attnNorm, Wq: layer.wq, Wk: layer.wk, Wv: layer.wv, Wo: layer.wo, QNorm: layer.qNorm, KNorm: layer.kNorm, MlpNorm: layer.mlpNorm, WGate: layer.wGate, WUp: layer.wUp, WDown: layer.wDown, } return arch.HybridDecoderBlock(ctx, hidden, &blockLayer, positions, cache, blockCfg, eps) }