| 123456789101112131415161718192021222324252627282930 |
- package graph
- import (
- "testing"
- "makarna/pkg/tensor"
- )
- func TestBuildPlanUsesLayerDevices(t *testing.T) {
- spec := RequestSpec{
- ID: "req",
- MaxContext: 8,
- BlockSize: 4,
- NumLayers: 3,
- UseAttention: true,
- LayerDevices: []tensor.DevicePlacement{
- {Type: tensor.CUDA, GPU: 0},
- {Type: tensor.CUDA, GPU: 1},
- {Type: tensor.CPU, GPU: -1},
- },
- }
- plan := BuildPlan(spec)
- if len(plan.Layers) != 3 {
- t.Fatalf("expected 3 layers, got %d", len(plan.Layers))
- }
- if plan.Layers[0].Device.GPU != 0 || plan.Layers[1].Device.GPU != 1 || plan.Layers[2].Device.Type != tensor.CPU {
- t.Fatalf("unexpected devices %+v", plan.Layers)
- }
- }
|