| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071 |
- //go:build cuda
- package matmul
- import (
- "fmt"
- "makarna/pkg/backend/cpu"
- "makarna/pkg/backend/cuda"
- "makarna/pkg/tensor"
- )
- // Linear offloads float32 matmul to CUDA when built with the cuda tag.
- // For non-float32 weights, it falls back to the CPU path.
- func Linear(input, weight, output *cpu.Tensor) error {
- // Fall back for non-float32 weights or inputs.
- if weight.DType() != tensor.Float32 || input.DType() != tensor.Float32 || output.DType() != tensor.Float32 {
- return linearCPU(input, weight, output)
- }
- inShape := input.Shape()
- wShape := weight.Shape()
- if len(inShape) != 2 || len(wShape) != 2 {
- return fmt.Errorf("linear: expected 2D inputs, got input %v, weight %v", inShape, wShape)
- }
- M := inShape[0]
- K := inShape[1]
- N := wShape[0]
- if wShape[1] != K {
- return fmt.Errorf("linear: shape mismatch: input [*, %d] vs weight [%d, %d]", K, N, wShape[1])
- }
- // Allocate CUDA buffers
- a, err := cuda.NewTensor(tensor.Shape{M, K}, tensor.Float32, 0)
- if err != nil {
- return err
- }
- // Weight stays row-major [N, K]
- b, err := cuda.NewTensor(tensor.Shape{N, K}, tensor.Float32, 0)
- if err != nil {
- return err
- }
- c, err := cuda.NewTensor(tensor.Shape{M, N}, tensor.Float32, 0)
- if err != nil {
- return err
- }
- // Copy input
- if err := a.CopyFrom(input.DataFloat32()); err != nil {
- return fmt.Errorf("linear: copy A failed: %w", err)
- }
- // Copy weight as-is (row-major [N, K]); CUDA kernel handles NT
- if err := b.CopyFrom(weight.DataFloat32()); err != nil {
- return fmt.Errorf("linear: copy B failed: %w", err)
- }
- // MatMul: c = a @ b
- if err := a.MatMul(b, c); err != nil {
- return fmt.Errorf("linear: cuda matmul failed: %w", err)
- }
- // Copy back to CPU output
- hostC := make([]float32, M*N)
- if err := c.CopyToHost(hostC); err != nil {
- return fmt.Errorf("linear: copy C failed: %w", err)
- }
- copy(output.DataFloat32(), hostC)
- return nil
- }
|