batcher_test.go 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. package engine
  2. import (
  3. "context"
  4. "sync/atomic"
  5. "testing"
  6. "time"
  7. "makarna/pkg/backend/cpu"
  8. "makarna/pkg/kvcache"
  9. "makarna/pkg/model"
  10. "makarna/pkg/sample"
  11. "makarna/pkg/tensor"
  12. )
  13. type mockBatchModel struct {
  14. cfg *model.Config
  15. forwardBatches atomic.Int64
  16. }
  17. func (m *mockBatchModel) Forward(ctx context.Context, input tensor.Tensor, positions tensor.Tensor, kv model.KVCache) (tensor.Tensor, error) {
  18. seq := input.Shape()[0]
  19. return cpu.NewTensor(tensor.Shape{seq, m.cfg.VocabSize}, nil), nil
  20. }
  21. func (m *mockBatchModel) ForwardBatch(ctx context.Context, input tensor.Tensor, positions tensor.Tensor, kvCaches []model.KVCache) (tensor.Tensor, error) {
  22. m.forwardBatches.Add(1)
  23. seq := input.Shape()[0]
  24. // Make token 0 always best so sampler deterministically picks 0.
  25. out := cpu.NewTensor(tensor.Shape{seq, m.cfg.VocabSize}, nil)
  26. for i := 0; i < seq; i++ {
  27. row := out.DataFloat32()[i*m.cfg.VocabSize : (i+1)*m.cfg.VocabSize]
  28. row[0] = 10
  29. }
  30. return out, nil
  31. }
  32. func (m *mockBatchModel) Config() *model.Config { return m.cfg }
  33. func (m *mockBatchModel) Close() error { return nil }
  34. func (m *mockBatchModel) SetTensor(string, tensor.Tensor) error {
  35. return nil
  36. }
  37. type mockKV struct{ seqLen int }
  38. func (k *mockKV) SeqLen() int { return k.seqLen }
  39. func (k *mockKV) Commit(newTokens int) { k.seqLen += newTokens }
  40. func (k *mockKV) Append(layer int, kt, vt tensor.Tensor) ([]kvcache.View, int, error) {
  41. return nil, k.seqLen, nil
  42. }
  43. func (k *mockKV) ContiguousKV(layer, kvLen, kvDim int) (tensor.Tensor, tensor.Tensor, bool, error) {
  44. return nil, nil, false, nil
  45. }
  46. func (k *mockKV) Views(layer int) []kvcache.View { return nil }
  47. func (k *mockKV) IsOnGPU() bool { return false }
  48. func (k *mockKV) LayerDevice(layer int) tensor.DevicePlacement {
  49. return tensor.DevicePlacement{Type: tensor.CPU, GPU: -1}
  50. }
  51. func (k *mockKV) MaxSeqLen() int { return 128 }
  52. func (k *mockKV) Truncate(seqLen int) {
  53. if seqLen < 0 {
  54. seqLen = 0
  55. }
  56. if seqLen < k.seqLen {
  57. k.seqLen = seqLen
  58. }
  59. }
  60. func (k *mockKV) Free() {}
  61. func TestBatcher_MicroBatchesDecode(t *testing.T) {
  62. m := &mockBatchModel{cfg: &model.Config{VocabSize: 8, NumLayers: 1, NumKVHeads: 1, HeadDim: 2}}
  63. eng := &Engine{model: m}
  64. b := NewBatcher(eng)
  65. b.Start()
  66. mkSeq := func(id string) *DecodeSequence {
  67. return &DecodeSequence{
  68. RequestID: id,
  69. Ctx: context.Background(),
  70. Cache: &mockKV{},
  71. History: []int{1, 2, 3},
  72. NextInputToken: 4,
  73. Remaining: 3,
  74. EosID: 7,
  75. Sampler: sample.New(sample.Config{Temperature: 0}),
  76. }
  77. }
  78. ev1, err := b.RegisterDecode(mkSeq("r1"))
  79. if err != nil {
  80. t.Fatalf("register r1: %v", err)
  81. }
  82. ev2, err := b.RegisterDecode(mkSeq("r2"))
  83. if err != nil {
  84. t.Fatalf("register r2: %v", err)
  85. }
  86. // Drain until both are done.
  87. waitDone := func(ch <-chan DecodeEvent) {
  88. timeout := time.After(2 * time.Second)
  89. for {
  90. select {
  91. case ev, ok := <-ch:
  92. if !ok {
  93. return
  94. }
  95. if ev.Err != nil {
  96. t.Fatalf("event err: %v", ev.Err)
  97. }
  98. if ev.Done {
  99. return
  100. }
  101. case <-timeout:
  102. t.Fatalf("timeout")
  103. }
  104. }
  105. }
  106. waitDone(ev1)
  107. waitDone(ev2)
  108. calls := m.forwardBatches.Load()
  109. if calls <= 0 {
  110. t.Fatalf("expected ForwardBatch calls > 0")
  111. }
  112. // In ideal case, 3 steps total; but allow scheduling variance. Still should be far less than per-seq calls.
  113. if calls > 6 {
  114. t.Fatalf("expected <=6 ForwardBatch calls, got %d", calls)
  115. }
  116. }