linear_cuda_test.go 1018 B

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. //go:build cuda
  2. package matmul
  3. import (
  4. "testing"
  5. "makarna/pkg/backend/cpu"
  6. "makarna/pkg/tensor"
  7. )
  8. func TestLinearCudaMatchesCPU(t *testing.T) {
  9. // Small matrix to compare CPU vs CUDA paths.
  10. M, K, N := 4, 8, 3
  11. aCPU := cpu.NewTensor(tensor.Shape{M, K}, nil)
  12. wCPU := cpu.NewTensor(tensor.Shape{N, K}, nil)
  13. outCPU := cpu.NewTensor(tensor.Shape{M, N}, nil)
  14. outCUDA := cpu.NewTensor(tensor.Shape{M, N}, nil)
  15. fillSeq(aCPU.DataFloat32())
  16. fillSeq(wCPU.DataFloat32())
  17. if err := linearCPU(aCPU, wCPU, outCPU); err != nil {
  18. t.Fatalf("cpu linear failed: %v", err)
  19. }
  20. if err := Linear(aCPU, wCPU, outCUDA); err != nil {
  21. t.Fatalf("cuda linear failed: %v", err)
  22. }
  23. for i, v := range outCPU.DataFloat32() {
  24. got := outCUDA.DataFloat32()[i]
  25. if diff := abs32(v - got); diff > 1e-4 {
  26. t.Fatalf("mismatch at %d: cpu=%f cuda=%f", i, v, got)
  27. }
  28. }
  29. }
  30. func fillSeq(dst []float32) {
  31. for i := range dst {
  32. dst[i] = float32(i%7 + 1)
  33. }
  34. }
  35. func abs32(v float32) float32 {
  36. if v < 0 {
  37. return -v
  38. }
  39. return v
  40. }