| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525 |
- //go:build cuda
- // Package compute provides device-agnostic computation dispatching.
- package compute
- import (
- "fmt"
- "unsafe"
- "makarna/pkg/backend/cpu"
- "makarna/pkg/backend/cpu/matmul"
- "makarna/pkg/backend/cuda"
- "makarna/pkg/backend/device"
- "makarna/pkg/profile"
- "makarna/pkg/tensor"
- )
- // Linear performs a linear layer: output = input @ weight.T
- // Automatically dispatches to CPU or CUDA based on tensor placement.
- // Uses GPU weight cache for persistent weight storage.
- func Linear(ctx *Context, input, weight, output tensor.Tensor) error {
- useGPU := ctx != nil && ctx.IsGPU() && device.CUDAAvailable()
- if !useGPU {
- profile.Start("Linear/CPU")
- err := linearCPU(input, weight, output)
- profile.End("Linear/CPU")
- return err
- }
- switch weight.DType() {
- case tensor.Float32:
- profile.Start("Linear/F32")
- err := linearCUDAF32(ctx, input, weight, output)
- profile.End("Linear/F32")
- return err
- case tensor.Q8_K:
- profile.Start("Linear/Q8K")
- err := linearCUDAQ8K(ctx, input, weight, output)
- profile.End("Linear/Q8K")
- return err
- case tensor.Q5_K:
- profile.Start("Linear/Q5K")
- err := linearCUDAQ5K(ctx, input, weight, output)
- profile.End("Linear/Q5K")
- return err
- case tensor.Q4_K:
- profile.Start("Linear/Q4K")
- err := linearCUDAQ4K(ctx, input, weight, output)
- profile.End("Linear/Q4K")
- return err
- case tensor.Q2_K:
- profile.Start("Linear/Q2K")
- err := linearCUDAQ2K(ctx, input, weight, output)
- profile.End("Linear/Q2K")
- return err
- case tensor.Q3_K:
- profile.Start("Linear/Q3K")
- err := linearCUDAQ3K(ctx, input, weight, output)
- profile.End("Linear/Q3K")
- return err
- case tensor.Q6_K:
- profile.Start("Linear/Q6K")
- err := linearCUDAQ6K(ctx, input, weight, output)
- profile.End("Linear/Q6K")
- return err
- default:
- profile.Start("Linear/CPU")
- err := linearCPU(input, weight, output)
- profile.End("Linear/CPU")
- return err
- }
- }
- func linearCPU(input, weight, output tensor.Tensor) error {
- inCPU, ok := input.(*cpu.Tensor)
- if !ok {
- var err error
- inCPU, err = ToCPU(input)
- if err != nil {
- return fmt.Errorf("linear: failed to get CPU input: %w", err)
- }
- }
- wCPU, ok := weight.(*cpu.Tensor)
- if !ok {
- return fmt.Errorf("linear: weight must be CPU tensor for CPU path")
- }
- outCPU, ok := output.(*cpu.Tensor)
- if !ok {
- return fmt.Errorf("linear: output must be CPU tensor for CPU path")
- }
- return matmul.Linear(inCPU, wCPU, outCPU)
- }
- // linearCUDAF32 - F32 weights, uses cache
- func linearCUDAF32(ctx *Context, input, weight, output tensor.Tensor) error {
- inShape := input.Shape()
- wShape := weight.Shape()
- M, K, N := inShape[0], inShape[1], wShape[0]
- gpu := ctx.Placement().GPU
- // Get cached GPU input or upload
- profile.Start("Linear/F32/input_upload")
- gpuInput, err := getOrUploadInput(input, gpu)
- profile.End("Linear/F32/input_upload")
- if err != nil {
- return err
- }
- // Get cached weight
- cache := GetWeightCache(gpu)
- weightKey := fmt.Sprintf("layer%d_w_f16_%p", ctx.LayerIdx, weight.(*cpu.Tensor))
- gpuWeight, ok := cache.Get(weightKey)
- if !ok {
- profile.Start("Linear/F32/weight_upload")
- cpuW := weight.(*cpu.Tensor)
- gpuWeight, err = cache.UploadF16(weightKey, cpuW)
- profile.End("Linear/F32/weight_upload")
- if err != nil {
- return fmt.Errorf("linear F32: cache weight: %w", err)
- }
- }
- // Allocate output
- gpuOutput, err := cuda.NewTensor(tensor.Shape{M, N}, tensor.Float32, gpu)
- if err != nil {
- return fmt.Errorf("linear F32: alloc output: %w", err)
- }
- defer gpuOutput.Free()
- // Execute matmul using raw pointers
- profile.Start("Linear/F32/matmul_kernel")
- aPtr := gpuInput.Data().(unsafe.Pointer)
- cPtr := gpuOutput.Data().(unsafe.Pointer)
- if gpuInput.DType() != tensor.Float16 {
- profile.Start("Linear/F32/cast_fp16")
- f16In, err := cuda.NewTensor(tensor.Shape{M, K}, tensor.Float16, gpu)
- if err != nil {
- profile.End("Linear/F32/cast_fp16")
- profile.End("Linear/F32/matmul_kernel")
- return fmt.Errorf("linear F32: alloc f16 input: %w", err)
- }
- defer f16In.Free()
- if err := cuda.CastF32ToF16(aPtr, f16In.Data().(unsafe.Pointer), M*K, gpu); err != nil {
- profile.End("Linear/F32/cast_fp16")
- profile.End("Linear/F32/matmul_kernel")
- return fmt.Errorf("linear F32: cast input f32->f16: %w", err)
- }
- aPtr = f16In.Data().(unsafe.Pointer)
- profile.End("Linear/F32/cast_fp16")
- }
- err = cuda.MatMulF16(aPtr, gpuWeight, cPtr, M, K, N, gpu)
- profile.End("Linear/F32/matmul_kernel")
- if err != nil {
- return fmt.Errorf("linear F32: matmul f16: %w", err)
- }
- // Copy back to CPU output
- if cpuOut, ok := output.(*cpu.Tensor); ok {
- if err := gpuOutput.CopyToHost(cpuOut.DataFloat32()); err != nil {
- return fmt.Errorf("linear F32: copy D2H: %w", err)
- }
- }
- return nil
- }
- func linearCUDAQ5K(ctx *Context, input, weight, output tensor.Tensor) error {
- inShape := input.Shape()
- wShape := weight.Shape()
- M, K, N := inShape[0], inShape[1], wShape[0]
- gpu := ctx.Placement().GPU
- profile.Start("Linear/Q5K/input_upload")
- gpuInput, err := getOrUploadInput(input, gpu)
- profile.End("Linear/Q5K/input_upload")
- if err != nil {
- return err
- }
- cache := GetWeightCache(gpu)
- weightKey := fmt.Sprintf("layer%d_w_%p", ctx.LayerIdx, weight.(*cpu.Tensor))
- gpuWeight, ok := cache.Get(weightKey)
- if !ok {
- profile.Start("Linear/Q5K/weight_upload")
- cpuW := weight.(*cpu.Tensor)
- gpuWeight, err = cache.Upload(weightKey, cpuW)
- profile.End("Linear/Q5K/weight_upload")
- if err != nil {
- return fmt.Errorf("linear Q5K: cache weight: %w", err)
- }
- }
- gpuOutput, err := cuda.NewTensor(tensor.Shape{M, N}, tensor.Float32, gpu)
- if err != nil {
- return fmt.Errorf("linear Q5K: alloc output: %w", err)
- }
- defer gpuOutput.Free()
- aPtr := gpuInput.Data().(unsafe.Pointer)
- cPtr := gpuOutput.Data().(unsafe.Pointer)
- profile.Start("Linear/Q5K/matmul_kernel")
- err = cuda.MatMulQ5K(aPtr, gpuWeight, cPtr, M, K, N, gpu)
- profile.End("Linear/Q5K/matmul_kernel")
- if err != nil {
- return fmt.Errorf("linear Q5K: matmul: %w", err)
- }
- if cpuOut, ok := output.(*cpu.Tensor); ok {
- if err := gpuOutput.CopyToHost(cpuOut.DataFloat32()); err != nil {
- return fmt.Errorf("linear Q5K: copy D2H: %w", err)
- }
- }
- return nil
- }
- // linearCUDAQ8K - Q8_K weights with caching
- func linearCUDAQ8K(ctx *Context, input, weight, output tensor.Tensor) error {
- inShape := input.Shape()
- wShape := weight.Shape()
- M, K, N := inShape[0], inShape[1], wShape[0]
- gpu := ctx.Placement().GPU
- // Get GPU input
- profile.Start("Linear/Q8K/input_upload")
- gpuInput, err := getOrUploadInput(input, gpu)
- profile.End("Linear/Q8K/input_upload")
- if err != nil {
- return err
- }
- // Get cached weight or upload
- cache := GetWeightCache(gpu)
- weightKey := fmt.Sprintf("layer%d_w_%p", ctx.LayerIdx, weight.(*cpu.Tensor))
- gpuWeight, ok := cache.Get(weightKey)
- if !ok {
- profile.Start("Linear/Q8K/weight_upload")
- cpuW := weight.(*cpu.Tensor)
- gpuWeight, err = cache.Upload(weightKey, cpuW)
- profile.End("Linear/Q8K/weight_upload")
- if err != nil {
- return fmt.Errorf("linear Q8K: cache weight: %w", err)
- }
- }
- // Allocate output
- gpuOutput, err := cuda.NewTensor(tensor.Shape{M, N}, tensor.Float32, gpu)
- if err != nil {
- return fmt.Errorf("linear Q8K: alloc output: %w", err)
- }
- defer gpuOutput.Free()
- // Execute fused matmul
- aPtr := gpuInput.Data().(unsafe.Pointer)
- cPtr := gpuOutput.Data().(unsafe.Pointer)
- profile.Start("Linear/Q8K/matmul_kernel")
- err = cuda.MatMulQ8K(aPtr, gpuWeight, cPtr, M, K, N, gpu)
- profile.End("Linear/Q8K/matmul_kernel")
- if err != nil {
- return fmt.Errorf("linear Q8K: matmul: %w", err)
- }
- // Copy back
- if cpuOut, ok := output.(*cpu.Tensor); ok {
- if err := gpuOutput.CopyToHost(cpuOut.DataFloat32()); err != nil {
- return fmt.Errorf("linear Q8K: copy D2H: %w", err)
- }
- }
- return nil
- }
- // linearCUDAQ4K - Q4_K weights with caching
- func linearCUDAQ4K(ctx *Context, input, weight, output tensor.Tensor) error {
- inShape := input.Shape()
- wShape := weight.Shape()
- M, K, N := inShape[0], inShape[1], wShape[0]
- gpu := ctx.Placement().GPU
- // Get GPU input
- profile.Start("Linear/Q4K/input_upload")
- gpuInput, err := getOrUploadInput(input, gpu)
- profile.End("Linear/Q4K/input_upload")
- if err != nil {
- return err
- }
- // Get cached weight
- cache := GetWeightCache(gpu)
- weightKey := fmt.Sprintf("layer%d_w_%p", ctx.LayerIdx, weight.(*cpu.Tensor))
- gpuWeight, ok := cache.Get(weightKey)
- if !ok {
- profile.Start("Linear/Q4K/weight_upload")
- cpuW := weight.(*cpu.Tensor)
- gpuWeight, err = cache.Upload(weightKey, cpuW)
- profile.End("Linear/Q4K/weight_upload")
- if err != nil {
- return fmt.Errorf("linear Q4K: cache weight: %w", err)
- }
- }
- // Allocate output
- gpuOutput, err := cuda.NewTensor(tensor.Shape{M, N}, tensor.Float32, gpu)
- if err != nil {
- return fmt.Errorf("linear Q4K: alloc output: %w", err)
- }
- defer gpuOutput.Free()
- // Execute fused matmul
- aPtr := gpuInput.Data().(unsafe.Pointer)
- cPtr := gpuOutput.Data().(unsafe.Pointer)
- profile.Start("Linear/Q4K/matmul_kernel")
- err = cuda.MatMulQ4K(aPtr, gpuWeight, cPtr, M, K, N, gpu)
- profile.End("Linear/Q4K/matmul_kernel")
- if err != nil {
- return fmt.Errorf("linear Q4K: matmul: %w", err)
- }
- // Copy back
- if cpuOut, ok := output.(*cpu.Tensor); ok {
- if err := gpuOutput.CopyToHost(cpuOut.DataFloat32()); err != nil {
- return fmt.Errorf("linear Q4K: copy D2H: %w", err)
- }
- }
- return nil
- }
- // linearCUDAQ2K - Q2_K weights with caching
- func linearCUDAQ2K(ctx *Context, input, weight, output tensor.Tensor) error {
- inShape := input.Shape()
- wShape := weight.Shape()
- M, K, N := inShape[0], inShape[1], wShape[0]
- gpu := ctx.Placement().GPU
- // Get GPU input
- profile.Start("Linear/Q2K/input_upload")
- gpuInput, err := getOrUploadInput(input, gpu)
- profile.End("Linear/Q2K/input_upload")
- if err != nil {
- return err
- }
- // Get cached weight
- cache := GetWeightCache(gpu)
- weightKey := fmt.Sprintf("layer%d_w_%p", ctx.LayerIdx, weight.(*cpu.Tensor))
- gpuWeight, ok := cache.Get(weightKey)
- if !ok {
- profile.Start("Linear/Q2K/weight_upload")
- cpuW := weight.(*cpu.Tensor)
- gpuWeight, err = cache.Upload(weightKey, cpuW)
- profile.End("Linear/Q2K/weight_upload")
- if err != nil {
- return fmt.Errorf("linear Q2K: cache weight: %w", err)
- }
- }
- // Allocate output
- gpuOutput, err := cuda.NewTensor(tensor.Shape{M, N}, tensor.Float32, gpu)
- if err != nil {
- return fmt.Errorf("linear Q2K: alloc output: %w", err)
- }
- defer gpuOutput.Free()
- // Execute fused matmul
- profile.Start("Linear/Q2K/matmul_kernel")
- aPtr := gpuInput.Data().(unsafe.Pointer)
- cPtr := gpuOutput.Data().(unsafe.Pointer)
- err = cuda.MatMulQ2K(aPtr, gpuWeight, cPtr, M, K, N, gpu)
- profile.End("Linear/Q2K/matmul_kernel")
- if err != nil {
- return fmt.Errorf("linear Q2K: matmul: %w", err)
- }
- // Copy back
- if cpuOut, ok := output.(*cpu.Tensor); ok {
- if err := gpuOutput.CopyToHost(cpuOut.DataFloat32()); err != nil {
- return fmt.Errorf("linear Q2K: copy D2H: %w", err)
- }
- }
- return nil
- }
- // linearCUDAQ3K - Q3_K weights with caching
- func linearCUDAQ3K(ctx *Context, input, weight, output tensor.Tensor) error {
- inShape := input.Shape()
- wShape := weight.Shape()
- M, K, N := inShape[0], inShape[1], wShape[0]
- gpu := ctx.Placement().GPU
- // Get GPU input
- profile.Start("Linear/Q3K/input_upload")
- gpuInput, err := getOrUploadInput(input, gpu)
- profile.End("Linear/Q3K/input_upload")
- if err != nil {
- return err
- }
- // Get cached weight
- cache := GetWeightCache(gpu)
- weightKey := fmt.Sprintf("layer%d_w_%p", ctx.LayerIdx, weight.(*cpu.Tensor))
- gpuWeight, ok := cache.Get(weightKey)
- if !ok {
- profile.Start("Linear/Q3K/weight_upload")
- cpuW := weight.(*cpu.Tensor)
- gpuWeight, err = cache.Upload(weightKey, cpuW)
- profile.End("Linear/Q3K/weight_upload")
- if err != nil {
- return fmt.Errorf("linear Q3K: cache weight: %w", err)
- }
- }
- // Allocate output
- gpuOutput, err := cuda.NewTensor(tensor.Shape{M, N}, tensor.Float32, gpu)
- if err != nil {
- return fmt.Errorf("linear Q3K: alloc output: %w", err)
- }
- defer gpuOutput.Free()
- // Execute fused matmul
- profile.Start("Linear/Q3K/matmul_kernel")
- aPtr := gpuInput.Data().(unsafe.Pointer)
- cPtr := gpuOutput.Data().(unsafe.Pointer)
- err = cuda.MatMulQ3K(aPtr, gpuWeight, cPtr, M, K, N, gpu)
- profile.End("Linear/Q3K/matmul_kernel")
- if err != nil {
- return fmt.Errorf("linear Q3K: matmul: %w", err)
- }
- // Copy back
- if cpuOut, ok := output.(*cpu.Tensor); ok {
- if err := gpuOutput.CopyToHost(cpuOut.DataFloat32()); err != nil {
- return fmt.Errorf("linear Q3K: copy D2H: %w", err)
- }
- }
- return nil
- }
- // linearCUDAQ6K - Q6_K weights with caching
- func linearCUDAQ6K(ctx *Context, input, weight, output tensor.Tensor) error {
- inShape := input.Shape()
- wShape := weight.Shape()
- M, K, N := inShape[0], inShape[1], wShape[0]
- gpu := ctx.Placement().GPU
- // Get GPU input
- profile.Start("Linear/Q6K/input_upload")
- gpuInput, err := getOrUploadInput(input, gpu)
- profile.End("Linear/Q6K/input_upload")
- if err != nil {
- return err
- }
- // Get cached weight
- cache := GetWeightCache(gpu)
- weightKey := fmt.Sprintf("layer%d_w_%p", ctx.LayerIdx, weight.(*cpu.Tensor))
- gpuWeight, ok := cache.Get(weightKey)
- if !ok {
- profile.Start("Linear/Q6K/weight_upload")
- cpuW := weight.(*cpu.Tensor)
- gpuWeight, err = cache.Upload(weightKey, cpuW)
- profile.End("Linear/Q6K/weight_upload")
- if err != nil {
- return fmt.Errorf("linear Q6K: cache weight: %w", err)
- }
- }
- // Allocate output
- gpuOutput, err := cuda.NewTensor(tensor.Shape{M, N}, tensor.Float32, gpu)
- if err != nil {
- return fmt.Errorf("linear Q6K: alloc output: %w", err)
- }
- defer gpuOutput.Free()
- // Execute fused matmul
- profile.Start("Linear/Q6K/matmul_kernel")
- aPtr := gpuInput.Data().(unsafe.Pointer)
- cPtr := gpuOutput.Data().(unsafe.Pointer)
- err = cuda.MatMulQ6K(aPtr, gpuWeight, cPtr, M, K, N, gpu)
- profile.End("Linear/Q6K/matmul_kernel")
- if err != nil {
- return fmt.Errorf("linear Q6K: matmul: %w", err)
- }
- // Copy back
- if cpuOut, ok := output.(*cpu.Tensor); ok {
- if err := gpuOutput.CopyToHost(cpuOut.DataFloat32()); err != nil {
- return fmt.Errorf("linear Q6K: copy D2H: %w", err)
- }
- }
- return nil
- }
- // getOrUploadInput uploads CPU input to GPU
- func getOrUploadInput(input tensor.Tensor, gpu int) (*cuda.Tensor, error) {
- if cudaIn, ok := input.(*cuda.Tensor); ok {
- return cudaIn, nil
- }
- cpuIn, ok := input.(*cpu.Tensor)
- if !ok {
- return nil, fmt.Errorf("input must be CPU or CUDA tensor")
- }
- shape := input.Shape()
- gpuInput, err := cuda.NewTensor(shape, tensor.Float32, gpu)
- if err != nil {
- return nil, fmt.Errorf("alloc GPU input: %w", err)
- }
- if err := gpuInput.CopyFrom(cpuIn.DataFloat32()); err != nil {
- return nil, fmt.Errorf("copy input H2D: %w", err)
- }
- return gpuInput, nil
- }
|