| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299 |
- // Package qwen3 implements device-aware forward pass using Hybrid* operations.
- // This enables efficient GPU/CPU offloading without duplicating forward pass logic.
- package qwen3
- import (
- "context"
- "fmt"
- "makarna/pkg/backend/cpu"
- "makarna/pkg/backend/cpu/nn"
- "makarna/pkg/compute"
- "makarna/pkg/kvcache"
- "makarna/pkg/model/arch"
- "makarna/pkg/model"
- "makarna/pkg/tensor"
- )
- // Forward performs a forward pass with automatic device placement.
- // If a dispatcher is provided in context, uses GPU operations for GPU layers.
- func (m *Model) Forward(ctx context.Context, input tensor.Tensor, positions tensor.Tensor, kvCache model.KVCache) (tensor.Tensor, error) {
- cfg := m.config
- seqLen := input.Shape()[0]
- headDim := cfg.HeadDim
- if headDim == 0 {
- headDim = cfg.HiddenSize / cfg.NumHeads
- }
- rmsNormEps := float32(cfg.RMSNormEps)
- if rmsNormEps == 0 {
- rmsNormEps = 1e-6
- }
- // Parse inputs
- ids := nn.ParseTokenIDs(input)
- posArr := nn.ParsePositions(positions, seqLen)
- dispatcher := compute.DispatcherFromContext(ctx)
- hidden, err := compute.HybridTokenEmbedding(ids, m.tokenEmb, cfg.HiddenSize, dispatcher)
- if err != nil {
- return nil, fmt.Errorf("embedding: %w", err)
- }
- scratchSet := compute.ScratchSetFromContext(ctx)
- baseScratch := compute.ScratchFromContext(ctx)
- // Get cache - supports both Cache and PagedKVCache via interface
- var cache kvcache.KVCacheInterface
- if kvCache != nil {
- cache, _ = kvCache.(kvcache.KVCacheInterface)
- }
- // Process transformer layers with device-aware operations
- var lastPlacement tensor.DevicePlacement = tensor.DevicePlacement{Type: tensor.CPU, GPU: -1}
- for i, layer := range m.layers {
- // Get target device for this layer
- var targetPlacement tensor.DevicePlacement
- if dispatcher != nil {
- targetPlacement = dispatcher.LayerPlacement(i).Normalize()
- } else {
- targetPlacement = tensor.DevicePlacement{Type: tensor.CPU, GPU: -1}
- }
- // Transfer activation if crossing device boundary
- if targetPlacement != lastPlacement {
- if err := compute.EnsureOnDevice(hidden, targetPlacement); err != nil {
- return nil, fmt.Errorf("layer %d: device transfer failed: %w", i, err)
- }
- }
- lastPlacement = targetPlacement
- var layerScratch *compute.ScratchSpace
- if targetPlacement.Type == tensor.CUDA {
- if scratchSet != nil {
- layerScratch = scratchSet.Scratch(targetPlacement.GPU)
- } else if baseScratch != nil && baseScratch.GPU() == targetPlacement.GPU {
- layerScratch = baseScratch
- }
- if layerScratch != nil {
- layerScratch.Reset()
- }
- }
- // Create compute context for this layer
- compCtx := compute.NewContext(dispatcher, i)
- compCtx.Scratch = layerScratch
- // Run transformer block with device-aware operations
- if err := m.transformerBlockDeviceAware(compCtx, hidden, layer, posArr, cache, cfg, headDim, rmsNormEps); err != nil {
- return nil, fmt.Errorf("layer %d: %w", i, err)
- }
- }
- // Commit KV cache
- if cache != nil {
- cache.Commit(seqLen)
- }
- // Final norm
- // Use the last layer's context or create a new one to enable GPU execution if hidden is on GPU.
- // We reuse the dispatcher logic: if hidden is on GPU, we try to keep it there.
- finalCtx := compute.NewContext(dispatcher, len(m.layers)-1)
- if lastPlacement.Type == tensor.CUDA {
- if scratchSet != nil {
- finalCtx.Scratch = scratchSet.Scratch(lastPlacement.GPU)
- } else if baseScratch != nil && baseScratch.GPU() == lastPlacement.GPU {
- finalCtx.Scratch = baseScratch
- }
- if finalCtx.Scratch != nil {
- finalCtx.Scratch.Reset()
- }
- }
- // Use HybridRMSNorm (keeps on GPU if already there)
- if err := compute.HybridRMSNorm(finalCtx, hidden, m.norm, rmsNormEps); err != nil {
- return nil, fmt.Errorf("final norm: %w", err)
- }
- // LM head
- // We initialize logits activation. HybridLinear will handle device placement.
- // If hidden is GPU, HybridLinear will execute on GPU and upgrade logits to GPU.
- logitsAct := compute.NewActivationFrom(
- cpu.NewTensor(tensor.Shape{seqLen, cfg.VocabSize}, nil),
- )
- outputWeights := m.output
- if outputWeights == nil {
- outputWeights = m.tokenEmb
- }
- if err := compute.HybridLinear(finalCtx, hidden, outputWeights, logitsAct); err != nil {
- return nil, fmt.Errorf("lm head: %w", err)
- }
- // Return logits on the device they were computed on.
- return logitsAct.Tensor(), nil
- }
- func (m *Model) ForwardBatch(ctx context.Context, input tensor.Tensor, positions tensor.Tensor, kvCaches []model.KVCache) (tensor.Tensor, error) {
- cfg := m.config
- seqLen := input.Shape()[0]
- if len(kvCaches) != seqLen {
- return nil, fmt.Errorf("kvCaches len %d != input len %d", len(kvCaches), seqLen)
- }
- headDim := cfg.HeadDim
- if headDim == 0 {
- headDim = cfg.HiddenSize / cfg.NumHeads
- }
- rmsNormEps := float32(cfg.RMSNormEps)
- if rmsNormEps == 0 {
- rmsNormEps = 1e-6
- }
- ids := nn.ParseTokenIDs(input)
- posArr := nn.ParsePositions(positions, seqLen)
- dispatcher := compute.DispatcherFromContext(ctx)
- hidden, err := compute.HybridTokenEmbedding(ids, m.tokenEmb, cfg.HiddenSize, dispatcher)
- if err != nil {
- return nil, fmt.Errorf("embedding: %w", err)
- }
- scratchSet := compute.ScratchSetFromContext(ctx)
- baseScratch := compute.ScratchFromContext(ctx)
- caches := make([]kvcache.KVCacheInterface, len(kvCaches))
- for i := range kvCaches {
- if kvCaches[i] == nil {
- caches[i] = nil
- continue
- }
- c, ok := kvCaches[i].(kvcache.KVCacheInterface)
- if !ok {
- return nil, fmt.Errorf("kvCache[%d] does not implement KVCacheInterface: %T", i, kvCaches[i])
- }
- caches[i] = c
- }
- var lastPlacement tensor.DevicePlacement = tensor.DevicePlacement{Type: tensor.CPU, GPU: -1}
- for i, layer := range m.layers {
- var targetPlacement tensor.DevicePlacement
- if dispatcher != nil {
- targetPlacement = dispatcher.LayerPlacement(i).Normalize()
- } else {
- targetPlacement = tensor.DevicePlacement{Type: tensor.CPU, GPU: -1}
- }
- if targetPlacement != lastPlacement {
- if err := compute.EnsureOnDevice(hidden, targetPlacement); err != nil {
- return nil, fmt.Errorf("layer %d: device transfer failed: %w", i, err)
- }
- }
- lastPlacement = targetPlacement
- var layerScratch *compute.ScratchSpace
- if targetPlacement.Type == tensor.CUDA {
- if scratchSet != nil {
- layerScratch = scratchSet.Scratch(targetPlacement.GPU)
- } else if baseScratch != nil && baseScratch.GPU() == targetPlacement.GPU {
- layerScratch = baseScratch
- }
- if layerScratch != nil {
- layerScratch.Reset()
- }
- }
- compCtx := compute.NewContext(dispatcher, i)
- compCtx.Scratch = layerScratch
- blockCfg := arch.HybridDecoderConfig{
- HiddenSize: cfg.HiddenSize,
- NumHeads: cfg.NumHeads,
- NumKVHeads: cfg.NumKVHeads,
- Intermediate: cfg.Intermediate,
- HeadDim: headDim,
- RopeTheta: float32(cfg.RopeTheta),
- }
- blockLayer := arch.HybridDecoderLayerWeights{
- Idx: layer.idx,
- AttnNorm: layer.attnNorm,
- Wq: layer.wq,
- Wk: layer.wk,
- Wv: layer.wv,
- Wo: layer.wo,
- QNorm: layer.qNorm,
- KNorm: layer.kNorm,
- MlpNorm: layer.mlpNorm,
- WGate: layer.wGate,
- WUp: layer.wUp,
- WDown: layer.wDown,
- }
- if err := arch.HybridDecoderBlockBatch(compCtx, hidden, &blockLayer, posArr, caches, blockCfg, rmsNormEps); err != nil {
- return nil, fmt.Errorf("layer %d: %w", i, err)
- }
- }
- for i := range caches {
- if caches[i] != nil {
- caches[i].Commit(1)
- }
- }
- finalCtx := compute.NewContext(dispatcher, len(m.layers)-1)
- if lastPlacement.Type == tensor.CUDA {
- if scratchSet != nil {
- finalCtx.Scratch = scratchSet.Scratch(lastPlacement.GPU)
- } else if baseScratch != nil && baseScratch.GPU() == lastPlacement.GPU {
- finalCtx.Scratch = baseScratch
- }
- if finalCtx.Scratch != nil {
- finalCtx.Scratch.Reset()
- }
- }
- if err := compute.HybridRMSNorm(finalCtx, hidden, m.norm, rmsNormEps); err != nil {
- return nil, fmt.Errorf("final norm: %w", err)
- }
- logitsAct := compute.NewActivationFrom(
- cpu.NewTensor(tensor.Shape{seqLen, cfg.VocabSize}, nil),
- )
- outputWeights := m.output
- if outputWeights == nil {
- outputWeights = m.tokenEmb
- }
- if err := compute.HybridLinear(finalCtx, hidden, outputWeights, logitsAct); err != nil {
- return nil, fmt.Errorf("lm head: %w", err)
- }
- return logitsAct.Tensor(), nil
- }
- // transformerBlockDeviceAware processes a single transformer layer using device-aware operations.
- func (m *Model) transformerBlockDeviceAware(ctx *compute.Context, hidden *compute.Activation, layer *Layer, positions []int, cache kvcache.KVCacheInterface, cfg *model.Config, headDim int, eps float32) error {
- blockCfg := arch.HybridDecoderConfig{
- HiddenSize: cfg.HiddenSize,
- NumHeads: cfg.NumHeads,
- NumKVHeads: cfg.NumKVHeads,
- Intermediate: cfg.Intermediate,
- HeadDim: headDim,
- RopeTheta: float32(cfg.RopeTheta),
- }
- blockLayer := arch.HybridDecoderLayerWeights{
- Idx: layer.idx,
- AttnNorm: layer.attnNorm,
- Wq: layer.wq,
- Wk: layer.wk,
- Wv: layer.wv,
- Wo: layer.wo,
- QNorm: layer.qNorm,
- KNorm: layer.kNorm,
- MlpNorm: layer.mlpNorm,
- WGate: layer.wGate,
- WUp: layer.wUp,
- WDown: layer.wDown,
- }
- return arch.HybridDecoderBlock(ctx, hidden, &blockLayer, positions, cache, blockCfg, eps)
- }
|