| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187 |
- package arch
- import (
- "fmt"
- "math"
- "unsafe"
- "makarna/pkg/backend/cpu"
- "makarna/pkg/backend/cpu/nn"
- "makarna/pkg/backend/cuda"
- "makarna/pkg/compute"
- "makarna/pkg/kvcache"
- "makarna/pkg/profile"
- "makarna/pkg/tensor"
- )
- type HybridDecoderLayerWeights struct {
- Idx int
- AttnNorm tensor.Tensor
- Wq tensor.Tensor
- Wk tensor.Tensor
- Wv tensor.Tensor
- Wo tensor.Tensor
- QNorm tensor.Tensor
- KNorm tensor.Tensor
- MlpNorm tensor.Tensor
- WGate tensor.Tensor
- WUp tensor.Tensor
- WDown tensor.Tensor
- }
- type HybridDecoderConfig struct {
- HiddenSize int
- NumHeads int
- NumKVHeads int
- Intermediate int
- HeadDim int
- RopeTheta float32
- }
- func HybridDecoderBlock(ctx *compute.Context, hidden *compute.Activation, layer *HybridDecoderLayerWeights, positions []int, cache kvcache.KVCacheInterface, cfg HybridDecoderConfig, eps float32) error {
- cfg = normalizeHybridDecoderConfig(cfg)
- if err := ensureActivationOnContextDevice(ctx, hidden); err != nil {
- return err
- }
- alloc := makeHybridAllocator(ctx)
- residual, err := cloneActivation(ctx, alloc, hidden, false)
- if err != nil {
- return err
- }
- postAttn, err := runHybridAttentionBlock(ctx, alloc, hidden, residual, layer, positions, cache, cfg, eps)
- if err != nil {
- return err
- }
- out, err := runHybridMLPBlock(ctx, alloc, postAttn, layer, cfg, eps)
- if err != nil {
- return err
- }
- hidden.ReplaceWith(out.Tensor())
- return nil
- }
- func HybridDecoderBlockBatch(ctx *compute.Context, hidden *compute.Activation, layer *HybridDecoderLayerWeights, positions []int, caches []kvcache.KVCacheInterface, cfg HybridDecoderConfig, eps float32) error {
- cfg = normalizeHybridDecoderConfig(cfg)
- if err := ensureActivationOnContextDevice(ctx, hidden); err != nil {
- return err
- }
- if len(caches) != hidden.Shape()[0] {
- return fmt.Errorf("caches len %d != hidden batch %d", len(caches), hidden.Shape()[0])
- }
- alloc := makeHybridAllocator(ctx)
- residual, err := cloneActivation(ctx, alloc, hidden, false)
- if err != nil {
- return err
- }
- postAttn, err := runHybridAttentionBlockBatch(ctx, alloc, hidden, residual, layer, positions, caches, cfg, eps)
- if err != nil {
- return err
- }
- out, err := runHybridMLPBlock(ctx, alloc, postAttn, layer, cfg, eps)
- if err != nil {
- return err
- }
- hidden.ReplaceWith(out.Tensor())
- return nil
- }
- func normalizeHybridDecoderConfig(cfg HybridDecoderConfig) HybridDecoderConfig {
- if cfg.NumKVHeads == 0 {
- cfg.NumKVHeads = cfg.NumHeads
- }
- if cfg.Intermediate == 0 {
- cfg.Intermediate = cfg.HiddenSize * 4
- }
- return cfg
- }
- func ensureActivationOnContextDevice(ctx *compute.Context, a *compute.Activation) error {
- if ctx == nil {
- return nil
- }
- if _, err := a.EnsureOn(ctx.Placement()); err != nil {
- return fmt.Errorf("move hidden to target device: %w", err)
- }
- return nil
- }
- type hybridAllocFn func(shape tensor.Shape, placement tensor.DevicePlacement, allowScratch bool) (*compute.Activation, error)
- func makeHybridAllocator(ctx *compute.Context) hybridAllocFn {
- return func(shape tensor.Shape, placement tensor.DevicePlacement, allowScratch bool) (*compute.Activation, error) {
- placement = placement.Normalize()
- if allowScratch && ctx != nil && ctx.Scratch != nil && placement.Type == tensor.CUDA && ctx.Scratch.GPU() == placement.GPU {
- if a, err := ctx.Scratch.GetTensor(shape, tensor.Float32); err == nil {
- return a, nil
- }
- }
- return compute.NewActivation(shape, placement)
- }
- }
- func cloneActivation(ctx *compute.Context, alloc hybridAllocFn, src *compute.Activation, allowScratch bool) (*compute.Activation, error) {
- dst, err := alloc(src.Shape(), src.Placement(), allowScratch)
- if err != nil {
- return nil, fmt.Errorf("alloc residual: %w", err)
- }
- if err := compute.HybridCopy(ctx, dst, src); err != nil {
- return nil, fmt.Errorf("copy residual: %w", err)
- }
- return dst, nil
- }
- func runHybridAttentionBlock(ctx *compute.Context, alloc hybridAllocFn, hidden *compute.Activation, residual *compute.Activation, layer *HybridDecoderLayerWeights, positions []int, cache kvcache.KVCacheInterface, cfg HybridDecoderConfig, eps float32) (*compute.Activation, error) {
- seqLen := hidden.Shape()[0]
- profile.Start("Block/AttnNorm")
- if err := compute.HybridRMSNorm(ctx, hidden, layer.AttnNorm, eps); err != nil {
- profile.End("Block/AttnNorm")
- return nil, fmt.Errorf("attn norm: %w", err)
- }
- profile.End("Block/AttnNorm")
- qOut, err := alloc(tensor.Shape{seqLen, cfg.NumHeads * cfg.HeadDim}, ctx.Placement(), true)
- if err != nil {
- return nil, err
- }
- kOut, err := alloc(tensor.Shape{seqLen, cfg.NumKVHeads * cfg.HeadDim}, ctx.Placement(), true)
- if err != nil {
- return nil, err
- }
- vOut, err := alloc(tensor.Shape{seqLen, cfg.NumKVHeads * cfg.HeadDim}, ctx.Placement(), true)
- if err != nil {
- return nil, err
- }
- profile.Start("Block/QProj")
- if err := compute.HybridLinear(ctx, hidden, layer.Wq, qOut); err != nil {
- profile.End("Block/QProj")
- return nil, fmt.Errorf("q_proj: %w", err)
- }
- profile.End("Block/QProj")
- profile.Start("Block/KProj")
- if err := compute.HybridLinear(ctx, hidden, layer.Wk, kOut); err != nil {
- profile.End("Block/KProj")
- return nil, fmt.Errorf("k_proj: %w", err)
- }
- profile.End("Block/KProj")
- profile.Start("Block/VProj")
- if err := compute.HybridLinear(ctx, hidden, layer.Wv, vOut); err != nil {
- profile.End("Block/VProj")
- return nil, fmt.Errorf("v_proj: %w", err)
- }
- profile.End("Block/VProj")
- if layer.QNorm != nil {
- profile.Start("Block/QNorm")
- if err := compute.HybridRMSNorm(ctx, qOut, layer.QNorm, eps); err != nil {
- profile.End("Block/QNorm")
- return nil, fmt.Errorf("q_norm: %w", err)
- }
- profile.End("Block/QNorm")
- }
- if layer.KNorm != nil {
- profile.Start("Block/KNorm")
- if err := compute.HybridRMSNorm(ctx, kOut, layer.KNorm, eps); err != nil {
- profile.End("Block/KNorm")
- return nil, fmt.Errorf("k_norm: %w", err)
- }
- profile.End("Block/KNorm")
- }
- // Decide whether we can fuse RoPE inside the paged attention kernel.
- // This skips the standalone RoPE kernel launches and avoids an extra read/modify/write of Q/K.
- layerCacheDev := tensor.DevicePlacement{Type: tensor.CPU, GPU: -1}
- var pc *kvcache.PagedKVCache
- useFusedRoPE := false
- if cache != nil {
- layerCacheDev = cache.LayerDevice(layer.Idx).Normalize()
- if p, ok := cache.(*kvcache.PagedKVCache); ok {
- pc = p
- }
- }
- if pc != nil && layerCacheDev.Type == tensor.CUDA && ctx != nil && ctx.IsGPU() && cuda.Available() {
- gpu := ctx.Placement().GPU
- // Fused kernel supports headDim<=256 and even headDim.
- if pc.LayerDevice(layer.Idx).GPU == gpu && cfg.HeadDim <= 256 && (cfg.HeadDim&1) == 0 && cfg.RopeTheta != 0 {
- useFusedRoPE = true
- }
- }
- if !useFusedRoPE {
- profile.Start("Block/RoPE_Q")
- if err := compute.HybridRoPE(ctx, qOut, positions, cfg.HeadDim, cfg.RopeTheta); err != nil {
- profile.End("Block/RoPE_Q")
- return nil, fmt.Errorf("rope q: %w", err)
- }
- profile.End("Block/RoPE_Q")
- profile.Start("Block/RoPE_K")
- if err := compute.HybridRoPE(ctx, kOut, positions, cfg.HeadDim, cfg.RopeTheta); err != nil {
- profile.End("Block/RoPE_K")
- return nil, fmt.Errorf("rope k: %w", err)
- }
- profile.End("Block/RoPE_K")
- }
- attnOut, err := alloc(tensor.Shape{seqLen, cfg.NumHeads * cfg.HeadDim}, ctx.Placement(), true)
- if err != nil {
- return nil, err
- }
- scale := float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
- profile.Start("Block/Attention")
- didPaged := false
- if cache != nil && layerCacheDev.Type == tensor.CUDA && ctx != nil && ctx.IsGPU() && cuda.Available() {
- if pc != nil {
- startPos := cache.SeqLen()
- gpu := ctx.Placement().GPU
- if pc.LayerDevice(layer.Idx).GPU != gpu {
- profile.End("Block/Attention")
- return nil, fmt.Errorf("paged attention requires cache layer %d on gpu %d (got gpu %d)", layer.Idx, gpu, pc.LayerDevice(layer.Idx).GPU)
- }
- if _, _, err := cache.Append(layer.Idx, kOut.Tensor(), vOut.Tensor()); err != nil {
- profile.End("Block/Attention")
- return nil, fmt.Errorf("cache append: %w", err)
- }
- kvLen := startPos + seqLen
- bs := pc.BlockSize()
- numBlocks := (kvLen + bs - 1) / bs
- var (
- kDev unsafe.Pointer
- vDev unsafe.Pointer
- freeAfter bool
- kvType tensor.DType
- blockSize int
- )
- // Prefer persistent per-layer device pointer tables (rebuilt only when numBlocks grows).
- kDev, vDev, blockSize, kvType, err = pc.LayerDevicePtrTables(layer.Idx, numBlocks)
- if err != nil {
- // Fallback to per-call scratch allocation/copy if needed.
- kPtrs, vPtrs, blockSize2, kvType2, err2 := pc.LayerBlockPtrTables(layer.Idx, numBlocks)
- if err2 != nil {
- profile.End("Block/Attention")
- return nil, fmt.Errorf("kv ptr tables: %w", err2)
- }
- blockSize = blockSize2
- kvType = kvType2
- if ctx != nil && ctx.Scratch != nil {
- kDev, err = ctx.Scratch.GetUintptrSlice(len(kPtrs))
- if err != nil {
- profile.End("Block/Attention")
- return nil, fmt.Errorf("scratch K ptr table: %w", err)
- }
- vDev, err = ctx.Scratch.GetUintptrSlice(len(vPtrs))
- if err != nil {
- profile.End("Block/Attention")
- return nil, fmt.Errorf("scratch V ptr table: %w", err)
- }
- if err := cuda.MemcpyH2D(kDev, unsafe.Pointer(&kPtrs[0]), uintptr(len(kPtrs))*unsafe.Sizeof(uintptr(0)), gpu); err != nil {
- profile.End("Block/Attention")
- return nil, fmt.Errorf("memcpy K ptr table: %w", err)
- }
- if err := cuda.MemcpyH2D(vDev, unsafe.Pointer(&vPtrs[0]), uintptr(len(vPtrs))*unsafe.Sizeof(uintptr(0)), gpu); err != nil {
- profile.End("Block/Attention")
- return nil, fmt.Errorf("memcpy V ptr table: %w", err)
- }
- } else {
- freeAfter = true
- kDev, err = cuda.AllocAndCopyPtrTable(kPtrs, gpu)
- if err != nil {
- profile.End("Block/Attention")
- return nil, fmt.Errorf("alloc K ptr table: %w", err)
- }
- vDev, err = cuda.AllocAndCopyPtrTable(vPtrs, gpu)
- if err != nil {
- cuda.FreeDevicePtr(kDev)
- profile.End("Block/Attention")
- return nil, fmt.Errorf("alloc V ptr table: %w", err)
- }
- }
- }
- if freeAfter {
- defer cuda.FreeDevicePtr(kDev)
- defer cuda.FreeDevicePtr(vDev)
- }
- gpuQ, err := qOut.AsCUDA(gpu)
- if err != nil {
- profile.End("Block/Attention")
- return nil, fmt.Errorf("q to cuda: %w", err)
- }
- gpuOut, err := attnOut.AsCUDA(gpu)
- if err != nil {
- profile.End("Block/Attention")
- return nil, fmt.Errorf("attn out to cuda: %w", err)
- }
- // Single-request fast path: use the non-batch paged attention kernel for any seqLen.
- // This avoids per-call device allocations for blockOffsets/kvLens/queryPos.
- if kvType == tensor.Float16 {
- if useFusedRoPE {
- err = cuda.PagedAttentionRoPEF32F16KV(
- gpuQ.Data().(unsafe.Pointer),
- kDev,
- vDev,
- gpuOut.Data().(unsafe.Pointer),
- seqLen, kvLen,
- cfg.NumHeads, cfg.NumKVHeads, cfg.HeadDim,
- blockSize,
- scale, startPos,
- cfg.RopeTheta,
- gpu,
- )
- } else {
- err = cuda.PagedAttentionF32F16KV(
- gpuQ.Data().(unsafe.Pointer),
- kDev,
- vDev,
- gpuOut.Data().(unsafe.Pointer),
- seqLen, kvLen,
- cfg.NumHeads, cfg.NumKVHeads, cfg.HeadDim,
- blockSize,
- scale, startPos,
- gpu,
- )
- }
- } else {
- err = cuda.PagedAttention(
- gpuQ.Data().(unsafe.Pointer),
- kDev,
- vDev,
- gpuOut.Data().(unsafe.Pointer),
- seqLen, kvLen,
- cfg.NumHeads, cfg.NumKVHeads, cfg.HeadDim,
- blockSize,
- scale, startPos,
- gpu,
- )
- }
- if err != nil {
- profile.End("Block/Attention")
- return nil, fmt.Errorf("paged attention: %w", err)
- }
- didPaged = true
- }
- }
- if !didPaged {
- // CPU path: do NOT materialize full K/V history. That is O(kvLen*kvDim) per step
- // and causes severe context-dependent slowdown. Instead, run attention directly
- // over the KV cache block views.
- if cache != nil && layerCacheDev.Type != tensor.CUDA {
- startPos := cache.SeqLen()
- qCPU, err := qOut.AsCPU()
- if err != nil {
- profile.End("Block/Attention")
- return nil, fmt.Errorf("q to cpu: %w", err)
- }
- kCPU, err := kOut.AsCPU()
- if err != nil {
- profile.End("Block/Attention")
- return nil, fmt.Errorf("k to cpu: %w", err)
- }
- vCPU, err := vOut.AsCPU()
- if err != nil {
- profile.End("Block/Attention")
- return nil, fmt.Errorf("v to cpu: %w", err)
- }
- outCPU, err := attnOut.AsCPU()
- if err != nil {
- profile.End("Block/Attention")
- return nil, fmt.Errorf("attn out to cpu: %w", err)
- }
- views, _, err := cache.Append(layer.Idx, kCPU, vCPU)
- if err != nil {
- profile.End("Block/Attention")
- return nil, fmt.Errorf("cache append: %w", err)
- }
- if pv, ok := cache.(kvcache.PackedViewsProvider); ok {
- pviews := pv.ViewsPacked(layer.Idx)
- if len(pviews) != 0 {
- if err := nn.CausalAttentionPackedBlocks(qCPU, pviews, outCPU, cfg.NumHeads, cfg.NumKVHeads, cfg.HeadDim, startPos); err != nil {
- profile.End("Block/Attention")
- return nil, fmt.Errorf("attention packed: %w", err)
- }
- } else {
- if err := nn.CausalAttentionBlocks(qCPU, views, outCPU, cfg.NumHeads, cfg.NumKVHeads, cfg.HeadDim, startPos); err != nil {
- profile.End("Block/Attention")
- return nil, fmt.Errorf("attention blocks: %w", err)
- }
- }
- } else {
- if err := nn.CausalAttentionBlocks(qCPU, views, outCPU, cfg.NumHeads, cfg.NumKVHeads, cfg.HeadDim, startPos); err != nil {
- profile.End("Block/Attention")
- return nil, fmt.Errorf("attention blocks: %w", err)
- }
- }
- // If the layer itself is on GPU, restore attention output to GPU so the output
- // projection runs on GPU even when KV cache is on CPU.
- if ctx != nil && ctx.IsGPU() {
- if _, err := attnOut.EnsureOn(ctx.Placement()); err != nil {
- profile.End("Block/Attention")
- return nil, fmt.Errorf("move attn out to gpu: %w", err)
- }
- }
- } else {
- fullK, fullV, startPos, err := gatherKV(ctx, cache, layer.Idx, kOut, vOut, cfg.NumKVHeads*cfg.HeadDim)
- if err != nil {
- profile.End("Block/Attention")
- return nil, err
- }
- if err := compute.HybridAttention(ctx, qOut, fullK, fullV, attnOut, cfg.NumHeads, cfg.NumKVHeads, cfg.HeadDim, scale, startPos); err != nil {
- profile.End("Block/Attention")
- return nil, fmt.Errorf("attention: %w", err)
- }
- }
- }
- profile.End("Block/Attention")
- attnProj, err := alloc(tensor.Shape{seqLen, cfg.HiddenSize}, ctx.Placement(), true)
- if err != nil {
- return nil, err
- }
- profile.Start("Block/OProj")
- if err := compute.HybridLinear(ctx, attnOut, layer.Wo, attnProj); err != nil {
- profile.End("Block/OProj")
- return nil, fmt.Errorf("o_proj: %w", err)
- }
- profile.End("Block/OProj")
- if err := compute.HybridAdd(ctx, residual, attnProj); err != nil {
- return nil, fmt.Errorf("residual 1: %w", err)
- }
- return residual, nil
- }
- func runHybridAttentionBlockBatch(ctx *compute.Context, alloc hybridAllocFn, hidden *compute.Activation, residual *compute.Activation, layer *HybridDecoderLayerWeights, positions []int, caches []kvcache.KVCacheInterface, cfg HybridDecoderConfig, eps float32) (*compute.Activation, error) {
- seqLen := hidden.Shape()[0]
- if len(caches) != seqLen {
- return nil, fmt.Errorf("caches len %d != seqLen %d", len(caches), seqLen)
- }
- if len(positions) != seqLen {
- return nil, fmt.Errorf("positions len %d != seqLen %d", len(positions), seqLen)
- }
- profile.Start("Block/AttnNorm")
- if err := compute.HybridRMSNorm(ctx, hidden, layer.AttnNorm, eps); err != nil {
- profile.End("Block/AttnNorm")
- return nil, fmt.Errorf("attn norm: %w", err)
- }
- profile.End("Block/AttnNorm")
- qOut, err := alloc(tensor.Shape{seqLen, cfg.NumHeads * cfg.HeadDim}, ctx.Placement(), true)
- if err != nil {
- return nil, err
- }
- kOut, err := alloc(tensor.Shape{seqLen, cfg.NumKVHeads * cfg.HeadDim}, ctx.Placement(), true)
- if err != nil {
- return nil, err
- }
- vOut, err := alloc(tensor.Shape{seqLen, cfg.NumKVHeads * cfg.HeadDim}, ctx.Placement(), true)
- if err != nil {
- return nil, err
- }
- profile.Start("Block/QProj")
- if err := compute.HybridLinear(ctx, hidden, layer.Wq, qOut); err != nil {
- profile.End("Block/QProj")
- return nil, fmt.Errorf("q_proj: %w", err)
- }
- profile.End("Block/QProj")
- profile.Start("Block/KProj")
- if err := compute.HybridLinear(ctx, hidden, layer.Wk, kOut); err != nil {
- profile.End("Block/KProj")
- return nil, fmt.Errorf("k_proj: %w", err)
- }
- profile.End("Block/KProj")
- profile.Start("Block/VProj")
- if err := compute.HybridLinear(ctx, hidden, layer.Wv, vOut); err != nil {
- profile.End("Block/VProj")
- return nil, fmt.Errorf("v_proj: %w", err)
- }
- profile.End("Block/VProj")
- if layer.QNorm != nil {
- profile.Start("Block/QNorm")
- if err := compute.HybridRMSNorm(ctx, qOut, layer.QNorm, eps); err != nil {
- profile.End("Block/QNorm")
- return nil, fmt.Errorf("q_norm: %w", err)
- }
- profile.End("Block/QNorm")
- }
- if layer.KNorm != nil {
- profile.Start("Block/KNorm")
- if err := compute.HybridRMSNorm(ctx, kOut, layer.KNorm, eps); err != nil {
- profile.End("Block/KNorm")
- return nil, fmt.Errorf("k_norm: %w", err)
- }
- profile.End("Block/KNorm")
- }
- // Only skip standalone RoPE when we will actually use fused RoPE inside CUDA paged attention.
- canPagedCUDA := false
- if ctx != nil && ctx.IsGPU() && cuda.Available() {
- gpu := ctx.Placement().GPU
- canPagedCUDA = true
- for i := 0; i < seqLen; i++ {
- if caches[i] == nil {
- canPagedCUDA = false
- break
- }
- pc, ok := caches[i].(*kvcache.PagedKVCache)
- if !ok || pc == nil || !pc.IsOnGPU() || pc.LayerDevice(layer.Idx).GPU != gpu {
- canPagedCUDA = false
- break
- }
- }
- }
- useFusedRoPE := canPagedCUDA && cfg.HeadDim <= 256 && (cfg.HeadDim&1) == 0 && cfg.RopeTheta != 0
- if !useFusedRoPE {
- profile.Start("Block/RoPE_Q")
- if err := compute.HybridRoPE(ctx, qOut, positions, cfg.HeadDim, cfg.RopeTheta); err != nil {
- profile.End("Block/RoPE_Q")
- return nil, fmt.Errorf("rope q: %w", err)
- }
- profile.End("Block/RoPE_Q")
- profile.Start("Block/RoPE_K")
- if err := compute.HybridRoPE(ctx, kOut, positions, cfg.HeadDim, cfg.RopeTheta); err != nil {
- profile.End("Block/RoPE_K")
- return nil, fmt.Errorf("rope k: %w", err)
- }
- profile.End("Block/RoPE_K")
- }
- attnOut, err := alloc(tensor.Shape{seqLen, cfg.NumHeads * cfg.HeadDim}, ctx.Placement(), true)
- if err != nil {
- return nil, err
- }
- scale := float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
- profile.Start("Block/Attention")
- didPaged := false
- if canPagedCUDA {
- gpu := ctx.Placement().GPU
- gpuQ, err := qOut.AsCUDA(gpu)
- if err != nil {
- profile.End("Block/Attention")
- return nil, fmt.Errorf("q to cuda: %w", err)
- }
- gpuK, err := kOut.AsCUDA(gpu)
- if err != nil {
- profile.End("Block/Attention")
- return nil, fmt.Errorf("k to cuda: %w", err)
- }
- gpuV, err := vOut.AsCUDA(gpu)
- if err != nil {
- profile.End("Block/Attention")
- return nil, fmt.Errorf("v to cuda: %w", err)
- }
- gpuOut, err := attnOut.AsCUDA(gpu)
- if err != nil {
- profile.End("Block/Attention")
- return nil, fmt.Errorf("attn out to cuda: %w", err)
- }
- flatKPtrs := make([]uintptr, 0)
- flatVPtrs := make([]uintptr, 0)
- blockOffsets := make([]int32, seqLen)
- kvLens := make([]int32, seqLen)
- queryPos := make([]int32, seqLen)
- maxKvLen := 0
- var kvType tensor.DType
- blockSize := 0
- kvDim := cfg.NumKVHeads * cfg.HeadDim
- for i := 0; i < seqLen; i++ {
- pc := caches[i].(*kvcache.PagedKVCache)
- startPos := pc.SeqLen()
- offBytes := uintptr(i * kvDim * 4)
- kView, err := gpuK.ViewAt(tensor.Shape{1, kvDim}, offBytes)
- if err != nil {
- profile.End("Block/Attention")
- return nil, fmt.Errorf("k view: %w", err)
- }
- vView, err := gpuV.ViewAt(tensor.Shape{1, kvDim}, offBytes)
- if err != nil {
- profile.End("Block/Attention")
- return nil, fmt.Errorf("v view: %w", err)
- }
- if _, _, err := pc.Append(layer.Idx, kView, vView); err != nil {
- profile.End("Block/Attention")
- return nil, fmt.Errorf("cache append: %w", err)
- }
- kvLen := startPos + 1
- if kvLen > maxKvLen {
- maxKvLen = kvLen
- }
- bs := pc.BlockSize()
- if blockSize == 0 {
- blockSize = bs
- } else if blockSize != bs {
- profile.End("Block/Attention")
- return nil, fmt.Errorf("mixed block sizes in batch: %d vs %d", blockSize, bs)
- }
- nBlocks := (kvLen + bs - 1) / bs
- kPtrs, vPtrs, _, curType, err := pc.LayerBlockPtrTables(layer.Idx, nBlocks)
- if err != nil {
- profile.End("Block/Attention")
- return nil, fmt.Errorf("kv ptr tables: %w", err)
- }
- if kvType == 0 {
- kvType = curType
- } else if kvType != curType {
- profile.End("Block/Attention")
- return nil, fmt.Errorf("mixed KV dtypes in batch: %v vs %v", kvType, curType)
- }
- blockOffsets[i] = int32(len(flatKPtrs))
- flatKPtrs = append(flatKPtrs, kPtrs...)
- flatVPtrs = append(flatVPtrs, vPtrs...)
- kvLens[i] = int32(kvLen)
- queryPos[i] = int32(startPos)
- }
- if blockSize <= 0 {
- profile.End("Block/Attention")
- return nil, fmt.Errorf("batched attention: invalid blockSize %d", blockSize)
- }
- if len(flatKPtrs) == 0 || len(flatVPtrs) == 0 {
- profile.End("Block/Attention")
- return nil, fmt.Errorf("batched attention: empty KV ptr tables")
- }
- var (
- kFlatDev unsafe.Pointer
- vFlatDev unsafe.Pointer
- offDev unsafe.Pointer
- kvDev unsafe.Pointer
- qposDev unsafe.Pointer
- freeAfter bool
- )
- if ctx != nil && ctx.Scratch != nil {
- kFlatDev, err = ctx.Scratch.GetUintptrSlice(len(flatKPtrs))
- if err != nil {
- profile.End("Block/Attention")
- return nil, fmt.Errorf("scratch K ptr table: %w", err)
- }
- vFlatDev, err = ctx.Scratch.GetUintptrSlice(len(flatVPtrs))
- if err != nil {
- profile.End("Block/Attention")
- return nil, fmt.Errorf("scratch V ptr table: %w", err)
- }
- offDev, err = ctx.Scratch.GetInt32Slice(len(blockOffsets))
- if err != nil {
- profile.End("Block/Attention")
- return nil, fmt.Errorf("scratch offsets: %w", err)
- }
- kvDev, err = ctx.Scratch.GetInt32Slice(len(kvLens))
- if err != nil {
- profile.End("Block/Attention")
- return nil, fmt.Errorf("scratch kv lens: %w", err)
- }
- qposDev, err = ctx.Scratch.GetInt32Slice(len(queryPos))
- if err != nil {
- profile.End("Block/Attention")
- return nil, fmt.Errorf("scratch query pos: %w", err)
- }
- if err := cuda.MemcpyH2D(kFlatDev, unsafe.Pointer(&flatKPtrs[0]), uintptr(len(flatKPtrs))*unsafe.Sizeof(uintptr(0)), gpu); err != nil {
- profile.End("Block/Attention")
- return nil, fmt.Errorf("memcpy K ptr table: %w", err)
- }
- if err := cuda.MemcpyH2D(vFlatDev, unsafe.Pointer(&flatVPtrs[0]), uintptr(len(flatVPtrs))*unsafe.Sizeof(uintptr(0)), gpu); err != nil {
- profile.End("Block/Attention")
- return nil, fmt.Errorf("memcpy V ptr table: %w", err)
- }
- if err := cuda.MemcpyH2D(offDev, unsafe.Pointer(&blockOffsets[0]), uintptr(len(blockOffsets))*4, gpu); err != nil {
- profile.End("Block/Attention")
- return nil, fmt.Errorf("memcpy offsets: %w", err)
- }
- if err := cuda.MemcpyH2D(kvDev, unsafe.Pointer(&kvLens[0]), uintptr(len(kvLens))*4, gpu); err != nil {
- profile.End("Block/Attention")
- return nil, fmt.Errorf("memcpy kv lens: %w", err)
- }
- if err := cuda.MemcpyH2D(qposDev, unsafe.Pointer(&queryPos[0]), uintptr(len(queryPos))*4, gpu); err != nil {
- profile.End("Block/Attention")
- return nil, fmt.Errorf("memcpy query pos: %w", err)
- }
- } else {
- freeAfter = true
- kFlatDev, err = cuda.AllocAndCopyPtrTable(flatKPtrs, gpu)
- if err != nil {
- profile.End("Block/Attention")
- return nil, fmt.Errorf("alloc K ptr table: %w", err)
- }
- vFlatDev, err = cuda.AllocAndCopyPtrTable(flatVPtrs, gpu)
- if err != nil {
- cuda.FreeDevicePtr(kFlatDev)
- profile.End("Block/Attention")
- return nil, fmt.Errorf("alloc V ptr table: %w", err)
- }
- offDev, err = cuda.AllocAndCopyInt32(blockOffsets, gpu)
- if err != nil {
- cuda.FreeDevicePtr(kFlatDev)
- cuda.FreeDevicePtr(vFlatDev)
- profile.End("Block/Attention")
- return nil, fmt.Errorf("alloc offsets: %w", err)
- }
- kvDev, err = cuda.AllocAndCopyInt32(kvLens, gpu)
- if err != nil {
- cuda.FreeDevicePtr(kFlatDev)
- cuda.FreeDevicePtr(vFlatDev)
- cuda.FreeDevicePtr(offDev)
- profile.End("Block/Attention")
- return nil, fmt.Errorf("alloc kv lens: %w", err)
- }
- qposDev, err = cuda.AllocAndCopyInt32(queryPos, gpu)
- if err != nil {
- cuda.FreeDevicePtr(kFlatDev)
- cuda.FreeDevicePtr(vFlatDev)
- cuda.FreeDevicePtr(offDev)
- cuda.FreeDevicePtr(kvDev)
- profile.End("Block/Attention")
- return nil, fmt.Errorf("alloc query pos: %w", err)
- }
- }
- if freeAfter {
- defer cuda.FreeDevicePtr(kFlatDev)
- defer cuda.FreeDevicePtr(vFlatDev)
- defer cuda.FreeDevicePtr(offDev)
- defer cuda.FreeDevicePtr(kvDev)
- defer cuda.FreeDevicePtr(qposDev)
- }
- if kvType == tensor.Float16 {
- if useFusedRoPE {
- err = cuda.PagedAttentionBatchRoPEF32F16KV(
- gpuQ.Data().(unsafe.Pointer),
- kFlatDev,
- vFlatDev,
- offDev,
- kvDev,
- qposDev,
- gpuOut.Data().(unsafe.Pointer),
- seqLen,
- cfg.NumHeads, cfg.NumKVHeads, cfg.HeadDim,
- blockSize,
- scale,
- maxKvLen,
- cfg.RopeTheta,
- gpu,
- )
- } else {
- err = cuda.PagedAttentionBatchF32F16KV(
- gpuQ.Data().(unsafe.Pointer),
- kFlatDev,
- vFlatDev,
- offDev,
- kvDev,
- qposDev,
- gpuOut.Data().(unsafe.Pointer),
- seqLen,
- cfg.NumHeads, cfg.NumKVHeads, cfg.HeadDim,
- blockSize,
- scale,
- maxKvLen,
- gpu,
- )
- }
- } else {
- err = cuda.PagedAttentionBatch(
- gpuQ.Data().(unsafe.Pointer),
- kFlatDev,
- vFlatDev,
- offDev,
- kvDev,
- qposDev,
- gpuOut.Data().(unsafe.Pointer),
- seqLen,
- cfg.NumHeads, cfg.NumKVHeads, cfg.HeadDim,
- blockSize,
- scale,
- maxKvLen,
- gpu,
- )
- }
- if err != nil {
- profile.End("Block/Attention")
- return nil, fmt.Errorf("paged attention batch: %w", err)
- }
- didPaged = true
- }
- if !didPaged {
- // CPU KV cache path (supports running the layer on GPU while keeping KV on CPU).
- qCPU, err := qOut.AsCPU()
- if err != nil {
- profile.End("Block/Attention")
- return nil, fmt.Errorf("q to cpu: %w", err)
- }
- kCPU, err := kOut.AsCPU()
- if err != nil {
- profile.End("Block/Attention")
- return nil, fmt.Errorf("k to cpu: %w", err)
- }
- vCPU, err := vOut.AsCPU()
- if err != nil {
- profile.End("Block/Attention")
- return nil, fmt.Errorf("v to cpu: %w", err)
- }
- outCPU, err := attnOut.AsCPU()
- if err != nil {
- profile.End("Block/Attention")
- return nil, fmt.Errorf("attn out to cpu: %w", err)
- }
- qStride := cfg.NumHeads * cfg.HeadDim
- kvDim := cfg.NumKVHeads * cfg.HeadDim
- kAll := kCPU.DataFloat32()
- vAll := vCPU.DataFloat32()
- outAll := outCPU.DataFloat32()
- queryPos := make([]int, seqLen)
- packedViews := make([][]kvcache.PackedView, seqLen)
- viewsByToken := make([][]kvcache.View, seqLen)
- allPacked := true
- for i := 0; i < seqLen; i++ {
- if caches[i] == nil {
- profile.End("Block/Attention")
- return nil, fmt.Errorf("batched attention requires non-nil KV cache for token %d", i)
- }
- layerDev := caches[i].LayerDevice(layer.Idx).Normalize()
- if layerDev.Type == tensor.CUDA {
- profile.End("Block/Attention")
- return nil, fmt.Errorf("CPU batched attention requires CPU KV cache for token %d (got %v)", i, layerDev)
- }
- queryPos[i] = caches[i].SeqLen()
- kRow := cpu.NewTensor(tensor.Shape{1, kvDim}, kAll[i*kvDim:(i+1)*kvDim])
- vRow := cpu.NewTensor(tensor.Shape{1, kvDim}, vAll[i*kvDim:(i+1)*kvDim])
- views, _, err := caches[i].Append(layer.Idx, kRow, vRow)
- if err != nil {
- profile.End("Block/Attention")
- return nil, fmt.Errorf("cache append: %w", err)
- }
- viewsByToken[i] = views
- if pv, ok := caches[i].(kvcache.PackedViewsProvider); ok {
- p := pv.ViewsPacked(layer.Idx)
- if len(p) > 0 {
- packedViews[i] = p
- } else {
- allPacked = false
- }
- } else {
- allPacked = false
- }
- }
- if allPacked {
- if err := nn.CausalAttentionPackedBlocksBatch(qCPU, packedViews, outCPU, cfg.NumHeads, cfg.NumKVHeads, cfg.HeadDim, queryPos); err != nil {
- profile.End("Block/Attention")
- return nil, fmt.Errorf("attention packed batch: %w", err)
- }
- } else {
- for i := 0; i < seqLen; i++ {
- qRow := cpu.NewTensor(tensor.Shape{1, qStride}, qCPU.DataFloat32()[i*qStride:(i+1)*qStride])
- outRow := cpu.NewTensor(tensor.Shape{1, qStride}, outAll[i*qStride:(i+1)*qStride])
- if len(packedViews[i]) != 0 {
- if err := nn.CausalAttentionPackedBlocks(qRow, packedViews[i], outRow, cfg.NumHeads, cfg.NumKVHeads, cfg.HeadDim, queryPos[i]); err != nil {
- profile.End("Block/Attention")
- return nil, fmt.Errorf("attention packed: %w", err)
- }
- continue
- }
- if err := nn.CausalAttentionBlocks(qRow, viewsByToken[i], outRow, cfg.NumHeads, cfg.NumKVHeads, cfg.HeadDim, queryPos[i]); err != nil {
- profile.End("Block/Attention")
- return nil, fmt.Errorf("attention blocks: %w", err)
- }
- }
- }
- // Restore attention output to GPU so the output projection stays on GPU.
- if ctx != nil && ctx.IsGPU() {
- if _, err := attnOut.EnsureOn(ctx.Placement()); err != nil {
- profile.End("Block/Attention")
- return nil, fmt.Errorf("move attn out to gpu: %w", err)
- }
- }
- didPaged = true
- }
- profile.End("Block/Attention")
- attnProj, err := alloc(tensor.Shape{seqLen, cfg.HiddenSize}, ctx.Placement(), true)
- if err != nil {
- return nil, err
- }
- profile.Start("Block/OProj")
- if err := compute.HybridLinear(ctx, attnOut, layer.Wo, attnProj); err != nil {
- profile.End("Block/OProj")
- return nil, fmt.Errorf("o_proj: %w", err)
- }
- profile.End("Block/OProj")
- if err := compute.HybridAdd(ctx, residual, attnProj); err != nil {
- return nil, fmt.Errorf("residual 1: %w", err)
- }
- return residual, nil
- }
- func runHybridMLPBlock(ctx *compute.Context, alloc hybridAllocFn, postAttn *compute.Activation, layer *HybridDecoderLayerWeights, cfg HybridDecoderConfig, eps float32) (*compute.Activation, error) {
- seqLen := postAttn.Shape()[0]
- residual2, err := cloneActivation(ctx, alloc, postAttn, false)
- if err != nil {
- return nil, err
- }
- profile.Start("Block/MlpNorm")
- if err := compute.HybridRMSNorm(ctx, postAttn, layer.MlpNorm, eps); err != nil {
- profile.End("Block/MlpNorm")
- return nil, fmt.Errorf("mlp norm: %w", err)
- }
- profile.End("Block/MlpNorm")
- gate, err := alloc(tensor.Shape{seqLen, cfg.Intermediate}, ctx.Placement(), true)
- if err != nil {
- return nil, err
- }
- up, err := alloc(tensor.Shape{seqLen, cfg.Intermediate}, ctx.Placement(), true)
- if err != nil {
- return nil, err
- }
- profile.Start("Block/GateProj")
- if err := compute.HybridLinear(ctx, postAttn, layer.WGate, gate); err != nil {
- profile.End("Block/GateProj")
- return nil, fmt.Errorf("gate proj: %w", err)
- }
- profile.End("Block/GateProj")
- profile.Start("Block/UpProj")
- if err := compute.HybridLinear(ctx, postAttn, layer.WUp, up); err != nil {
- profile.End("Block/UpProj")
- return nil, fmt.Errorf("up proj: %w", err)
- }
- profile.End("Block/UpProj")
- profile.Start("Block/SwiGLU")
- act, err := alloc(tensor.Shape{seqLen, cfg.Intermediate}, ctx.Placement(), true)
- if err != nil {
- profile.End("Block/SwiGLU")
- return nil, err
- }
- if err := compute.HybridSwiGLU(ctx, gate, up, act); err != nil {
- profile.End("Block/SwiGLU")
- return nil, err
- }
- profile.End("Block/SwiGLU")
- mlpOut, err := alloc(tensor.Shape{seqLen, cfg.HiddenSize}, ctx.Placement(), true)
- if err != nil {
- return nil, err
- }
- profile.Start("Block/DownProj")
- if err := compute.HybridLinear(ctx, act, layer.WDown, mlpOut); err != nil {
- profile.End("Block/DownProj")
- return nil, fmt.Errorf("down proj: %w", err)
- }
- profile.End("Block/DownProj")
- if err := compute.HybridAdd(ctx, residual2, mlpOut); err != nil {
- return nil, fmt.Errorf("residual 2: %w", err)
- }
- return residual2, nil
- }
- func gatherKV(ctx *compute.Context, cache kvcache.KVCacheInterface, layerIdx int, kOut, vOut *compute.Activation, kvDim int) (*compute.Activation, *compute.Activation, int, error) {
- seqLen := kOut.Shape()[0]
- startPos := 0
- if cache == nil {
- return kOut, vOut, startPos, nil
- }
- startPos = cache.SeqLen()
- layerDev := cache.LayerDevice(layerIdx).Normalize()
- if layerDev.Type == tensor.CUDA && ctx != nil && ctx.IsGPU() && cuda.Available() {
- profile.Instant("KVCache/GPU_path", profile.EventOp, "")
- if _, _, err := cache.Append(layerIdx, kOut.Tensor(), vOut.Tensor()); err != nil {
- return nil, nil, 0, fmt.Errorf("cache append: %w", err)
- }
- kvLen := startPos + seqLen
- if kView, vView, ok, err := cache.ContiguousKV(layerIdx, kvLen, kvDim); err != nil {
- return nil, nil, 0, fmt.Errorf("contiguous kv: %w", err)
- } else if ok {
- if kView == nil || vView == nil {
- return kOut, vOut, startPos, nil
- }
- return compute.NewActivationFrom(kView), compute.NewActivationFrom(vView), startPos, nil
- }
- kAct, vAct, err := concatKVOnDevice(ctx, cache.Views(layerIdx), startPos+seqLen, kvDim)
- if err != nil {
- return nil, nil, 0, fmt.Errorf("concat kv on device: %w", err)
- }
- // Append the current step's K/V to the end of the full buffer.
- // PagedKVCache views only include committed tokens, so we must manually append the new ones.
- kDst := kAct.Tensor().(*cuda.Tensor)
- vDst := vAct.Tensor().(*cuda.Tensor)
- kSrc := kOut.Tensor().(*cuda.Tensor)
- vSrc := vOut.Tensor().(*cuda.Tensor)
- dstOffset := startPos * kvDim
- copyLen := seqLen * kvDim
- if err := kDst.CopyPartialFromDevice(dstOffset, kSrc, 0, copyLen); err != nil {
- return nil, nil, 0, fmt.Errorf("copy current K: %w", err)
- }
- if err := vDst.CopyPartialFromDevice(dstOffset, vSrc, 0, copyLen); err != nil {
- return nil, nil, 0, fmt.Errorf("copy current V: %w", err)
- }
- return kAct, vAct, startPos, nil
- }
- profile.Instant("KVCache/CPU_path", profile.EventOp, fmt.Sprintf("layerDev=%v ctxGPU=%v", layerDev, ctx != nil && ctx.IsGPU()))
- kCPU, err := kOut.AsCPU()
- if err != nil {
- return nil, nil, 0, err
- }
- vCPU, err := vOut.AsCPU()
- if err != nil {
- return nil, nil, 0, err
- }
- views, _, err := cache.Append(layerIdx, kCPU, vCPU)
- if err != nil {
- return nil, nil, 0, fmt.Errorf("cache append: %w", err)
- }
- kvLen := startPos + seqLen
- fullKData := make([]float32, kvLen*kvDim)
- fullVData := make([]float32, kvLen*kvDim)
- for _, view := range views {
- kData, err := getViewData(view.K)
- if err != nil {
- return nil, nil, 0, fmt.Errorf("get K data: %w", err)
- }
- vData, err := getViewData(view.V)
- if err != nil {
- return nil, nil, 0, fmt.Errorf("get V data: %w", err)
- }
- for i := 0; i < view.Length; i++ {
- globalPos := view.Start + i
- copy(fullKData[globalPos*kvDim:(globalPos+1)*kvDim], kData[i*kvDim:(i+1)*kvDim])
- copy(fullVData[globalPos*kvDim:(globalPos+1)*kvDim], vData[i*kvDim:(i+1)*kvDim])
- }
- }
- fullK := compute.NewActivationFrom(cpu.NewTensor(tensor.Shape{kvLen, kvDim}, fullKData))
- fullV := compute.NewActivationFrom(cpu.NewTensor(tensor.Shape{kvLen, kvDim}, fullVData))
- if ctx != nil && ctx.IsGPU() {
- profile.Start("KVCache/FullK_H2D")
- if err := compute.EnsureOnDevice(fullK, ctx.Placement()); err != nil {
- profile.End("KVCache/FullK_H2D")
- return nil, nil, 0, err
- }
- profile.End("KVCache/FullK_H2D")
- profile.Start("KVCache/FullV_H2D")
- if err := compute.EnsureOnDevice(fullV, ctx.Placement()); err != nil {
- profile.End("KVCache/FullV_H2D")
- return nil, nil, 0, err
- }
- profile.End("KVCache/FullV_H2D")
- }
- return fullK, fullV, startPos, nil
- }
- func getViewData(t tensor.Tensor) ([]float32, error) {
- switch tt := t.(type) {
- case *cpu.Tensor:
- return tt.DataFloat32(), nil
- default:
- if copier, ok := t.(interface{ CopyToHost([]float32) error }); ok {
- data := make([]float32, t.Shape().NumElements())
- if err := copier.CopyToHost(data); err != nil {
- return nil, err
- }
- return data, nil
- }
- return nil, fmt.Errorf("unsupported tensor type: %T", t)
- }
- }
- func concatKVOnDevice(ctx *compute.Context, views []kvcache.View, kvLen, kvDim int) (*compute.Activation, *compute.Activation, error) {
- gpu := ctx.Placement().GPU
- fullKAct, err := func() (*compute.Activation, error) {
- if ctx != nil && ctx.Scratch != nil {
- if a, err := ctx.Scratch.GetTensor(tensor.Shape{kvLen, kvDim}, tensor.Float32); err == nil {
- return a, nil
- }
- }
- return compute.NewActivation(tensor.Shape{kvLen, kvDim}, ctx.Placement())
- }()
- if err != nil {
- return nil, nil, err
- }
- fullVAct, err := func() (*compute.Activation, error) {
- if ctx != nil && ctx.Scratch != nil {
- if a, err := ctx.Scratch.GetTensor(tensor.Shape{kvLen, kvDim}, tensor.Float32); err == nil {
- return a, nil
- }
- }
- return compute.NewActivation(tensor.Shape{kvLen, kvDim}, ctx.Placement())
- }()
- if err != nil {
- return nil, nil, err
- }
- fullKGPU, err := fullKAct.AsCUDA(gpu)
- if err != nil {
- return nil, nil, err
- }
- fullVGPU, err := fullVAct.AsCUDA(gpu)
- if err != nil {
- return nil, nil, err
- }
- for _, view := range views {
- kSrc, ok := view.K.(*cuda.Tensor)
- if !ok {
- return nil, nil, fmt.Errorf("expected CUDA tensor for K view, got %T", view.K)
- }
- vSrc, ok := view.V.(*cuda.Tensor)
- if !ok {
- return nil, nil, fmt.Errorf("expected CUDA tensor for V view, got %T", view.V)
- }
- dstStart := view.Start * kvDim
- length := view.Length * kvDim
- if err := fullKGPU.CopyPartialFromDevice(dstStart, kSrc, 0, length); err != nil {
- return nil, nil, fmt.Errorf("copy K view: %w", err)
- }
- if err := fullVGPU.CopyPartialFromDevice(dstStart, vSrc, 0, length); err != nil {
- return nil, nil, fmt.Errorf("copy V view: %w", err)
- }
- }
- return fullKAct, fullVAct, nil
- }
|