package engine import ( "context" "sync/atomic" "testing" "time" "makarna/pkg/backend/cpu" "makarna/pkg/kvcache" "makarna/pkg/model" "makarna/pkg/sample" "makarna/pkg/tensor" ) type mockBatchModel struct { cfg *model.Config forwardBatches atomic.Int64 } func (m *mockBatchModel) Forward(ctx context.Context, input tensor.Tensor, positions tensor.Tensor, kv model.KVCache) (tensor.Tensor, error) { seq := input.Shape()[0] return cpu.NewTensor(tensor.Shape{seq, m.cfg.VocabSize}, nil), nil } func (m *mockBatchModel) ForwardBatch(ctx context.Context, input tensor.Tensor, positions tensor.Tensor, kvCaches []model.KVCache) (tensor.Tensor, error) { m.forwardBatches.Add(1) seq := input.Shape()[0] // Make token 0 always best so sampler deterministically picks 0. out := cpu.NewTensor(tensor.Shape{seq, m.cfg.VocabSize}, nil) for i := 0; i < seq; i++ { row := out.DataFloat32()[i*m.cfg.VocabSize : (i+1)*m.cfg.VocabSize] row[0] = 10 } return out, nil } func (m *mockBatchModel) Config() *model.Config { return m.cfg } func (m *mockBatchModel) Close() error { return nil } func (m *mockBatchModel) SetTensor(string, tensor.Tensor) error { return nil } type mockKV struct{ seqLen int } func (k *mockKV) SeqLen() int { return k.seqLen } func (k *mockKV) Commit(newTokens int) { k.seqLen += newTokens } func (k *mockKV) Append(layer int, kt, vt tensor.Tensor) ([]kvcache.View, int, error) { return nil, k.seqLen, nil } func (k *mockKV) ContiguousKV(layer, kvLen, kvDim int) (tensor.Tensor, tensor.Tensor, bool, error) { return nil, nil, false, nil } func (k *mockKV) Views(layer int) []kvcache.View { return nil } func (k *mockKV) IsOnGPU() bool { return false } func (k *mockKV) LayerDevice(layer int) tensor.DevicePlacement { return tensor.DevicePlacement{Type: tensor.CPU, GPU: -1} } func (k *mockKV) MaxSeqLen() int { return 128 } func (k *mockKV) Truncate(seqLen int) { if seqLen < 0 { seqLen = 0 } if seqLen < k.seqLen { k.seqLen = seqLen } } func (k *mockKV) Free() {} func TestBatcher_MicroBatchesDecode(t *testing.T) { m := &mockBatchModel{cfg: &model.Config{VocabSize: 8, NumLayers: 1, NumKVHeads: 1, HeadDim: 2}} eng := &Engine{model: m} b := NewBatcher(eng) b.Start() mkSeq := func(id string) *DecodeSequence { return &DecodeSequence{ RequestID: id, Ctx: context.Background(), Cache: &mockKV{}, History: []int{1, 2, 3}, NextInputToken: 4, Remaining: 3, EosID: 7, Sampler: sample.New(sample.Config{Temperature: 0}), } } ev1, err := b.RegisterDecode(mkSeq("r1")) if err != nil { t.Fatalf("register r1: %v", err) } ev2, err := b.RegisterDecode(mkSeq("r2")) if err != nil { t.Fatalf("register r2: %v", err) } // Drain until both are done. waitDone := func(ch <-chan DecodeEvent) { timeout := time.After(2 * time.Second) for { select { case ev, ok := <-ch: if !ok { return } if ev.Err != nil { t.Fatalf("event err: %v", ev.Err) } if ev.Done { return } case <-timeout: t.Fatalf("timeout") } } } waitDone(ev1) waitDone(ev2) calls := m.forwardBatches.Load() if calls <= 0 { t.Fatalf("expected ForwardBatch calls > 0") } // In ideal case, 3 steps total; but allow scheduling variance. Still should be far less than per-seq calls. if calls > 6 { t.Fatalf("expected <=6 ForwardBatch calls, got %d", calls) } }