| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782 |
- //go:build cuda
- // Package compute provides hybrid CPU/GPU neural network operations.
- // Operations automatically dispatch to the correct backend based on activation placement.
- package compute
- import (
- "fmt"
- "math"
- "unsafe"
- "makarna/pkg/backend/cpu"
- "makarna/pkg/backend/cpu/matmul"
- "makarna/pkg/backend/cpu/nn"
- "makarna/pkg/backend/cuda"
- "makarna/pkg/profile"
- "makarna/pkg/tensor"
- )
- // HybridLinear performs matrix multiplication on either CPU or GPU.
- // Automatically uses weight cache for GPU weights.
- func HybridLinear(ctx *Context, input *Activation, weight tensor.Tensor, output *Activation) error {
- if input.IsGPU() && ctx != nil && ctx.IsGPU() {
- profile.Start("HybridLinear/GPU")
- err := hybridLinearGPU(ctx, input, weight, output)
- profile.End("HybridLinear/GPU")
- return err
- }
- profile.Start("HybridLinear/CPU")
- err := hybridLinearCPU(input, weight, output)
- profile.End("HybridLinear/CPU")
- return err
- }
- func hybridLinearCPU(input *Activation, weight tensor.Tensor, output *Activation) error {
- inCPU, err := input.AsCPU()
- if err != nil {
- return err
- }
- outCPU, err := output.AsCPU()
- if err != nil {
- return err
- }
- wCPU := weight.(*cpu.Tensor)
- return matmul.Linear(inCPU, wCPU, outCPU)
- }
- func hybridLinearGPU(ctx *Context, input *Activation, weight tensor.Tensor, output *Activation) error {
- gpu := ctx.Placement().GPU
- inShape := input.Shape()
- wShape := weight.Shape()
- M, K, N := inShape[0], inShape[1], wShape[0]
- // Get GPU input
- gpuIn, err := input.AsCUDA(gpu)
- if err != nil {
- return err
- }
- // Get cached weight
- cache := GetWeightCache(gpu)
- var weightKey string
- if wCPU, ok := weight.(*cpu.Tensor); ok {
- weightKey = fmt.Sprintf("layer%d_w_%p", ctx.LayerIdx, wCPU)
- } else {
- weightKey = fmt.Sprintf("layer%d_w_%T_%p", ctx.LayerIdx, weight, weight)
- }
- upload := cache.Upload
- if weight.DType() == tensor.Float32 {
- // Use a separate key so float32 weights needed by other ops can coexist with FP16 GEMM weights.
- if wCPU, ok := weight.(*cpu.Tensor); ok {
- weightKey = fmt.Sprintf("layer%d_w_f16_%p", ctx.LayerIdx, wCPU)
- } else {
- weightKey = fmt.Sprintf("layer%d_w_f16_%T_%p", ctx.LayerIdx, weight, weight)
- }
- upload = cache.UploadF16
- }
- gpuWeight, ok := cache.Get(weightKey)
- if !ok {
- cpuW := weight.(*cpu.Tensor)
- gpuWeight, err = upload(weightKey, cpuW)
- if err != nil {
- return fmt.Errorf("hybrid linear: cache weight: %w", err)
- }
- }
- // Reuse preallocated output buffer when possible (e.g., scratch views).
- var gpuOut *cuda.Tensor
- if output != nil && output.IsGPU() {
- if outT, err := output.AsCUDA(gpu); err == nil {
- if outT.DType() == tensor.Float32 {
- shape := outT.Shape()
- if len(shape) == 2 && shape[0] == M && shape[1] == N {
- gpuOut = outT
- }
- }
- }
- }
- if gpuOut == nil {
- t, err := cuda.NewTensor(tensor.Shape{M, N}, tensor.Float32, gpu)
- if err != nil {
- return err
- }
- output.ReplaceWith(t)
- gpuOut = t
- }
- // Execute based on weight dtype and input dtype
- aPtr := gpuIn.Data().(unsafe.Pointer)
- cPtr := gpuOut.Data().(unsafe.Pointer)
- inputIsF16 := gpuIn.DType() == tensor.Float16
- // Prefer FP16 input for quant matmuls (memory bandwidth win).
- // If activations are still FP32, cast to FP16 on GPU and use the FP16 kernels.
- if !inputIsF16 {
- switch weight.DType() {
- case tensor.Q8_K, tensor.Q5_K, tensor.Q4_K, tensor.Q2_K, tensor.Q3_K, tensor.Q6_K:
- f16In, err := cuda.NewTensor(tensor.Shape{M, K}, tensor.Float16, gpu)
- if err != nil {
- return err
- }
- defer f16In.Free()
- if err := cuda.CastF32ToF16(aPtr, f16In.Data().(unsafe.Pointer), M*K, gpu); err != nil {
- return err
- }
- aPtr = f16In.Data().(unsafe.Pointer)
- inputIsF16 = true
- }
- }
- switch weight.DType() {
- case tensor.Q8_K:
- if inputIsF16 {
- if err := cuda.MatMulF16Q8K(aPtr, gpuWeight, cPtr, M, K, N, gpu); err != nil {
- return err
- }
- } else {
- if err := cuda.MatMulQ8K(aPtr, gpuWeight, cPtr, M, K, N, gpu); err != nil {
- return err
- }
- }
- case tensor.Q5_K:
- if inputIsF16 {
- if err := cuda.MatMulF16Q5K(aPtr, gpuWeight, cPtr, M, K, N, gpu); err != nil {
- return err
- }
- } else {
- if err := cuda.MatMulQ5K(aPtr, gpuWeight, cPtr, M, K, N, gpu); err != nil {
- return err
- }
- }
- case tensor.Q4_K:
- if inputIsF16 {
- if err := cuda.MatMulF16Q4K(aPtr, gpuWeight, cPtr, M, K, N, gpu); err != nil {
- return err
- }
- } else {
- if err := cuda.MatMulQ4K(aPtr, gpuWeight, cPtr, M, K, N, gpu); err != nil {
- return err
- }
- }
- case tensor.Q2_K:
- if inputIsF16 {
- if err := cuda.MatMulF16Q2K(aPtr, gpuWeight, cPtr, M, K, N, gpu); err != nil {
- return err
- }
- } else {
- if err := cuda.MatMulQ2K(aPtr, gpuWeight, cPtr, M, K, N, gpu); err != nil {
- return err
- }
- }
- case tensor.Q3_K:
- if inputIsF16 {
- if err := cuda.MatMulF16Q3K(aPtr, gpuWeight, cPtr, M, K, N, gpu); err != nil {
- return err
- }
- } else {
- if err := cuda.MatMulQ3K(aPtr, gpuWeight, cPtr, M, K, N, gpu); err != nil {
- return err
- }
- }
- case tensor.Q6_K:
- if inputIsF16 {
- if err := cuda.MatMulF16Q6K(aPtr, gpuWeight, cPtr, M, K, N, gpu); err != nil {
- return err
- }
- } else {
- if err := cuda.MatMulQ6K(aPtr, gpuWeight, cPtr, M, K, N, gpu); err != nil {
- return err
- }
- }
- default:
- // Dense GEMM path (weights cached as FP16 on GPU).
- if !inputIsF16 {
- f16In, err := cuda.NewTensor(tensor.Shape{M, K}, tensor.Float16, gpu)
- if err != nil {
- return err
- }
- defer f16In.Free()
- if err := cuda.CastF32ToF16(aPtr, f16In.Data().(unsafe.Pointer), M*K, gpu); err != nil {
- return err
- }
- aPtr = f16In.Data().(unsafe.Pointer)
- inputIsF16 = true
- }
- if err := cuda.MatMulF16(aPtr, gpuWeight, cPtr, M, K, N, gpu); err != nil {
- return err
- }
- }
- return nil
- }
- // HybridRMSNorm applies RMS normalization in-place.
- func HybridRMSNorm(ctx *Context, x *Activation, w tensor.Tensor, eps float32) error {
- if x.IsGPU() && ctx != nil && ctx.IsGPU() {
- profile.Start("HybridRMSNorm/GPU")
- err := hybridRMSNormGPU(ctx, x, w, eps)
- profile.End("HybridRMSNorm/GPU")
- return err
- }
- profile.Start("HybridRMSNorm/CPU")
- err := hybridRMSNormCPU(x, w, eps)
- profile.End("HybridRMSNorm/CPU")
- return err
- }
- func hybridRMSNormCPU(x *Activation, w tensor.Tensor, eps float32) error {
- xCPU, err := x.AsCPU()
- if err != nil {
- return err
- }
- wCPU := w.(*cpu.Tensor)
- xData := xCPU.DataFloat32()
- wData := wCPU.DataFloat32()
- dim := wCPU.Shape().NumElements()
- numRows := xCPU.Shape().NumElements() / dim
- for i := 0; i < numRows; i++ {
- row := xData[i*dim : (i+1)*dim]
- ss := cpu.DotFloat32(row, row) / float32(dim)
- invRMS := 1.0 / float32(math.Sqrt(float64(ss+eps)))
- for j := 0; j < dim; j++ {
- row[j] = row[j] * invRMS * wData[j]
- }
- }
- return nil
- }
- func hybridRMSNormGPU(ctx *Context, x *Activation, w tensor.Tensor, eps float32) error {
- gpu := ctx.Placement().GPU
- shape := x.Shape()
- seqLen, dim := shape[0], shape[1]
- wShape := w.Shape()
- wDim := wShape.NumElements()
- // For per-head normalization (qNorm/kNorm), if the dimension matches the weight dimension
- // when viewed as a flattened sequence of heads, we can run it on GPU by reshaping.
- if wDim != dim {
- if dim%wDim == 0 {
- // e.g. dim=3584, wDim=128 (28 heads).
- // We can treat this as [seqLen * numHeads, headDim]
- numHeads := dim / wDim
- effectiveSeqLen := seqLen * numHeads
-
- // We use the same kernel, just with modified dimensions
- gpuX, err := x.AsCUDA(gpu)
- if err != nil {
- return err
- }
- // Get cached weight
- cache := GetWeightCache(gpu)
- var wKey string
- if wCPU, ok := w.(*cpu.Tensor); ok {
- wKey = fmt.Sprintf("norm_%p", wCPU)
- } else {
- wKey = fmt.Sprintf("norm_%T_%p", w, w)
- }
- gpuW, ok := cache.Get(wKey)
- if !ok {
- gpuW, err = cache.Upload(wKey, w.(*cpu.Tensor))
- if err != nil {
- return fmt.Errorf("rmsnorm: cache upload failed: %w", err)
- }
- }
- if gpuW == nil {
- return fmt.Errorf("rmsnorm: got nil weight pointer from cache")
- }
-
- return cuda.RMSNorm(gpuX.Data().(unsafe.Pointer), gpuW, effectiveSeqLen, wDim, eps, gpu)
- }
- // Fallback to CPU if we can't reshape cleanly
- // Per-head normalization - fall back to CPU but restore to GPU after
- wasGPU := x.IsGPU()
- if err := hybridRMSNormCPU(x, w, eps); err != nil {
- return err
- }
- // Restore to GPU if it was on GPU before
- if wasGPU {
- if _, err := x.EnsureOn(ctx.Placement()); err != nil {
- return fmt.Errorf("restore to GPU after per-head norm: %w", err)
- }
- }
- return nil
- }
- gpuX, err := x.AsCUDA(gpu)
- if err != nil {
- return err
- }
- // Get cached weight
- cache := GetWeightCache(gpu)
- var wKey string
- if wCPU, ok := w.(*cpu.Tensor); ok {
- wKey = fmt.Sprintf("norm_%p", wCPU)
- } else {
- wKey = fmt.Sprintf("norm_%T_%p", w, w)
- }
- gpuW, ok := cache.Get(wKey)
- if !ok {
- gpuW, err = cache.Upload(wKey, w.(*cpu.Tensor))
- if err != nil {
- return fmt.Errorf("rmsnorm: cache upload failed: %w", err)
- }
- }
- if gpuW == nil {
- return fmt.Errorf("rmsnorm: got nil weight pointer from cache")
- }
- // Standard case: weight dimension matches activation dimension
- return cuda.RMSNorm(gpuX.Data().(unsafe.Pointer), gpuW, seqLen, dim, eps, gpu)
- }
- // HybridRoPE applies rotary positional embeddings in-place.
- func HybridRoPE(ctx *Context, x *Activation, positions []int, headDim int, theta float32) error {
- if x.IsGPU() && ctx != nil && ctx.IsGPU() {
- profile.Start("HybridRoPE/GPU")
- err := hybridRoPEGPU(ctx, x, positions, headDim, theta)
- profile.End("HybridRoPE/GPU")
- return err
- }
- profile.Start("HybridRoPE/CPU")
- err := hybridRoPECPU(x, positions, headDim, theta)
- profile.End("HybridRoPE/CPU")
- return err
- }
- func hybridRoPECPU(x *Activation, positions []int, headDim int, theta float32) error {
- xCPU, err := x.AsCPU()
- if err != nil {
- return err
- }
- data := xCPU.DataFloat32()
- shape := x.Shape()
- seqLen := shape[0]
- totalDim := shape[1]
- halfDim := headDim / 2
- invFreqs := make([]float64, halfDim)
- for j := 0; j < halfDim; j++ {
- invFreqs[j] = 1.0 / math.Pow(float64(theta), float64(2*j)/float64(headDim))
- }
- for seq := 0; seq < seqLen; seq++ {
- pos := positions[seq]
- rowStart := seq * totalDim
- for headStart := 0; headStart < totalDim; headStart += headDim {
- for j := 0; j < halfDim; j++ {
- freq := float64(pos) * invFreqs[j]
- sin, cos := math.Sincos(freq)
- idx0 := rowStart + headStart + j
- idx1 := rowStart + headStart + j + halfDim
- v0 := data[idx0]
- v1 := data[idx1]
- data[idx0] = v0*float32(cos) - v1*float32(sin)
- data[idx1] = v1*float32(cos) + v0*float32(sin)
- }
- }
- }
- return nil
- }
- func hybridRoPEGPU(ctx *Context, x *Activation, positions []int, headDim int, theta float32) error {
- gpu := ctx.Placement().GPU
- shape := x.Shape()
- seqLen := shape[0]
- totalDim := shape[1]
- numHeads := totalDim / headDim
- gpuX, err := x.AsCUDA(gpu)
- if err != nil {
- return err
- }
- // Optimization: For single-token update (decode phase), we can pass the position
- // directly to the kernel as a scalar, avoiding ALL memory allocation/copy overhead.
- if len(positions) == 1 {
- pos := positions[0]
- return cuda.RoPESingle(gpuX.Data().(unsafe.Pointer), pos, numHeads, headDim, theta, gpu)
- }
- // Upload positions as int32 (CUDA kernel expects int*)
- posData := make([]int32, len(positions))
- for i, p := range positions {
- posData[i] = int32(p)
- }
-
- var gpuPosPtr unsafe.Pointer
- // Try using scratch space if available to avoid malloc/free overhead
- if ctx != nil && ctx.Scratch != nil {
- gpuPosPtr, err = ctx.Scratch.GetInt32Slice(len(positions))
- }
-
- // Fallback or if scratch failed (or nil), allocate new
- shouldFree := false
- if gpuPosPtr == nil || err != nil {
- gpuPosPtr, err = cuda.AllocAndCopyInt32(posData, gpu)
- if err != nil {
- return fmt.Errorf("RoPE: upload positions: %w", err)
- }
- shouldFree = true
- } else {
- // If using scratch, we need to copy data manually
- // (AllocAndCopyInt32 did alloc+copy, GetInt32Slice only allocs)
- err = cuda.MemcpyH2D(gpuPosPtr, unsafe.Pointer(&posData[0]), uintptr(len(posData)*4), gpu)
- if err != nil {
- return fmt.Errorf("RoPE: memcpy positions: %w", err)
- }
- }
- if shouldFree {
- defer cuda.Free(gpuPosPtr)
- }
- if err := cuda.RoPE(gpuX.Data().(unsafe.Pointer), gpuPosPtr, seqLen, numHeads, headDim, theta, gpu); err != nil {
- return err
- }
- return nil
-
- // Synchronize REMOVED for performance.
- // 1. If using scratch: memory persists until end of step (reset). Safe.
- // 2. If using alloc+free: cudaFree is stream-ordered, so kernel will finish reading before free happens. Safe.
- return nil
- }
- // HybridSoftmax applies softmax along the last dimension in-place.
- func HybridSoftmax(ctx *Context, x *Activation) error {
- if x.IsGPU() && ctx != nil && ctx.IsGPU() {
- profile.Start("HybridSoftmax/GPU")
- err := hybridSoftmaxGPU(ctx, x)
- profile.End("HybridSoftmax/GPU")
- return err
- }
- profile.Start("HybridSoftmax/CPU")
- err := hybridSoftmaxCPU(x)
- profile.End("HybridSoftmax/CPU")
- return err
- }
- func hybridSoftmaxCPU(x *Activation) error {
- xCPU, err := x.AsCPU()
- if err != nil {
- return err
- }
- data := xCPU.DataFloat32()
- shape := x.Shape()
- rows, cols := shape[0], shape[1]
- for i := 0; i < rows; i++ {
- row := data[i*cols : (i+1)*cols]
- maxVal := row[0]
- for _, v := range row[1:] {
- if v > maxVal {
- maxVal = v
- }
- }
- sum := float32(0)
- for j := range row {
- row[j] = float32(math.Exp(float64(row[j] - maxVal)))
- sum += row[j]
- }
- for j := range row {
- row[j] /= sum
- }
- }
- return nil
- }
- func hybridSoftmaxGPU(ctx *Context, x *Activation) error {
- gpu := ctx.Placement().GPU
- shape := x.Shape()
- rows, cols := shape[0], shape[1]
- gpuX, err := x.AsCUDA(gpu)
- if err != nil {
- return err
- }
- return cuda.Softmax(gpuX.Data().(unsafe.Pointer), rows, cols, gpu)
- }
- // HybridSiLU applies SiLU activation in-place: x = x * sigmoid(x)
- func HybridSiLU(ctx *Context, x *Activation) error {
- if x.IsGPU() && ctx != nil && ctx.IsGPU() {
- profile.Start("HybridSiLU/GPU")
- err := hybridSiLUGPU(ctx, x)
- profile.End("HybridSiLU/GPU")
- return err
- }
- profile.Start("HybridSiLU/CPU")
- err := hybridSiLUCPU(x)
- profile.End("HybridSiLU/CPU")
- return err
- }
- func HybridSwiGLU(ctx *Context, gate, up, out *Activation) error {
- if err := HybridCopy(ctx, out, gate); err != nil {
- return err
- }
- if err := HybridSiLU(ctx, out); err != nil {
- return err
- }
- return HybridMul(ctx, out, up)
- }
- func hybridSiLUCPU(x *Activation) error {
- xCPU, err := x.AsCPU()
- if err != nil {
- return err
- }
- return nn.SiLU(xCPU)
- }
- func hybridSiLUGPU(ctx *Context, x *Activation) error {
- gpu := ctx.Placement().GPU
- gpuX, err := x.AsCUDA(gpu)
- if err != nil {
- return err
- }
- return cuda.SiLU(gpuX.Data().(unsafe.Pointer), x.Shape().NumElements(), gpu)
- }
- // HybridMul performs element-wise multiplication: a = a * b
- func HybridMul(ctx *Context, a, b *Activation) error {
- if a.IsGPU() && ctx != nil && ctx.IsGPU() {
- profile.Start("HybridMul/GPU")
- err := hybridMulGPU(ctx, a, b)
- profile.End("HybridMul/GPU")
- return err
- }
- profile.Start("HybridMul/CPU")
- err := hybridMulCPU(a, b)
- profile.End("HybridMul/CPU")
- return err
- }
- func hybridMulCPU(a, b *Activation) error {
- aCPU, err := a.AsCPU()
- if err != nil {
- return err
- }
- bCPU, err := b.AsCPU()
- if err != nil {
- return err
- }
- aData := aCPU.DataFloat32()
- bData := bCPU.DataFloat32()
- for i := range aData {
- aData[i] *= bData[i]
- }
- return nil
- }
- func hybridMulGPU(ctx *Context, a, b *Activation) error {
- gpu := ctx.Placement().GPU
- gpuA, err := a.AsCUDA(gpu)
- if err != nil {
- return err
- }
- gpuB, err := b.AsCUDA(gpu)
- if err != nil {
- return err
- }
- return cuda.MulInplace(gpuA.Data().(unsafe.Pointer), gpuB.Data().(unsafe.Pointer), a.Shape().NumElements(), gpu)
- }
- // HybridAdd performs element-wise addition: a = a + b
- func HybridAdd(ctx *Context, a, b *Activation) error {
- if a.IsGPU() && ctx != nil && ctx.IsGPU() {
- profile.Start("HybridAdd/GPU")
- err := hybridAddGPU(ctx, a, b)
- profile.End("HybridAdd/GPU")
- return err
- }
- profile.Start("HybridAdd/CPU")
- err := hybridAddCPU(a, b)
- profile.End("HybridAdd/CPU")
- return err
- }
- func hybridAddCPU(a, b *Activation) error {
- aCPU, err := a.AsCPU()
- if err != nil {
- return err
- }
- bCPU, err := b.AsCPU()
- if err != nil {
- return err
- }
- aData := aCPU.DataFloat32()
- bData := bCPU.DataFloat32()
- for i := range aData {
- aData[i] += bData[i]
- }
- return nil
- }
- func hybridAddGPU(ctx *Context, a, b *Activation) error {
- gpu := ctx.Placement().GPU
- gpuA, err := a.AsCUDA(gpu)
- if err != nil {
- return err
- }
- gpuB, err := b.AsCUDA(gpu)
- if err != nil {
- return err
- }
- return cuda.AddInplace(gpuA.Data().(unsafe.Pointer), gpuB.Data().(unsafe.Pointer), a.Shape().NumElements(), gpu)
- }
- // HybridAttention computes full causal attention.
- func HybridAttention(ctx *Context, Q, K, V, out *Activation, numHeads, numKVHeads, headDim int, scale float32, startPos int) error {
- if Q.IsGPU() && ctx != nil && ctx.IsGPU() {
- profile.Start("HybridAttention/GPU")
- err := hybridAttentionGPU(ctx, Q, K, V, out, numHeads, numKVHeads, headDim, scale, startPos)
- profile.End("HybridAttention/GPU")
- return err
- }
- profile.Start("HybridAttention/CPU")
- err := hybridAttentionCPU(Q, K, V, out, numHeads, numKVHeads, headDim, scale, startPos)
- profile.End("HybridAttention/CPU")
- return err
- }
- func hybridAttentionCPU(Q, K, V, out *Activation, numHeads, numKVHeads, headDim int, scale float32, startPos int) error {
- qCPU, err := Q.AsCPU()
- if err != nil {
- return err
- }
- kCPU, err := K.AsCPU()
- if err != nil {
- return err
- }
- vCPU, err := V.AsCPU()
- if err != nil {
- return err
- }
- outCPU, err := out.AsCPU()
- if err != nil {
- return err
- }
- // Prefer optimized CPU kernels from backend/cpu/nn.
- // This avoids per-token allocations and uses SIMD softmax/Axpy.
- qTensor := qCPU
- kTensor := kCPU
- vTensor := vCPU
- outTensor := outCPU
- // nn implementations include scaling internally; apply scale by scaling Q in-place into a temporary
- // would be costly. Instead we keep the existing API and pass scale via headDim scaling in scores.
- // Here we rely on nn attention using cpu.DotFloat32 and multiply by scale internally.
- _ = scale
- // Use cached causal attention when startPos is provided (decode/prefill with cache).
- // When seqLen == kvLen and startPos==0, this also works as standard causal attention.
- return nn.CausalAttentionCached(qTensor, kTensor, vTensor, outTensor, numHeads, numKVHeads, headDim, startPos)
- }
- func hybridAttentionGPU(ctx *Context, Q, K, V, out *Activation, numHeads, numKVHeads, headDim int, scale float32, startPos int) error {
- gpu := ctx.Placement().GPU
- gpuQ, err := Q.AsCUDA(gpu)
- if err != nil {
- return err
- }
- gpuK, err := K.AsCUDA(gpu)
- if err != nil {
- return err
- }
- gpuV, err := V.AsCUDA(gpu)
- if err != nil {
- return err
- }
- // Allocate output on GPU
- gpuOut, err := cuda.NewTensor(out.Shape(), tensor.Float32, gpu)
- if err != nil {
- return err
- }
- seqLen := Q.Shape()[0]
- kvLen := K.Shape()[0]
- err = cuda.Attention(
- gpuQ.Data().(unsafe.Pointer),
- gpuK.Data().(unsafe.Pointer),
- gpuV.Data().(unsafe.Pointer),
- gpuOut.Data().(unsafe.Pointer),
- seqLen, kvLen, numHeads, numKVHeads, headDim,
- scale, startPos, gpu,
- )
- if err != nil {
- return err
- }
- out.ReplaceWith(gpuOut)
- return nil
- }
- // HybridCopy copies src to dst.
- func HybridCopy(ctx *Context, dst, src *Activation) error {
- if dst.IsGPU() && src.IsGPU() && ctx != nil && ctx.IsGPU() {
- return hybridCopyGPU(ctx, dst, src)
- }
- return hybridCopyCPU(dst, src)
- }
- func hybridCopyCPU(dst, src *Activation) error {
- dstCPU, err := dst.AsCPU()
- if err != nil {
- return err
- }
- srcCPU, err := src.AsCPU()
- if err != nil {
- return err
- }
- copy(dstCPU.DataFloat32(), srcCPU.DataFloat32())
- return nil
- }
- func hybridCopyGPU(ctx *Context, dst, src *Activation) error {
- gpu := ctx.Placement().GPU
- gpuDst, err := dst.AsCUDA(gpu)
- if err != nil {
- return err
- }
- gpuSrc, err := src.AsCUDA(gpu)
- if err != nil {
- return err
- }
- return cuda.Copy(gpuDst.Data().(unsafe.Pointer), gpuSrc.Data().(unsafe.Pointer), dst.Shape().NumElements(), gpu)
- }
- // EnsureOnDevice moves activation to target device if needed.
- // This is the key function for hybrid execution - only transfers when crossing device boundaries.
- func EnsureOnDevice(a *Activation, target tensor.DevicePlacement) error {
- transferred, err := a.EnsureOn(target)
- if err != nil {
- return err
- }
- if transferred {
- // Log for debugging (can be removed later)
- _ = transferred
- }
- return nil
- }
|