package engine import ( "context" "testing" "makarna/pkg/backend/cpu" "makarna/pkg/backend/cpu/nn" "makarna/pkg/graph" "makarna/pkg/kvcache" "makarna/pkg/model" "makarna/pkg/tensor" ) func TestSchedulerOffsetsPositionsAndContext(t *testing.T) { cfg := &model.Config{ VocabSize: 5, NumLayers: 1, NumKVHeads: 1, HeadDim: 2, } mock := &mockModel{cfg: cfg} engine := &Engine{model: mock} plan := graph.BuildPlan(graph.RequestSpec{ ID: "req-1", MaxContext: 8, BlockSize: 4, NumLayers: 1, UseAttention: true, LayerDevices: []tensor.DevicePlacement{{Type: tensor.CUDA, GPU: 0}}, }) pool, err := kvcache.NewBlockPool(kvcache.BlockPoolConfig{ NumLayers: 1, NumKVHeads: 1, HeadDim: 2, BlockSize: 4, NumBlocks: 8, Device: tensor.CPU, GPU: 0, }) if err != nil { t.Fatalf("NewBlockPool: %v", err) } cache := kvcache.NewPagedKVCache(pool, kvcache.PagedCacheConfig{ NumLayers: 1, NumKVHeads: 1, HeadDim: 2, BlockSize: 4, MaxSeqLen: 8, Device: tensor.CPU, GPU: 0, }, "sched") defer cache.Free() sched := engine.NewScheduler(plan, cache) if _, err := sched.Prefill([]int{1, 2, 3}); err != nil { t.Fatalf("prefill failed: %v", err) } if _, err := sched.Decode([]int{4, 5}); err != nil { t.Fatalf("decode failed: %v", err) } if cache.SeqLen() != 5 { t.Fatalf("expected seqLen 5, got %d", cache.SeqLen()) } if remaining := sched.RemainingContext(); remaining != 3 { t.Fatalf("expected remaining context 3, got %d", remaining) } if len(mock.positions) != 2 { t.Fatalf("expected 2 forward calls, got %d", len(mock.positions)) } expectSlice(t, []int{0, 1, 2}, mock.positions[0]) expectSlice(t, []int{3, 4}, mock.positions[1]) if plan.Layers[0].Device.Type != tensor.CUDA || plan.Layers[0].Device.GPU != 0 { t.Fatalf("expected layer device cuda:0, got %+v", plan.Layers[0].Device) } } type mockModel struct { cfg *model.Config positions [][]int } func (m *mockModel) Forward(ctx context.Context, input tensor.Tensor, positions tensor.Tensor, kv model.KVCache) (tensor.Tensor, error) { seq := input.Shape()[0] pos := nn.ParsePositions(positions, seq) m.positions = append(m.positions, pos) return cpu.NewTensor(tensor.Shape{seq, m.cfg.VocabSize}, nil), nil } func (m *mockModel) Config() *model.Config { return m.cfg } func (m *mockModel) Close() error { return nil } func (m *mockModel) SetTensor(string, tensor.Tensor) error { return nil } func expectSlice(t *testing.T, want, got []int) { if len(want) != len(got) { t.Fatalf("length mismatch want=%v got=%v", want, got) } for i := range want { if want[i] != got[i] { t.Fatalf("slice mismatch at %d: want=%d got=%d", i, want[i], got[i]) } } }