//go:build cuda package matmul import ( "fmt" "makarna/pkg/backend/cpu" "makarna/pkg/backend/cuda" "makarna/pkg/tensor" ) // Linear offloads float32 matmul to CUDA when built with the cuda tag. // For non-float32 weights, it falls back to the CPU path. func Linear(input, weight, output *cpu.Tensor) error { // Fall back for non-float32 weights or inputs. if weight.DType() != tensor.Float32 || input.DType() != tensor.Float32 || output.DType() != tensor.Float32 { return linearCPU(input, weight, output) } inShape := input.Shape() wShape := weight.Shape() if len(inShape) != 2 || len(wShape) != 2 { return fmt.Errorf("linear: expected 2D inputs, got input %v, weight %v", inShape, wShape) } M := inShape[0] K := inShape[1] N := wShape[0] if wShape[1] != K { return fmt.Errorf("linear: shape mismatch: input [*, %d] vs weight [%d, %d]", K, N, wShape[1]) } // Allocate CUDA buffers a, err := cuda.NewTensor(tensor.Shape{M, K}, tensor.Float32, 0) if err != nil { return err } // Weight stays row-major [N, K] b, err := cuda.NewTensor(tensor.Shape{N, K}, tensor.Float32, 0) if err != nil { return err } c, err := cuda.NewTensor(tensor.Shape{M, N}, tensor.Float32, 0) if err != nil { return err } // Copy input if err := a.CopyFrom(input.DataFloat32()); err != nil { return fmt.Errorf("linear: copy A failed: %w", err) } // Copy weight as-is (row-major [N, K]); CUDA kernel handles NT if err := b.CopyFrom(weight.DataFloat32()); err != nil { return fmt.Errorf("linear: copy B failed: %w", err) } // MatMul: c = a @ b if err := a.MatMul(b, c); err != nil { return fmt.Errorf("linear: cuda matmul failed: %w", err) } // Copy back to CPU output hostC := make([]float32, M*N) if err := c.CopyToHost(hostC); err != nil { return fmt.Errorf("linear: copy C failed: %w", err) } copy(output.DataFloat32(), hostC) return nil }