1
0

linear_cuda.go 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. //go:build cuda
  2. package matmul
  3. import (
  4. "fmt"
  5. "makarna/pkg/backend/cpu"
  6. "makarna/pkg/backend/cuda"
  7. "makarna/pkg/tensor"
  8. )
  9. // Linear offloads float32 matmul to CUDA when built with the cuda tag.
  10. // For non-float32 weights, it falls back to the CPU path.
  11. func Linear(input, weight, output *cpu.Tensor) error {
  12. // Fall back for non-float32 weights or inputs.
  13. if weight.DType() != tensor.Float32 || input.DType() != tensor.Float32 || output.DType() != tensor.Float32 {
  14. return linearCPU(input, weight, output)
  15. }
  16. inShape := input.Shape()
  17. wShape := weight.Shape()
  18. if len(inShape) != 2 || len(wShape) != 2 {
  19. return fmt.Errorf("linear: expected 2D inputs, got input %v, weight %v", inShape, wShape)
  20. }
  21. M := inShape[0]
  22. K := inShape[1]
  23. N := wShape[0]
  24. if wShape[1] != K {
  25. return fmt.Errorf("linear: shape mismatch: input [*, %d] vs weight [%d, %d]", K, N, wShape[1])
  26. }
  27. // Allocate CUDA buffers
  28. a, err := cuda.NewTensor(tensor.Shape{M, K}, tensor.Float32, 0)
  29. if err != nil {
  30. return err
  31. }
  32. // Weight stays row-major [N, K]
  33. b, err := cuda.NewTensor(tensor.Shape{N, K}, tensor.Float32, 0)
  34. if err != nil {
  35. return err
  36. }
  37. c, err := cuda.NewTensor(tensor.Shape{M, N}, tensor.Float32, 0)
  38. if err != nil {
  39. return err
  40. }
  41. // Copy input
  42. if err := a.CopyFrom(input.DataFloat32()); err != nil {
  43. return fmt.Errorf("linear: copy A failed: %w", err)
  44. }
  45. // Copy weight as-is (row-major [N, K]); CUDA kernel handles NT
  46. if err := b.CopyFrom(weight.DataFloat32()); err != nil {
  47. return fmt.Errorf("linear: copy B failed: %w", err)
  48. }
  49. // MatMul: c = a @ b
  50. if err := a.MatMul(b, c); err != nil {
  51. return fmt.Errorf("linear: cuda matmul failed: %w", err)
  52. }
  53. // Copy back to CPU output
  54. hostC := make([]float32, M*N)
  55. if err := c.CopyToHost(hostC); err != nil {
  56. return fmt.Errorf("linear: copy C failed: %w", err)
  57. }
  58. copy(output.DataFloat32(), hostC)
  59. return nil
  60. }