//go:build cuda // Package compute provides device-agnostic computation dispatching. package compute import ( "fmt" "unsafe" "makarna/pkg/backend/cpu" "makarna/pkg/backend/cpu/matmul" "makarna/pkg/backend/cuda" "makarna/pkg/backend/device" "makarna/pkg/profile" "makarna/pkg/tensor" ) // Linear performs a linear layer: output = input @ weight.T // Automatically dispatches to CPU or CUDA based on tensor placement. // Uses GPU weight cache for persistent weight storage. func Linear(ctx *Context, input, weight, output tensor.Tensor) error { useGPU := ctx != nil && ctx.IsGPU() && device.CUDAAvailable() if !useGPU { profile.Start("Linear/CPU") err := linearCPU(input, weight, output) profile.End("Linear/CPU") return err } switch weight.DType() { case tensor.Float32: profile.Start("Linear/F32") err := linearCUDAF32(ctx, input, weight, output) profile.End("Linear/F32") return err case tensor.Q8_K: profile.Start("Linear/Q8K") err := linearCUDAQ8K(ctx, input, weight, output) profile.End("Linear/Q8K") return err case tensor.Q5_K: profile.Start("Linear/Q5K") err := linearCUDAQ5K(ctx, input, weight, output) profile.End("Linear/Q5K") return err case tensor.Q4_K: profile.Start("Linear/Q4K") err := linearCUDAQ4K(ctx, input, weight, output) profile.End("Linear/Q4K") return err case tensor.Q2_K: profile.Start("Linear/Q2K") err := linearCUDAQ2K(ctx, input, weight, output) profile.End("Linear/Q2K") return err case tensor.Q3_K: profile.Start("Linear/Q3K") err := linearCUDAQ3K(ctx, input, weight, output) profile.End("Linear/Q3K") return err case tensor.Q6_K: profile.Start("Linear/Q6K") err := linearCUDAQ6K(ctx, input, weight, output) profile.End("Linear/Q6K") return err default: profile.Start("Linear/CPU") err := linearCPU(input, weight, output) profile.End("Linear/CPU") return err } } func linearCPU(input, weight, output tensor.Tensor) error { inCPU, ok := input.(*cpu.Tensor) if !ok { var err error inCPU, err = ToCPU(input) if err != nil { return fmt.Errorf("linear: failed to get CPU input: %w", err) } } wCPU, ok := weight.(*cpu.Tensor) if !ok { return fmt.Errorf("linear: weight must be CPU tensor for CPU path") } outCPU, ok := output.(*cpu.Tensor) if !ok { return fmt.Errorf("linear: output must be CPU tensor for CPU path") } return matmul.Linear(inCPU, wCPU, outCPU) } // linearCUDAF32 - F32 weights, uses cache func linearCUDAF32(ctx *Context, input, weight, output tensor.Tensor) error { inShape := input.Shape() wShape := weight.Shape() M, K, N := inShape[0], inShape[1], wShape[0] gpu := ctx.Placement().GPU // Get cached GPU input or upload profile.Start("Linear/F32/input_upload") gpuInput, err := getOrUploadInput(input, gpu) profile.End("Linear/F32/input_upload") if err != nil { return err } // Get cached weight cache := GetWeightCache(gpu) weightKey := fmt.Sprintf("layer%d_w_f16_%p", ctx.LayerIdx, weight.(*cpu.Tensor)) gpuWeight, ok := cache.Get(weightKey) if !ok { profile.Start("Linear/F32/weight_upload") cpuW := weight.(*cpu.Tensor) gpuWeight, err = cache.UploadF16(weightKey, cpuW) profile.End("Linear/F32/weight_upload") if err != nil { return fmt.Errorf("linear F32: cache weight: %w", err) } } // Allocate output gpuOutput, err := cuda.NewTensor(tensor.Shape{M, N}, tensor.Float32, gpu) if err != nil { return fmt.Errorf("linear F32: alloc output: %w", err) } defer gpuOutput.Free() // Execute matmul using raw pointers profile.Start("Linear/F32/matmul_kernel") aPtr := gpuInput.Data().(unsafe.Pointer) cPtr := gpuOutput.Data().(unsafe.Pointer) if gpuInput.DType() != tensor.Float16 { profile.Start("Linear/F32/cast_fp16") f16In, err := cuda.NewTensor(tensor.Shape{M, K}, tensor.Float16, gpu) if err != nil { profile.End("Linear/F32/cast_fp16") profile.End("Linear/F32/matmul_kernel") return fmt.Errorf("linear F32: alloc f16 input: %w", err) } defer f16In.Free() if err := cuda.CastF32ToF16(aPtr, f16In.Data().(unsafe.Pointer), M*K, gpu); err != nil { profile.End("Linear/F32/cast_fp16") profile.End("Linear/F32/matmul_kernel") return fmt.Errorf("linear F32: cast input f32->f16: %w", err) } aPtr = f16In.Data().(unsafe.Pointer) profile.End("Linear/F32/cast_fp16") } err = cuda.MatMulF16(aPtr, gpuWeight, cPtr, M, K, N, gpu) profile.End("Linear/F32/matmul_kernel") if err != nil { return fmt.Errorf("linear F32: matmul f16: %w", err) } // Copy back to CPU output if cpuOut, ok := output.(*cpu.Tensor); ok { if err := gpuOutput.CopyToHost(cpuOut.DataFloat32()); err != nil { return fmt.Errorf("linear F32: copy D2H: %w", err) } } return nil } func linearCUDAQ5K(ctx *Context, input, weight, output tensor.Tensor) error { inShape := input.Shape() wShape := weight.Shape() M, K, N := inShape[0], inShape[1], wShape[0] gpu := ctx.Placement().GPU profile.Start("Linear/Q5K/input_upload") gpuInput, err := getOrUploadInput(input, gpu) profile.End("Linear/Q5K/input_upload") if err != nil { return err } cache := GetWeightCache(gpu) weightKey := fmt.Sprintf("layer%d_w_%p", ctx.LayerIdx, weight.(*cpu.Tensor)) gpuWeight, ok := cache.Get(weightKey) if !ok { profile.Start("Linear/Q5K/weight_upload") cpuW := weight.(*cpu.Tensor) gpuWeight, err = cache.Upload(weightKey, cpuW) profile.End("Linear/Q5K/weight_upload") if err != nil { return fmt.Errorf("linear Q5K: cache weight: %w", err) } } gpuOutput, err := cuda.NewTensor(tensor.Shape{M, N}, tensor.Float32, gpu) if err != nil { return fmt.Errorf("linear Q5K: alloc output: %w", err) } defer gpuOutput.Free() aPtr := gpuInput.Data().(unsafe.Pointer) cPtr := gpuOutput.Data().(unsafe.Pointer) profile.Start("Linear/Q5K/matmul_kernel") err = cuda.MatMulQ5K(aPtr, gpuWeight, cPtr, M, K, N, gpu) profile.End("Linear/Q5K/matmul_kernel") if err != nil { return fmt.Errorf("linear Q5K: matmul: %w", err) } if cpuOut, ok := output.(*cpu.Tensor); ok { if err := gpuOutput.CopyToHost(cpuOut.DataFloat32()); err != nil { return fmt.Errorf("linear Q5K: copy D2H: %w", err) } } return nil } // linearCUDAQ8K - Q8_K weights with caching func linearCUDAQ8K(ctx *Context, input, weight, output tensor.Tensor) error { inShape := input.Shape() wShape := weight.Shape() M, K, N := inShape[0], inShape[1], wShape[0] gpu := ctx.Placement().GPU // Get GPU input profile.Start("Linear/Q8K/input_upload") gpuInput, err := getOrUploadInput(input, gpu) profile.End("Linear/Q8K/input_upload") if err != nil { return err } // Get cached weight or upload cache := GetWeightCache(gpu) weightKey := fmt.Sprintf("layer%d_w_%p", ctx.LayerIdx, weight.(*cpu.Tensor)) gpuWeight, ok := cache.Get(weightKey) if !ok { profile.Start("Linear/Q8K/weight_upload") cpuW := weight.(*cpu.Tensor) gpuWeight, err = cache.Upload(weightKey, cpuW) profile.End("Linear/Q8K/weight_upload") if err != nil { return fmt.Errorf("linear Q8K: cache weight: %w", err) } } // Allocate output gpuOutput, err := cuda.NewTensor(tensor.Shape{M, N}, tensor.Float32, gpu) if err != nil { return fmt.Errorf("linear Q8K: alloc output: %w", err) } defer gpuOutput.Free() // Execute fused matmul aPtr := gpuInput.Data().(unsafe.Pointer) cPtr := gpuOutput.Data().(unsafe.Pointer) profile.Start("Linear/Q8K/matmul_kernel") err = cuda.MatMulQ8K(aPtr, gpuWeight, cPtr, M, K, N, gpu) profile.End("Linear/Q8K/matmul_kernel") if err != nil { return fmt.Errorf("linear Q8K: matmul: %w", err) } // Copy back if cpuOut, ok := output.(*cpu.Tensor); ok { if err := gpuOutput.CopyToHost(cpuOut.DataFloat32()); err != nil { return fmt.Errorf("linear Q8K: copy D2H: %w", err) } } return nil } // linearCUDAQ4K - Q4_K weights with caching func linearCUDAQ4K(ctx *Context, input, weight, output tensor.Tensor) error { inShape := input.Shape() wShape := weight.Shape() M, K, N := inShape[0], inShape[1], wShape[0] gpu := ctx.Placement().GPU // Get GPU input profile.Start("Linear/Q4K/input_upload") gpuInput, err := getOrUploadInput(input, gpu) profile.End("Linear/Q4K/input_upload") if err != nil { return err } // Get cached weight cache := GetWeightCache(gpu) weightKey := fmt.Sprintf("layer%d_w_%p", ctx.LayerIdx, weight.(*cpu.Tensor)) gpuWeight, ok := cache.Get(weightKey) if !ok { profile.Start("Linear/Q4K/weight_upload") cpuW := weight.(*cpu.Tensor) gpuWeight, err = cache.Upload(weightKey, cpuW) profile.End("Linear/Q4K/weight_upload") if err != nil { return fmt.Errorf("linear Q4K: cache weight: %w", err) } } // Allocate output gpuOutput, err := cuda.NewTensor(tensor.Shape{M, N}, tensor.Float32, gpu) if err != nil { return fmt.Errorf("linear Q4K: alloc output: %w", err) } defer gpuOutput.Free() // Execute fused matmul aPtr := gpuInput.Data().(unsafe.Pointer) cPtr := gpuOutput.Data().(unsafe.Pointer) profile.Start("Linear/Q4K/matmul_kernel") err = cuda.MatMulQ4K(aPtr, gpuWeight, cPtr, M, K, N, gpu) profile.End("Linear/Q4K/matmul_kernel") if err != nil { return fmt.Errorf("linear Q4K: matmul: %w", err) } // Copy back if cpuOut, ok := output.(*cpu.Tensor); ok { if err := gpuOutput.CopyToHost(cpuOut.DataFloat32()); err != nil { return fmt.Errorf("linear Q4K: copy D2H: %w", err) } } return nil } // linearCUDAQ2K - Q2_K weights with caching func linearCUDAQ2K(ctx *Context, input, weight, output tensor.Tensor) error { inShape := input.Shape() wShape := weight.Shape() M, K, N := inShape[0], inShape[1], wShape[0] gpu := ctx.Placement().GPU // Get GPU input profile.Start("Linear/Q2K/input_upload") gpuInput, err := getOrUploadInput(input, gpu) profile.End("Linear/Q2K/input_upload") if err != nil { return err } // Get cached weight cache := GetWeightCache(gpu) weightKey := fmt.Sprintf("layer%d_w_%p", ctx.LayerIdx, weight.(*cpu.Tensor)) gpuWeight, ok := cache.Get(weightKey) if !ok { profile.Start("Linear/Q2K/weight_upload") cpuW := weight.(*cpu.Tensor) gpuWeight, err = cache.Upload(weightKey, cpuW) profile.End("Linear/Q2K/weight_upload") if err != nil { return fmt.Errorf("linear Q2K: cache weight: %w", err) } } // Allocate output gpuOutput, err := cuda.NewTensor(tensor.Shape{M, N}, tensor.Float32, gpu) if err != nil { return fmt.Errorf("linear Q2K: alloc output: %w", err) } defer gpuOutput.Free() // Execute fused matmul profile.Start("Linear/Q2K/matmul_kernel") aPtr := gpuInput.Data().(unsafe.Pointer) cPtr := gpuOutput.Data().(unsafe.Pointer) err = cuda.MatMulQ2K(aPtr, gpuWeight, cPtr, M, K, N, gpu) profile.End("Linear/Q2K/matmul_kernel") if err != nil { return fmt.Errorf("linear Q2K: matmul: %w", err) } // Copy back if cpuOut, ok := output.(*cpu.Tensor); ok { if err := gpuOutput.CopyToHost(cpuOut.DataFloat32()); err != nil { return fmt.Errorf("linear Q2K: copy D2H: %w", err) } } return nil } // linearCUDAQ3K - Q3_K weights with caching func linearCUDAQ3K(ctx *Context, input, weight, output tensor.Tensor) error { inShape := input.Shape() wShape := weight.Shape() M, K, N := inShape[0], inShape[1], wShape[0] gpu := ctx.Placement().GPU // Get GPU input profile.Start("Linear/Q3K/input_upload") gpuInput, err := getOrUploadInput(input, gpu) profile.End("Linear/Q3K/input_upload") if err != nil { return err } // Get cached weight cache := GetWeightCache(gpu) weightKey := fmt.Sprintf("layer%d_w_%p", ctx.LayerIdx, weight.(*cpu.Tensor)) gpuWeight, ok := cache.Get(weightKey) if !ok { profile.Start("Linear/Q3K/weight_upload") cpuW := weight.(*cpu.Tensor) gpuWeight, err = cache.Upload(weightKey, cpuW) profile.End("Linear/Q3K/weight_upload") if err != nil { return fmt.Errorf("linear Q3K: cache weight: %w", err) } } // Allocate output gpuOutput, err := cuda.NewTensor(tensor.Shape{M, N}, tensor.Float32, gpu) if err != nil { return fmt.Errorf("linear Q3K: alloc output: %w", err) } defer gpuOutput.Free() // Execute fused matmul profile.Start("Linear/Q3K/matmul_kernel") aPtr := gpuInput.Data().(unsafe.Pointer) cPtr := gpuOutput.Data().(unsafe.Pointer) err = cuda.MatMulQ3K(aPtr, gpuWeight, cPtr, M, K, N, gpu) profile.End("Linear/Q3K/matmul_kernel") if err != nil { return fmt.Errorf("linear Q3K: matmul: %w", err) } // Copy back if cpuOut, ok := output.(*cpu.Tensor); ok { if err := gpuOutput.CopyToHost(cpuOut.DataFloat32()); err != nil { return fmt.Errorf("linear Q3K: copy D2H: %w", err) } } return nil } // linearCUDAQ6K - Q6_K weights with caching func linearCUDAQ6K(ctx *Context, input, weight, output tensor.Tensor) error { inShape := input.Shape() wShape := weight.Shape() M, K, N := inShape[0], inShape[1], wShape[0] gpu := ctx.Placement().GPU // Get GPU input profile.Start("Linear/Q6K/input_upload") gpuInput, err := getOrUploadInput(input, gpu) profile.End("Linear/Q6K/input_upload") if err != nil { return err } // Get cached weight cache := GetWeightCache(gpu) weightKey := fmt.Sprintf("layer%d_w_%p", ctx.LayerIdx, weight.(*cpu.Tensor)) gpuWeight, ok := cache.Get(weightKey) if !ok { profile.Start("Linear/Q6K/weight_upload") cpuW := weight.(*cpu.Tensor) gpuWeight, err = cache.Upload(weightKey, cpuW) profile.End("Linear/Q6K/weight_upload") if err != nil { return fmt.Errorf("linear Q6K: cache weight: %w", err) } } // Allocate output gpuOutput, err := cuda.NewTensor(tensor.Shape{M, N}, tensor.Float32, gpu) if err != nil { return fmt.Errorf("linear Q6K: alloc output: %w", err) } defer gpuOutput.Free() // Execute fused matmul profile.Start("Linear/Q6K/matmul_kernel") aPtr := gpuInput.Data().(unsafe.Pointer) cPtr := gpuOutput.Data().(unsafe.Pointer) err = cuda.MatMulQ6K(aPtr, gpuWeight, cPtr, M, K, N, gpu) profile.End("Linear/Q6K/matmul_kernel") if err != nil { return fmt.Errorf("linear Q6K: matmul: %w", err) } // Copy back if cpuOut, ok := output.(*cpu.Tensor); ok { if err := gpuOutput.CopyToHost(cpuOut.DataFloat32()); err != nil { return fmt.Errorf("linear Q6K: copy D2H: %w", err) } } return nil } // getOrUploadInput uploads CPU input to GPU func getOrUploadInput(input tensor.Tensor, gpu int) (*cuda.Tensor, error) { if cudaIn, ok := input.(*cuda.Tensor); ok { return cudaIn, nil } cpuIn, ok := input.(*cpu.Tensor) if !ok { return nil, fmt.Errorf("input must be CPU or CUDA tensor") } shape := input.Shape() gpuInput, err := cuda.NewTensor(shape, tensor.Float32, gpu) if err != nil { return nil, fmt.Errorf("alloc GPU input: %w", err) } if err := gpuInput.CopyFrom(cpuIn.DataFloat32()); err != nil { return nil, fmt.Errorf("copy input H2D: %w", err) } return gpuInput, nil }