package arch import ( "fmt" "math" "unsafe" "makarna/pkg/backend/cpu" "makarna/pkg/backend/cpu/nn" "makarna/pkg/backend/cuda" "makarna/pkg/compute" "makarna/pkg/kvcache" "makarna/pkg/profile" "makarna/pkg/tensor" ) type HybridDecoderLayerWeights struct { Idx int AttnNorm tensor.Tensor Wq tensor.Tensor Wk tensor.Tensor Wv tensor.Tensor Wo tensor.Tensor QNorm tensor.Tensor KNorm tensor.Tensor MlpNorm tensor.Tensor WGate tensor.Tensor WUp tensor.Tensor WDown tensor.Tensor } type HybridDecoderConfig struct { HiddenSize int NumHeads int NumKVHeads int Intermediate int HeadDim int RopeTheta float32 } func HybridDecoderBlock(ctx *compute.Context, hidden *compute.Activation, layer *HybridDecoderLayerWeights, positions []int, cache kvcache.KVCacheInterface, cfg HybridDecoderConfig, eps float32) error { cfg = normalizeHybridDecoderConfig(cfg) if err := ensureActivationOnContextDevice(ctx, hidden); err != nil { return err } alloc := makeHybridAllocator(ctx) residual, err := cloneActivation(ctx, alloc, hidden, false) if err != nil { return err } postAttn, err := runHybridAttentionBlock(ctx, alloc, hidden, residual, layer, positions, cache, cfg, eps) if err != nil { return err } out, err := runHybridMLPBlock(ctx, alloc, postAttn, layer, cfg, eps) if err != nil { return err } hidden.ReplaceWith(out.Tensor()) return nil } func HybridDecoderBlockBatch(ctx *compute.Context, hidden *compute.Activation, layer *HybridDecoderLayerWeights, positions []int, caches []kvcache.KVCacheInterface, cfg HybridDecoderConfig, eps float32) error { cfg = normalizeHybridDecoderConfig(cfg) if err := ensureActivationOnContextDevice(ctx, hidden); err != nil { return err } if len(caches) != hidden.Shape()[0] { return fmt.Errorf("caches len %d != hidden batch %d", len(caches), hidden.Shape()[0]) } alloc := makeHybridAllocator(ctx) residual, err := cloneActivation(ctx, alloc, hidden, false) if err != nil { return err } postAttn, err := runHybridAttentionBlockBatch(ctx, alloc, hidden, residual, layer, positions, caches, cfg, eps) if err != nil { return err } out, err := runHybridMLPBlock(ctx, alloc, postAttn, layer, cfg, eps) if err != nil { return err } hidden.ReplaceWith(out.Tensor()) return nil } func normalizeHybridDecoderConfig(cfg HybridDecoderConfig) HybridDecoderConfig { if cfg.NumKVHeads == 0 { cfg.NumKVHeads = cfg.NumHeads } if cfg.Intermediate == 0 { cfg.Intermediate = cfg.HiddenSize * 4 } return cfg } func ensureActivationOnContextDevice(ctx *compute.Context, a *compute.Activation) error { if ctx == nil { return nil } if _, err := a.EnsureOn(ctx.Placement()); err != nil { return fmt.Errorf("move hidden to target device: %w", err) } return nil } type hybridAllocFn func(shape tensor.Shape, placement tensor.DevicePlacement, allowScratch bool) (*compute.Activation, error) func makeHybridAllocator(ctx *compute.Context) hybridAllocFn { return func(shape tensor.Shape, placement tensor.DevicePlacement, allowScratch bool) (*compute.Activation, error) { placement = placement.Normalize() if allowScratch && ctx != nil && ctx.Scratch != nil && placement.Type == tensor.CUDA && ctx.Scratch.GPU() == placement.GPU { if a, err := ctx.Scratch.GetTensor(shape, tensor.Float32); err == nil { return a, nil } } return compute.NewActivation(shape, placement) } } func cloneActivation(ctx *compute.Context, alloc hybridAllocFn, src *compute.Activation, allowScratch bool) (*compute.Activation, error) { dst, err := alloc(src.Shape(), src.Placement(), allowScratch) if err != nil { return nil, fmt.Errorf("alloc residual: %w", err) } if err := compute.HybridCopy(ctx, dst, src); err != nil { return nil, fmt.Errorf("copy residual: %w", err) } return dst, nil } func runHybridAttentionBlock(ctx *compute.Context, alloc hybridAllocFn, hidden *compute.Activation, residual *compute.Activation, layer *HybridDecoderLayerWeights, positions []int, cache kvcache.KVCacheInterface, cfg HybridDecoderConfig, eps float32) (*compute.Activation, error) { seqLen := hidden.Shape()[0] profile.Start("Block/AttnNorm") if err := compute.HybridRMSNorm(ctx, hidden, layer.AttnNorm, eps); err != nil { profile.End("Block/AttnNorm") return nil, fmt.Errorf("attn norm: %w", err) } profile.End("Block/AttnNorm") qOut, err := alloc(tensor.Shape{seqLen, cfg.NumHeads * cfg.HeadDim}, ctx.Placement(), true) if err != nil { return nil, err } kOut, err := alloc(tensor.Shape{seqLen, cfg.NumKVHeads * cfg.HeadDim}, ctx.Placement(), true) if err != nil { return nil, err } vOut, err := alloc(tensor.Shape{seqLen, cfg.NumKVHeads * cfg.HeadDim}, ctx.Placement(), true) if err != nil { return nil, err } profile.Start("Block/QProj") if err := compute.HybridLinear(ctx, hidden, layer.Wq, qOut); err != nil { profile.End("Block/QProj") return nil, fmt.Errorf("q_proj: %w", err) } profile.End("Block/QProj") profile.Start("Block/KProj") if err := compute.HybridLinear(ctx, hidden, layer.Wk, kOut); err != nil { profile.End("Block/KProj") return nil, fmt.Errorf("k_proj: %w", err) } profile.End("Block/KProj") profile.Start("Block/VProj") if err := compute.HybridLinear(ctx, hidden, layer.Wv, vOut); err != nil { profile.End("Block/VProj") return nil, fmt.Errorf("v_proj: %w", err) } profile.End("Block/VProj") if layer.QNorm != nil { profile.Start("Block/QNorm") if err := compute.HybridRMSNorm(ctx, qOut, layer.QNorm, eps); err != nil { profile.End("Block/QNorm") return nil, fmt.Errorf("q_norm: %w", err) } profile.End("Block/QNorm") } if layer.KNorm != nil { profile.Start("Block/KNorm") if err := compute.HybridRMSNorm(ctx, kOut, layer.KNorm, eps); err != nil { profile.End("Block/KNorm") return nil, fmt.Errorf("k_norm: %w", err) } profile.End("Block/KNorm") } // Decide whether we can fuse RoPE inside the paged attention kernel. // This skips the standalone RoPE kernel launches and avoids an extra read/modify/write of Q/K. layerCacheDev := tensor.DevicePlacement{Type: tensor.CPU, GPU: -1} var pc *kvcache.PagedKVCache useFusedRoPE := false if cache != nil { layerCacheDev = cache.LayerDevice(layer.Idx).Normalize() if p, ok := cache.(*kvcache.PagedKVCache); ok { pc = p } } if pc != nil && layerCacheDev.Type == tensor.CUDA && ctx != nil && ctx.IsGPU() && cuda.Available() { gpu := ctx.Placement().GPU // Fused kernel supports headDim<=256 and even headDim. if pc.LayerDevice(layer.Idx).GPU == gpu && cfg.HeadDim <= 256 && (cfg.HeadDim&1) == 0 && cfg.RopeTheta != 0 { useFusedRoPE = true } } if !useFusedRoPE { profile.Start("Block/RoPE_Q") if err := compute.HybridRoPE(ctx, qOut, positions, cfg.HeadDim, cfg.RopeTheta); err != nil { profile.End("Block/RoPE_Q") return nil, fmt.Errorf("rope q: %w", err) } profile.End("Block/RoPE_Q") profile.Start("Block/RoPE_K") if err := compute.HybridRoPE(ctx, kOut, positions, cfg.HeadDim, cfg.RopeTheta); err != nil { profile.End("Block/RoPE_K") return nil, fmt.Errorf("rope k: %w", err) } profile.End("Block/RoPE_K") } attnOut, err := alloc(tensor.Shape{seqLen, cfg.NumHeads * cfg.HeadDim}, ctx.Placement(), true) if err != nil { return nil, err } scale := float32(1.0 / math.Sqrt(float64(cfg.HeadDim))) profile.Start("Block/Attention") didPaged := false if cache != nil && layerCacheDev.Type == tensor.CUDA && ctx != nil && ctx.IsGPU() && cuda.Available() { if pc != nil { startPos := cache.SeqLen() gpu := ctx.Placement().GPU if pc.LayerDevice(layer.Idx).GPU != gpu { profile.End("Block/Attention") return nil, fmt.Errorf("paged attention requires cache layer %d on gpu %d (got gpu %d)", layer.Idx, gpu, pc.LayerDevice(layer.Idx).GPU) } if _, _, err := cache.Append(layer.Idx, kOut.Tensor(), vOut.Tensor()); err != nil { profile.End("Block/Attention") return nil, fmt.Errorf("cache append: %w", err) } kvLen := startPos + seqLen bs := pc.BlockSize() numBlocks := (kvLen + bs - 1) / bs var ( kDev unsafe.Pointer vDev unsafe.Pointer freeAfter bool kvType tensor.DType blockSize int ) // Prefer persistent per-layer device pointer tables (rebuilt only when numBlocks grows). kDev, vDev, blockSize, kvType, err = pc.LayerDevicePtrTables(layer.Idx, numBlocks) if err != nil { // Fallback to per-call scratch allocation/copy if needed. kPtrs, vPtrs, blockSize2, kvType2, err2 := pc.LayerBlockPtrTables(layer.Idx, numBlocks) if err2 != nil { profile.End("Block/Attention") return nil, fmt.Errorf("kv ptr tables: %w", err2) } blockSize = blockSize2 kvType = kvType2 if ctx != nil && ctx.Scratch != nil { kDev, err = ctx.Scratch.GetUintptrSlice(len(kPtrs)) if err != nil { profile.End("Block/Attention") return nil, fmt.Errorf("scratch K ptr table: %w", err) } vDev, err = ctx.Scratch.GetUintptrSlice(len(vPtrs)) if err != nil { profile.End("Block/Attention") return nil, fmt.Errorf("scratch V ptr table: %w", err) } if err := cuda.MemcpyH2D(kDev, unsafe.Pointer(&kPtrs[0]), uintptr(len(kPtrs))*unsafe.Sizeof(uintptr(0)), gpu); err != nil { profile.End("Block/Attention") return nil, fmt.Errorf("memcpy K ptr table: %w", err) } if err := cuda.MemcpyH2D(vDev, unsafe.Pointer(&vPtrs[0]), uintptr(len(vPtrs))*unsafe.Sizeof(uintptr(0)), gpu); err != nil { profile.End("Block/Attention") return nil, fmt.Errorf("memcpy V ptr table: %w", err) } } else { freeAfter = true kDev, err = cuda.AllocAndCopyPtrTable(kPtrs, gpu) if err != nil { profile.End("Block/Attention") return nil, fmt.Errorf("alloc K ptr table: %w", err) } vDev, err = cuda.AllocAndCopyPtrTable(vPtrs, gpu) if err != nil { cuda.FreeDevicePtr(kDev) profile.End("Block/Attention") return nil, fmt.Errorf("alloc V ptr table: %w", err) } } } if freeAfter { defer cuda.FreeDevicePtr(kDev) defer cuda.FreeDevicePtr(vDev) } gpuQ, err := qOut.AsCUDA(gpu) if err != nil { profile.End("Block/Attention") return nil, fmt.Errorf("q to cuda: %w", err) } gpuOut, err := attnOut.AsCUDA(gpu) if err != nil { profile.End("Block/Attention") return nil, fmt.Errorf("attn out to cuda: %w", err) } // Single-request fast path: use the non-batch paged attention kernel for any seqLen. // This avoids per-call device allocations for blockOffsets/kvLens/queryPos. if kvType == tensor.Float16 { if useFusedRoPE { err = cuda.PagedAttentionRoPEF32F16KV( gpuQ.Data().(unsafe.Pointer), kDev, vDev, gpuOut.Data().(unsafe.Pointer), seqLen, kvLen, cfg.NumHeads, cfg.NumKVHeads, cfg.HeadDim, blockSize, scale, startPos, cfg.RopeTheta, gpu, ) } else { err = cuda.PagedAttentionF32F16KV( gpuQ.Data().(unsafe.Pointer), kDev, vDev, gpuOut.Data().(unsafe.Pointer), seqLen, kvLen, cfg.NumHeads, cfg.NumKVHeads, cfg.HeadDim, blockSize, scale, startPos, gpu, ) } } else { err = cuda.PagedAttention( gpuQ.Data().(unsafe.Pointer), kDev, vDev, gpuOut.Data().(unsafe.Pointer), seqLen, kvLen, cfg.NumHeads, cfg.NumKVHeads, cfg.HeadDim, blockSize, scale, startPos, gpu, ) } if err != nil { profile.End("Block/Attention") return nil, fmt.Errorf("paged attention: %w", err) } didPaged = true } } if !didPaged { // CPU path: do NOT materialize full K/V history. That is O(kvLen*kvDim) per step // and causes severe context-dependent slowdown. Instead, run attention directly // over the KV cache block views. if cache != nil && layerCacheDev.Type != tensor.CUDA { startPos := cache.SeqLen() qCPU, err := qOut.AsCPU() if err != nil { profile.End("Block/Attention") return nil, fmt.Errorf("q to cpu: %w", err) } kCPU, err := kOut.AsCPU() if err != nil { profile.End("Block/Attention") return nil, fmt.Errorf("k to cpu: %w", err) } vCPU, err := vOut.AsCPU() if err != nil { profile.End("Block/Attention") return nil, fmt.Errorf("v to cpu: %w", err) } outCPU, err := attnOut.AsCPU() if err != nil { profile.End("Block/Attention") return nil, fmt.Errorf("attn out to cpu: %w", err) } views, _, err := cache.Append(layer.Idx, kCPU, vCPU) if err != nil { profile.End("Block/Attention") return nil, fmt.Errorf("cache append: %w", err) } if pv, ok := cache.(kvcache.PackedViewsProvider); ok { pviews := pv.ViewsPacked(layer.Idx) if len(pviews) != 0 { if err := nn.CausalAttentionPackedBlocks(qCPU, pviews, outCPU, cfg.NumHeads, cfg.NumKVHeads, cfg.HeadDim, startPos); err != nil { profile.End("Block/Attention") return nil, fmt.Errorf("attention packed: %w", err) } } else { if err := nn.CausalAttentionBlocks(qCPU, views, outCPU, cfg.NumHeads, cfg.NumKVHeads, cfg.HeadDim, startPos); err != nil { profile.End("Block/Attention") return nil, fmt.Errorf("attention blocks: %w", err) } } } else { if err := nn.CausalAttentionBlocks(qCPU, views, outCPU, cfg.NumHeads, cfg.NumKVHeads, cfg.HeadDim, startPos); err != nil { profile.End("Block/Attention") return nil, fmt.Errorf("attention blocks: %w", err) } } // If the layer itself is on GPU, restore attention output to GPU so the output // projection runs on GPU even when KV cache is on CPU. if ctx != nil && ctx.IsGPU() { if _, err := attnOut.EnsureOn(ctx.Placement()); err != nil { profile.End("Block/Attention") return nil, fmt.Errorf("move attn out to gpu: %w", err) } } } else { fullK, fullV, startPos, err := gatherKV(ctx, cache, layer.Idx, kOut, vOut, cfg.NumKVHeads*cfg.HeadDim) if err != nil { profile.End("Block/Attention") return nil, err } if err := compute.HybridAttention(ctx, qOut, fullK, fullV, attnOut, cfg.NumHeads, cfg.NumKVHeads, cfg.HeadDim, scale, startPos); err != nil { profile.End("Block/Attention") return nil, fmt.Errorf("attention: %w", err) } } } profile.End("Block/Attention") attnProj, err := alloc(tensor.Shape{seqLen, cfg.HiddenSize}, ctx.Placement(), true) if err != nil { return nil, err } profile.Start("Block/OProj") if err := compute.HybridLinear(ctx, attnOut, layer.Wo, attnProj); err != nil { profile.End("Block/OProj") return nil, fmt.Errorf("o_proj: %w", err) } profile.End("Block/OProj") if err := compute.HybridAdd(ctx, residual, attnProj); err != nil { return nil, fmt.Errorf("residual 1: %w", err) } return residual, nil } func runHybridAttentionBlockBatch(ctx *compute.Context, alloc hybridAllocFn, hidden *compute.Activation, residual *compute.Activation, layer *HybridDecoderLayerWeights, positions []int, caches []kvcache.KVCacheInterface, cfg HybridDecoderConfig, eps float32) (*compute.Activation, error) { seqLen := hidden.Shape()[0] if len(caches) != seqLen { return nil, fmt.Errorf("caches len %d != seqLen %d", len(caches), seqLen) } if len(positions) != seqLen { return nil, fmt.Errorf("positions len %d != seqLen %d", len(positions), seqLen) } profile.Start("Block/AttnNorm") if err := compute.HybridRMSNorm(ctx, hidden, layer.AttnNorm, eps); err != nil { profile.End("Block/AttnNorm") return nil, fmt.Errorf("attn norm: %w", err) } profile.End("Block/AttnNorm") qOut, err := alloc(tensor.Shape{seqLen, cfg.NumHeads * cfg.HeadDim}, ctx.Placement(), true) if err != nil { return nil, err } kOut, err := alloc(tensor.Shape{seqLen, cfg.NumKVHeads * cfg.HeadDim}, ctx.Placement(), true) if err != nil { return nil, err } vOut, err := alloc(tensor.Shape{seqLen, cfg.NumKVHeads * cfg.HeadDim}, ctx.Placement(), true) if err != nil { return nil, err } profile.Start("Block/QProj") if err := compute.HybridLinear(ctx, hidden, layer.Wq, qOut); err != nil { profile.End("Block/QProj") return nil, fmt.Errorf("q_proj: %w", err) } profile.End("Block/QProj") profile.Start("Block/KProj") if err := compute.HybridLinear(ctx, hidden, layer.Wk, kOut); err != nil { profile.End("Block/KProj") return nil, fmt.Errorf("k_proj: %w", err) } profile.End("Block/KProj") profile.Start("Block/VProj") if err := compute.HybridLinear(ctx, hidden, layer.Wv, vOut); err != nil { profile.End("Block/VProj") return nil, fmt.Errorf("v_proj: %w", err) } profile.End("Block/VProj") if layer.QNorm != nil { profile.Start("Block/QNorm") if err := compute.HybridRMSNorm(ctx, qOut, layer.QNorm, eps); err != nil { profile.End("Block/QNorm") return nil, fmt.Errorf("q_norm: %w", err) } profile.End("Block/QNorm") } if layer.KNorm != nil { profile.Start("Block/KNorm") if err := compute.HybridRMSNorm(ctx, kOut, layer.KNorm, eps); err != nil { profile.End("Block/KNorm") return nil, fmt.Errorf("k_norm: %w", err) } profile.End("Block/KNorm") } // Only skip standalone RoPE when we will actually use fused RoPE inside CUDA paged attention. canPagedCUDA := false if ctx != nil && ctx.IsGPU() && cuda.Available() { gpu := ctx.Placement().GPU canPagedCUDA = true for i := 0; i < seqLen; i++ { if caches[i] == nil { canPagedCUDA = false break } pc, ok := caches[i].(*kvcache.PagedKVCache) if !ok || pc == nil || !pc.IsOnGPU() || pc.LayerDevice(layer.Idx).GPU != gpu { canPagedCUDA = false break } } } useFusedRoPE := canPagedCUDA && cfg.HeadDim <= 256 && (cfg.HeadDim&1) == 0 && cfg.RopeTheta != 0 if !useFusedRoPE { profile.Start("Block/RoPE_Q") if err := compute.HybridRoPE(ctx, qOut, positions, cfg.HeadDim, cfg.RopeTheta); err != nil { profile.End("Block/RoPE_Q") return nil, fmt.Errorf("rope q: %w", err) } profile.End("Block/RoPE_Q") profile.Start("Block/RoPE_K") if err := compute.HybridRoPE(ctx, kOut, positions, cfg.HeadDim, cfg.RopeTheta); err != nil { profile.End("Block/RoPE_K") return nil, fmt.Errorf("rope k: %w", err) } profile.End("Block/RoPE_K") } attnOut, err := alloc(tensor.Shape{seqLen, cfg.NumHeads * cfg.HeadDim}, ctx.Placement(), true) if err != nil { return nil, err } scale := float32(1.0 / math.Sqrt(float64(cfg.HeadDim))) profile.Start("Block/Attention") didPaged := false if canPagedCUDA { gpu := ctx.Placement().GPU gpuQ, err := qOut.AsCUDA(gpu) if err != nil { profile.End("Block/Attention") return nil, fmt.Errorf("q to cuda: %w", err) } gpuK, err := kOut.AsCUDA(gpu) if err != nil { profile.End("Block/Attention") return nil, fmt.Errorf("k to cuda: %w", err) } gpuV, err := vOut.AsCUDA(gpu) if err != nil { profile.End("Block/Attention") return nil, fmt.Errorf("v to cuda: %w", err) } gpuOut, err := attnOut.AsCUDA(gpu) if err != nil { profile.End("Block/Attention") return nil, fmt.Errorf("attn out to cuda: %w", err) } flatKPtrs := make([]uintptr, 0) flatVPtrs := make([]uintptr, 0) blockOffsets := make([]int32, seqLen) kvLens := make([]int32, seqLen) queryPos := make([]int32, seqLen) maxKvLen := 0 var kvType tensor.DType blockSize := 0 kvDim := cfg.NumKVHeads * cfg.HeadDim for i := 0; i < seqLen; i++ { pc := caches[i].(*kvcache.PagedKVCache) startPos := pc.SeqLen() offBytes := uintptr(i * kvDim * 4) kView, err := gpuK.ViewAt(tensor.Shape{1, kvDim}, offBytes) if err != nil { profile.End("Block/Attention") return nil, fmt.Errorf("k view: %w", err) } vView, err := gpuV.ViewAt(tensor.Shape{1, kvDim}, offBytes) if err != nil { profile.End("Block/Attention") return nil, fmt.Errorf("v view: %w", err) } if _, _, err := pc.Append(layer.Idx, kView, vView); err != nil { profile.End("Block/Attention") return nil, fmt.Errorf("cache append: %w", err) } kvLen := startPos + 1 if kvLen > maxKvLen { maxKvLen = kvLen } bs := pc.BlockSize() if blockSize == 0 { blockSize = bs } else if blockSize != bs { profile.End("Block/Attention") return nil, fmt.Errorf("mixed block sizes in batch: %d vs %d", blockSize, bs) } nBlocks := (kvLen + bs - 1) / bs kPtrs, vPtrs, _, curType, err := pc.LayerBlockPtrTables(layer.Idx, nBlocks) if err != nil { profile.End("Block/Attention") return nil, fmt.Errorf("kv ptr tables: %w", err) } if kvType == 0 { kvType = curType } else if kvType != curType { profile.End("Block/Attention") return nil, fmt.Errorf("mixed KV dtypes in batch: %v vs %v", kvType, curType) } blockOffsets[i] = int32(len(flatKPtrs)) flatKPtrs = append(flatKPtrs, kPtrs...) flatVPtrs = append(flatVPtrs, vPtrs...) kvLens[i] = int32(kvLen) queryPos[i] = int32(startPos) } if blockSize <= 0 { profile.End("Block/Attention") return nil, fmt.Errorf("batched attention: invalid blockSize %d", blockSize) } if len(flatKPtrs) == 0 || len(flatVPtrs) == 0 { profile.End("Block/Attention") return nil, fmt.Errorf("batched attention: empty KV ptr tables") } var ( kFlatDev unsafe.Pointer vFlatDev unsafe.Pointer offDev unsafe.Pointer kvDev unsafe.Pointer qposDev unsafe.Pointer freeAfter bool ) if ctx != nil && ctx.Scratch != nil { kFlatDev, err = ctx.Scratch.GetUintptrSlice(len(flatKPtrs)) if err != nil { profile.End("Block/Attention") return nil, fmt.Errorf("scratch K ptr table: %w", err) } vFlatDev, err = ctx.Scratch.GetUintptrSlice(len(flatVPtrs)) if err != nil { profile.End("Block/Attention") return nil, fmt.Errorf("scratch V ptr table: %w", err) } offDev, err = ctx.Scratch.GetInt32Slice(len(blockOffsets)) if err != nil { profile.End("Block/Attention") return nil, fmt.Errorf("scratch offsets: %w", err) } kvDev, err = ctx.Scratch.GetInt32Slice(len(kvLens)) if err != nil { profile.End("Block/Attention") return nil, fmt.Errorf("scratch kv lens: %w", err) } qposDev, err = ctx.Scratch.GetInt32Slice(len(queryPos)) if err != nil { profile.End("Block/Attention") return nil, fmt.Errorf("scratch query pos: %w", err) } if err := cuda.MemcpyH2D(kFlatDev, unsafe.Pointer(&flatKPtrs[0]), uintptr(len(flatKPtrs))*unsafe.Sizeof(uintptr(0)), gpu); err != nil { profile.End("Block/Attention") return nil, fmt.Errorf("memcpy K ptr table: %w", err) } if err := cuda.MemcpyH2D(vFlatDev, unsafe.Pointer(&flatVPtrs[0]), uintptr(len(flatVPtrs))*unsafe.Sizeof(uintptr(0)), gpu); err != nil { profile.End("Block/Attention") return nil, fmt.Errorf("memcpy V ptr table: %w", err) } if err := cuda.MemcpyH2D(offDev, unsafe.Pointer(&blockOffsets[0]), uintptr(len(blockOffsets))*4, gpu); err != nil { profile.End("Block/Attention") return nil, fmt.Errorf("memcpy offsets: %w", err) } if err := cuda.MemcpyH2D(kvDev, unsafe.Pointer(&kvLens[0]), uintptr(len(kvLens))*4, gpu); err != nil { profile.End("Block/Attention") return nil, fmt.Errorf("memcpy kv lens: %w", err) } if err := cuda.MemcpyH2D(qposDev, unsafe.Pointer(&queryPos[0]), uintptr(len(queryPos))*4, gpu); err != nil { profile.End("Block/Attention") return nil, fmt.Errorf("memcpy query pos: %w", err) } } else { freeAfter = true kFlatDev, err = cuda.AllocAndCopyPtrTable(flatKPtrs, gpu) if err != nil { profile.End("Block/Attention") return nil, fmt.Errorf("alloc K ptr table: %w", err) } vFlatDev, err = cuda.AllocAndCopyPtrTable(flatVPtrs, gpu) if err != nil { cuda.FreeDevicePtr(kFlatDev) profile.End("Block/Attention") return nil, fmt.Errorf("alloc V ptr table: %w", err) } offDev, err = cuda.AllocAndCopyInt32(blockOffsets, gpu) if err != nil { cuda.FreeDevicePtr(kFlatDev) cuda.FreeDevicePtr(vFlatDev) profile.End("Block/Attention") return nil, fmt.Errorf("alloc offsets: %w", err) } kvDev, err = cuda.AllocAndCopyInt32(kvLens, gpu) if err != nil { cuda.FreeDevicePtr(kFlatDev) cuda.FreeDevicePtr(vFlatDev) cuda.FreeDevicePtr(offDev) profile.End("Block/Attention") return nil, fmt.Errorf("alloc kv lens: %w", err) } qposDev, err = cuda.AllocAndCopyInt32(queryPos, gpu) if err != nil { cuda.FreeDevicePtr(kFlatDev) cuda.FreeDevicePtr(vFlatDev) cuda.FreeDevicePtr(offDev) cuda.FreeDevicePtr(kvDev) profile.End("Block/Attention") return nil, fmt.Errorf("alloc query pos: %w", err) } } if freeAfter { defer cuda.FreeDevicePtr(kFlatDev) defer cuda.FreeDevicePtr(vFlatDev) defer cuda.FreeDevicePtr(offDev) defer cuda.FreeDevicePtr(kvDev) defer cuda.FreeDevicePtr(qposDev) } if kvType == tensor.Float16 { if useFusedRoPE { err = cuda.PagedAttentionBatchRoPEF32F16KV( gpuQ.Data().(unsafe.Pointer), kFlatDev, vFlatDev, offDev, kvDev, qposDev, gpuOut.Data().(unsafe.Pointer), seqLen, cfg.NumHeads, cfg.NumKVHeads, cfg.HeadDim, blockSize, scale, maxKvLen, cfg.RopeTheta, gpu, ) } else { err = cuda.PagedAttentionBatchF32F16KV( gpuQ.Data().(unsafe.Pointer), kFlatDev, vFlatDev, offDev, kvDev, qposDev, gpuOut.Data().(unsafe.Pointer), seqLen, cfg.NumHeads, cfg.NumKVHeads, cfg.HeadDim, blockSize, scale, maxKvLen, gpu, ) } } else { err = cuda.PagedAttentionBatch( gpuQ.Data().(unsafe.Pointer), kFlatDev, vFlatDev, offDev, kvDev, qposDev, gpuOut.Data().(unsafe.Pointer), seqLen, cfg.NumHeads, cfg.NumKVHeads, cfg.HeadDim, blockSize, scale, maxKvLen, gpu, ) } if err != nil { profile.End("Block/Attention") return nil, fmt.Errorf("paged attention batch: %w", err) } didPaged = true } if !didPaged { // CPU KV cache path (supports running the layer on GPU while keeping KV on CPU). qCPU, err := qOut.AsCPU() if err != nil { profile.End("Block/Attention") return nil, fmt.Errorf("q to cpu: %w", err) } kCPU, err := kOut.AsCPU() if err != nil { profile.End("Block/Attention") return nil, fmt.Errorf("k to cpu: %w", err) } vCPU, err := vOut.AsCPU() if err != nil { profile.End("Block/Attention") return nil, fmt.Errorf("v to cpu: %w", err) } outCPU, err := attnOut.AsCPU() if err != nil { profile.End("Block/Attention") return nil, fmt.Errorf("attn out to cpu: %w", err) } qStride := cfg.NumHeads * cfg.HeadDim kvDim := cfg.NumKVHeads * cfg.HeadDim kAll := kCPU.DataFloat32() vAll := vCPU.DataFloat32() outAll := outCPU.DataFloat32() queryPos := make([]int, seqLen) packedViews := make([][]kvcache.PackedView, seqLen) viewsByToken := make([][]kvcache.View, seqLen) allPacked := true for i := 0; i < seqLen; i++ { if caches[i] == nil { profile.End("Block/Attention") return nil, fmt.Errorf("batched attention requires non-nil KV cache for token %d", i) } layerDev := caches[i].LayerDevice(layer.Idx).Normalize() if layerDev.Type == tensor.CUDA { profile.End("Block/Attention") return nil, fmt.Errorf("CPU batched attention requires CPU KV cache for token %d (got %v)", i, layerDev) } queryPos[i] = caches[i].SeqLen() kRow := cpu.NewTensor(tensor.Shape{1, kvDim}, kAll[i*kvDim:(i+1)*kvDim]) vRow := cpu.NewTensor(tensor.Shape{1, kvDim}, vAll[i*kvDim:(i+1)*kvDim]) views, _, err := caches[i].Append(layer.Idx, kRow, vRow) if err != nil { profile.End("Block/Attention") return nil, fmt.Errorf("cache append: %w", err) } viewsByToken[i] = views if pv, ok := caches[i].(kvcache.PackedViewsProvider); ok { p := pv.ViewsPacked(layer.Idx) if len(p) > 0 { packedViews[i] = p } else { allPacked = false } } else { allPacked = false } } if allPacked { if err := nn.CausalAttentionPackedBlocksBatch(qCPU, packedViews, outCPU, cfg.NumHeads, cfg.NumKVHeads, cfg.HeadDim, queryPos); err != nil { profile.End("Block/Attention") return nil, fmt.Errorf("attention packed batch: %w", err) } } else { for i := 0; i < seqLen; i++ { qRow := cpu.NewTensor(tensor.Shape{1, qStride}, qCPU.DataFloat32()[i*qStride:(i+1)*qStride]) outRow := cpu.NewTensor(tensor.Shape{1, qStride}, outAll[i*qStride:(i+1)*qStride]) if len(packedViews[i]) != 0 { if err := nn.CausalAttentionPackedBlocks(qRow, packedViews[i], outRow, cfg.NumHeads, cfg.NumKVHeads, cfg.HeadDim, queryPos[i]); err != nil { profile.End("Block/Attention") return nil, fmt.Errorf("attention packed: %w", err) } continue } if err := nn.CausalAttentionBlocks(qRow, viewsByToken[i], outRow, cfg.NumHeads, cfg.NumKVHeads, cfg.HeadDim, queryPos[i]); err != nil { profile.End("Block/Attention") return nil, fmt.Errorf("attention blocks: %w", err) } } } // Restore attention output to GPU so the output projection stays on GPU. if ctx != nil && ctx.IsGPU() { if _, err := attnOut.EnsureOn(ctx.Placement()); err != nil { profile.End("Block/Attention") return nil, fmt.Errorf("move attn out to gpu: %w", err) } } didPaged = true } profile.End("Block/Attention") attnProj, err := alloc(tensor.Shape{seqLen, cfg.HiddenSize}, ctx.Placement(), true) if err != nil { return nil, err } profile.Start("Block/OProj") if err := compute.HybridLinear(ctx, attnOut, layer.Wo, attnProj); err != nil { profile.End("Block/OProj") return nil, fmt.Errorf("o_proj: %w", err) } profile.End("Block/OProj") if err := compute.HybridAdd(ctx, residual, attnProj); err != nil { return nil, fmt.Errorf("residual 1: %w", err) } return residual, nil } func runHybridMLPBlock(ctx *compute.Context, alloc hybridAllocFn, postAttn *compute.Activation, layer *HybridDecoderLayerWeights, cfg HybridDecoderConfig, eps float32) (*compute.Activation, error) { seqLen := postAttn.Shape()[0] residual2, err := cloneActivation(ctx, alloc, postAttn, false) if err != nil { return nil, err } profile.Start("Block/MlpNorm") if err := compute.HybridRMSNorm(ctx, postAttn, layer.MlpNorm, eps); err != nil { profile.End("Block/MlpNorm") return nil, fmt.Errorf("mlp norm: %w", err) } profile.End("Block/MlpNorm") gate, err := alloc(tensor.Shape{seqLen, cfg.Intermediate}, ctx.Placement(), true) if err != nil { return nil, err } up, err := alloc(tensor.Shape{seqLen, cfg.Intermediate}, ctx.Placement(), true) if err != nil { return nil, err } profile.Start("Block/GateProj") if err := compute.HybridLinear(ctx, postAttn, layer.WGate, gate); err != nil { profile.End("Block/GateProj") return nil, fmt.Errorf("gate proj: %w", err) } profile.End("Block/GateProj") profile.Start("Block/UpProj") if err := compute.HybridLinear(ctx, postAttn, layer.WUp, up); err != nil { profile.End("Block/UpProj") return nil, fmt.Errorf("up proj: %w", err) } profile.End("Block/UpProj") profile.Start("Block/SwiGLU") act, err := alloc(tensor.Shape{seqLen, cfg.Intermediate}, ctx.Placement(), true) if err != nil { profile.End("Block/SwiGLU") return nil, err } if err := compute.HybridSwiGLU(ctx, gate, up, act); err != nil { profile.End("Block/SwiGLU") return nil, err } profile.End("Block/SwiGLU") mlpOut, err := alloc(tensor.Shape{seqLen, cfg.HiddenSize}, ctx.Placement(), true) if err != nil { return nil, err } profile.Start("Block/DownProj") if err := compute.HybridLinear(ctx, act, layer.WDown, mlpOut); err != nil { profile.End("Block/DownProj") return nil, fmt.Errorf("down proj: %w", err) } profile.End("Block/DownProj") if err := compute.HybridAdd(ctx, residual2, mlpOut); err != nil { return nil, fmt.Errorf("residual 2: %w", err) } return residual2, nil } func gatherKV(ctx *compute.Context, cache kvcache.KVCacheInterface, layerIdx int, kOut, vOut *compute.Activation, kvDim int) (*compute.Activation, *compute.Activation, int, error) { seqLen := kOut.Shape()[0] startPos := 0 if cache == nil { return kOut, vOut, startPos, nil } startPos = cache.SeqLen() layerDev := cache.LayerDevice(layerIdx).Normalize() if layerDev.Type == tensor.CUDA && ctx != nil && ctx.IsGPU() && cuda.Available() { profile.Instant("KVCache/GPU_path", profile.EventOp, "") if _, _, err := cache.Append(layerIdx, kOut.Tensor(), vOut.Tensor()); err != nil { return nil, nil, 0, fmt.Errorf("cache append: %w", err) } kvLen := startPos + seqLen if kView, vView, ok, err := cache.ContiguousKV(layerIdx, kvLen, kvDim); err != nil { return nil, nil, 0, fmt.Errorf("contiguous kv: %w", err) } else if ok { if kView == nil || vView == nil { return kOut, vOut, startPos, nil } return compute.NewActivationFrom(kView), compute.NewActivationFrom(vView), startPos, nil } kAct, vAct, err := concatKVOnDevice(ctx, cache.Views(layerIdx), startPos+seqLen, kvDim) if err != nil { return nil, nil, 0, fmt.Errorf("concat kv on device: %w", err) } // Append the current step's K/V to the end of the full buffer. // PagedKVCache views only include committed tokens, so we must manually append the new ones. kDst := kAct.Tensor().(*cuda.Tensor) vDst := vAct.Tensor().(*cuda.Tensor) kSrc := kOut.Tensor().(*cuda.Tensor) vSrc := vOut.Tensor().(*cuda.Tensor) dstOffset := startPos * kvDim copyLen := seqLen * kvDim if err := kDst.CopyPartialFromDevice(dstOffset, kSrc, 0, copyLen); err != nil { return nil, nil, 0, fmt.Errorf("copy current K: %w", err) } if err := vDst.CopyPartialFromDevice(dstOffset, vSrc, 0, copyLen); err != nil { return nil, nil, 0, fmt.Errorf("copy current V: %w", err) } return kAct, vAct, startPos, nil } profile.Instant("KVCache/CPU_path", profile.EventOp, fmt.Sprintf("layerDev=%v ctxGPU=%v", layerDev, ctx != nil && ctx.IsGPU())) kCPU, err := kOut.AsCPU() if err != nil { return nil, nil, 0, err } vCPU, err := vOut.AsCPU() if err != nil { return nil, nil, 0, err } views, _, err := cache.Append(layerIdx, kCPU, vCPU) if err != nil { return nil, nil, 0, fmt.Errorf("cache append: %w", err) } kvLen := startPos + seqLen fullKData := make([]float32, kvLen*kvDim) fullVData := make([]float32, kvLen*kvDim) for _, view := range views { kData, err := getViewData(view.K) if err != nil { return nil, nil, 0, fmt.Errorf("get K data: %w", err) } vData, err := getViewData(view.V) if err != nil { return nil, nil, 0, fmt.Errorf("get V data: %w", err) } for i := 0; i < view.Length; i++ { globalPos := view.Start + i copy(fullKData[globalPos*kvDim:(globalPos+1)*kvDim], kData[i*kvDim:(i+1)*kvDim]) copy(fullVData[globalPos*kvDim:(globalPos+1)*kvDim], vData[i*kvDim:(i+1)*kvDim]) } } fullK := compute.NewActivationFrom(cpu.NewTensor(tensor.Shape{kvLen, kvDim}, fullKData)) fullV := compute.NewActivationFrom(cpu.NewTensor(tensor.Shape{kvLen, kvDim}, fullVData)) if ctx != nil && ctx.IsGPU() { profile.Start("KVCache/FullK_H2D") if err := compute.EnsureOnDevice(fullK, ctx.Placement()); err != nil { profile.End("KVCache/FullK_H2D") return nil, nil, 0, err } profile.End("KVCache/FullK_H2D") profile.Start("KVCache/FullV_H2D") if err := compute.EnsureOnDevice(fullV, ctx.Placement()); err != nil { profile.End("KVCache/FullV_H2D") return nil, nil, 0, err } profile.End("KVCache/FullV_H2D") } return fullK, fullV, startPos, nil } func getViewData(t tensor.Tensor) ([]float32, error) { switch tt := t.(type) { case *cpu.Tensor: return tt.DataFloat32(), nil default: if copier, ok := t.(interface{ CopyToHost([]float32) error }); ok { data := make([]float32, t.Shape().NumElements()) if err := copier.CopyToHost(data); err != nil { return nil, err } return data, nil } return nil, fmt.Errorf("unsupported tensor type: %T", t) } } func concatKVOnDevice(ctx *compute.Context, views []kvcache.View, kvLen, kvDim int) (*compute.Activation, *compute.Activation, error) { gpu := ctx.Placement().GPU fullKAct, err := func() (*compute.Activation, error) { if ctx != nil && ctx.Scratch != nil { if a, err := ctx.Scratch.GetTensor(tensor.Shape{kvLen, kvDim}, tensor.Float32); err == nil { return a, nil } } return compute.NewActivation(tensor.Shape{kvLen, kvDim}, ctx.Placement()) }() if err != nil { return nil, nil, err } fullVAct, err := func() (*compute.Activation, error) { if ctx != nil && ctx.Scratch != nil { if a, err := ctx.Scratch.GetTensor(tensor.Shape{kvLen, kvDim}, tensor.Float32); err == nil { return a, nil } } return compute.NewActivation(tensor.Shape{kvLen, kvDim}, ctx.Placement()) }() if err != nil { return nil, nil, err } fullKGPU, err := fullKAct.AsCUDA(gpu) if err != nil { return nil, nil, err } fullVGPU, err := fullVAct.AsCUDA(gpu) if err != nil { return nil, nil, err } for _, view := range views { kSrc, ok := view.K.(*cuda.Tensor) if !ok { return nil, nil, fmt.Errorf("expected CUDA tensor for K view, got %T", view.K) } vSrc, ok := view.V.(*cuda.Tensor) if !ok { return nil, nil, fmt.Errorf("expected CUDA tensor for V view, got %T", view.V) } dstStart := view.Start * kvDim length := view.Length * kvDim if err := fullKGPU.CopyPartialFromDevice(dstStart, kSrc, 0, length); err != nil { return nil, nil, fmt.Errorf("copy K view: %w", err) } if err := fullVGPU.CopyPartialFromDevice(dstStart, vSrc, 0, length); err != nil { return nil, nil, fmt.Errorf("copy V view: %w", err) } } return fullKAct, fullVAct, nil }