main_test.go 733 B

1234567891011121314151617181920212223242526
  1. package main
  2. import (
  3. "testing"
  4. "makarna/pkg/tensor"
  5. )
  6. func TestParseLayerMap(t *testing.T) {
  7. placements := parseLayerMap(5, "0-1:gpu0,2-3:gpu1,4:cpu")
  8. if len(placements) != 5 {
  9. t.Fatalf("expected 5 placements, got %d", len(placements))
  10. }
  11. expectPlacement(t, placements[0], tensor.CUDA, 0)
  12. expectPlacement(t, placements[1], tensor.CUDA, 0)
  13. expectPlacement(t, placements[2], tensor.CUDA, 1)
  14. expectPlacement(t, placements[3], tensor.CUDA, 1)
  15. expectPlacement(t, placements[4], tensor.CPU, -1)
  16. }
  17. func expectPlacement(t *testing.T, p tensor.DevicePlacement, device tensor.DeviceType, gpu int) {
  18. if p.Type != device || p.GPU != gpu {
  19. t.Fatalf("expected device %v gpu %d, got %v gpu %d", device, gpu, p.Type, p.GPU)
  20. }
  21. }