graph_test.go 667 B

123456789101112131415161718192021222324252627282930
  1. package graph
  2. import (
  3. "testing"
  4. "makarna/pkg/tensor"
  5. )
  6. func TestBuildPlanUsesLayerDevices(t *testing.T) {
  7. spec := RequestSpec{
  8. ID: "req",
  9. MaxContext: 8,
  10. BlockSize: 4,
  11. NumLayers: 3,
  12. UseAttention: true,
  13. LayerDevices: []tensor.DevicePlacement{
  14. {Type: tensor.CUDA, GPU: 0},
  15. {Type: tensor.CUDA, GPU: 1},
  16. {Type: tensor.CPU, GPU: -1},
  17. },
  18. }
  19. plan := BuildPlan(spec)
  20. if len(plan.Layers) != 3 {
  21. t.Fatalf("expected 3 layers, got %d", len(plan.Layers))
  22. }
  23. if plan.Layers[0].Device.GPU != 0 || plan.Layers[1].Device.GPU != 1 || plan.Layers[2].Device.Type != tensor.CPU {
  24. t.Fatalf("unexpected devices %+v", plan.Layers)
  25. }
  26. }