scheduler_test.go 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. package engine
  2. import (
  3. "context"
  4. "testing"
  5. "makarna/pkg/backend/cpu"
  6. "makarna/pkg/backend/cpu/nn"
  7. "makarna/pkg/graph"
  8. "makarna/pkg/kvcache"
  9. "makarna/pkg/model"
  10. "makarna/pkg/tensor"
  11. )
  12. func TestSchedulerOffsetsPositionsAndContext(t *testing.T) {
  13. cfg := &model.Config{
  14. VocabSize: 5,
  15. NumLayers: 1,
  16. NumKVHeads: 1,
  17. HeadDim: 2,
  18. }
  19. mock := &mockModel{cfg: cfg}
  20. engine := &Engine{model: mock}
  21. plan := graph.BuildPlan(graph.RequestSpec{
  22. ID: "req-1",
  23. MaxContext: 8,
  24. BlockSize: 4,
  25. NumLayers: 1,
  26. UseAttention: true,
  27. LayerDevices: []tensor.DevicePlacement{{Type: tensor.CUDA, GPU: 0}},
  28. })
  29. pool, err := kvcache.NewBlockPool(kvcache.BlockPoolConfig{
  30. NumLayers: 1,
  31. NumKVHeads: 1,
  32. HeadDim: 2,
  33. BlockSize: 4,
  34. NumBlocks: 8,
  35. Device: tensor.CPU,
  36. GPU: 0,
  37. })
  38. if err != nil {
  39. t.Fatalf("NewBlockPool: %v", err)
  40. }
  41. cache := kvcache.NewPagedKVCache(pool, kvcache.PagedCacheConfig{
  42. NumLayers: 1,
  43. NumKVHeads: 1,
  44. HeadDim: 2,
  45. BlockSize: 4,
  46. MaxSeqLen: 8,
  47. Device: tensor.CPU,
  48. GPU: 0,
  49. }, "sched")
  50. defer cache.Free()
  51. sched := engine.NewScheduler(plan, cache)
  52. if _, err := sched.Prefill([]int{1, 2, 3}); err != nil {
  53. t.Fatalf("prefill failed: %v", err)
  54. }
  55. if _, err := sched.Decode([]int{4, 5}); err != nil {
  56. t.Fatalf("decode failed: %v", err)
  57. }
  58. if cache.SeqLen() != 5 {
  59. t.Fatalf("expected seqLen 5, got %d", cache.SeqLen())
  60. }
  61. if remaining := sched.RemainingContext(); remaining != 3 {
  62. t.Fatalf("expected remaining context 3, got %d", remaining)
  63. }
  64. if len(mock.positions) != 2 {
  65. t.Fatalf("expected 2 forward calls, got %d", len(mock.positions))
  66. }
  67. expectSlice(t, []int{0, 1, 2}, mock.positions[0])
  68. expectSlice(t, []int{3, 4}, mock.positions[1])
  69. if plan.Layers[0].Device.Type != tensor.CUDA || plan.Layers[0].Device.GPU != 0 {
  70. t.Fatalf("expected layer device cuda:0, got %+v", plan.Layers[0].Device)
  71. }
  72. }
  73. type mockModel struct {
  74. cfg *model.Config
  75. positions [][]int
  76. }
  77. func (m *mockModel) Forward(ctx context.Context, input tensor.Tensor, positions tensor.Tensor, kv model.KVCache) (tensor.Tensor, error) {
  78. seq := input.Shape()[0]
  79. pos := nn.ParsePositions(positions, seq)
  80. m.positions = append(m.positions, pos)
  81. return cpu.NewTensor(tensor.Shape{seq, m.cfg.VocabSize}, nil), nil
  82. }
  83. func (m *mockModel) Config() *model.Config { return m.cfg }
  84. func (m *mockModel) Close() error { return nil }
  85. func (m *mockModel) SetTensor(string, tensor.Tensor) error {
  86. return nil
  87. }
  88. func expectSlice(t *testing.T, want, got []int) {
  89. if len(want) != len(got) {
  90. t.Fatalf("length mismatch want=%v got=%v", want, got)
  91. }
  92. for i := range want {
  93. if want[i] != got[i] {
  94. t.Fatalf("slice mismatch at %d: want=%d got=%d", i, want[i], got[i])
  95. }
  96. }
  97. }