linear.go 982 B

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. //go:build !cuda
  2. // Package compute provides device-agnostic computation dispatching.
  3. package compute
  4. import (
  5. "fmt"
  6. "makarna/pkg/backend/cpu"
  7. "makarna/pkg/backend/cpu/matmul"
  8. "makarna/pkg/tensor"
  9. )
  10. // Linear performs a linear layer: output = input @ weight.T
  11. // CPU-only build: always uses CPU path.
  12. func Linear(ctx *Context, input, weight, output tensor.Tensor) error {
  13. return linearCPU(input, weight, output)
  14. }
  15. // linearCPU executes matmul on CPU
  16. func linearCPU(input, weight, output tensor.Tensor) error {
  17. inCPU, ok := input.(*cpu.Tensor)
  18. if !ok {
  19. var err error
  20. inCPU, err = ToCPU(input)
  21. if err != nil {
  22. return fmt.Errorf("linear: failed to get CPU input: %w", err)
  23. }
  24. }
  25. wCPU, ok := weight.(*cpu.Tensor)
  26. if !ok {
  27. return fmt.Errorf("linear: weight must be CPU tensor for CPU path")
  28. }
  29. outCPU, ok := output.(*cpu.Tensor)
  30. if !ok {
  31. return fmt.Errorf("linear: output must be CPU tensor for CPU path")
  32. }
  33. return matmul.Linear(inCPU, wCPU, outCPU)
  34. }