| 1234567891011121314151617181920212223242526 |
- package main
- import (
- "testing"
- "makarna/pkg/tensor"
- )
- func TestParseLayerMap(t *testing.T) {
- placements := parseLayerMap(5, "0-1:gpu0,2-3:gpu1,4:cpu")
- if len(placements) != 5 {
- t.Fatalf("expected 5 placements, got %d", len(placements))
- }
- expectPlacement(t, placements[0], tensor.CUDA, 0)
- expectPlacement(t, placements[1], tensor.CUDA, 0)
- expectPlacement(t, placements[2], tensor.CUDA, 1)
- expectPlacement(t, placements[3], tensor.CUDA, 1)
- expectPlacement(t, placements[4], tensor.CPU, -1)
- }
- func expectPlacement(t *testing.T, p tensor.DevicePlacement, device tensor.DeviceType, gpu int) {
- if p.Type != device || p.GPU != gpu {
- t.Fatalf("expected device %v gpu %d, got %v gpu %d", device, gpu, p.Type, p.GPU)
- }
- }
|