package engine import ( "context" "fmt" "sort" "sync" "time" "unsafe" "makarna/pkg/backend/cpu" "makarna/pkg/backend/cuda" "makarna/pkg/backend/device" "makarna/pkg/compute" "makarna/pkg/kvcache" "makarna/pkg/model" "makarna/pkg/sample" "makarna/pkg/tensor" ) type DecodeEvent struct { Token int Done bool Err error } type DecodeSequence struct { RequestID string Ctx context.Context Cache kvcache.KVCacheInterface // History includes prompt + generated tokens so far. History []int NextInputToken int Remaining int EosID int Sampler *sample.Sampler } type Batcher struct { eng *Engine cmdCh chan any onceStart sync.Once } type registerCmd struct { seq *DecodeSequence event chan DecodeEvent resp chan error } type stopCmd struct { reqID string } type seqState struct { seq *DecodeSequence event chan DecodeEvent } func NewBatcher(eng *Engine) *Batcher { return &Batcher{ eng: eng, cmdCh: make(chan any, 1024), } } func (b *Batcher) Start() { b.onceStart.Do(func() { go b.loop() }) } func (b *Batcher) RegisterDecode(seq *DecodeSequence) (<-chan DecodeEvent, error) { if seq == nil { return nil, fmt.Errorf("nil sequence") } if seq.Cache == nil { return nil, fmt.Errorf("nil cache") } if seq.Sampler == nil { return nil, fmt.Errorf("nil sampler") } if seq.Ctx == nil { seq.Ctx = context.Background() } if seq.Remaining <= 0 { ch := make(chan DecodeEvent) close(ch) return ch, nil } b.Start() event := make(chan DecodeEvent, 16) resp := make(chan error, 1) b.cmdCh <- registerCmd{seq: seq, event: event, resp: resp} return event, <-resp } func (b *Batcher) Stop(reqID string) { if reqID == "" { return } b.Start() b.cmdCh <- stopCmd{reqID: reqID} } func (b *Batcher) loop() { seqs := make(map[string]*seqState) // Scratch set reused for all batch steps (supports multi-GPU). var scratchSet *compute.ScratchSet var baseScratch *compute.ScratchSpace if b.eng != nil && b.eng.Dispatcher() != nil && cuda.Available() { cfg := b.eng.Model().Config() gpus := collectDispatcherGPUs(b.eng.Dispatcher(), cfg.NumLayers) if len(gpus) > 0 { if ss, err := compute.NewScratchSet(gpus, compute.DefaultScratchBytes); err == nil { scratchSet = ss defer scratchSet.Free() baseScratch = scratchSet.Scratch(gpus[0]) } } } for { // Block when idle. if len(seqs) == 0 { cmd := <-b.cmdCh switch c := cmd.(type) { case registerCmd: if c.seq.RequestID == "" { c.resp <- fmt.Errorf("missing RequestID") close(c.event) continue } seqs[c.seq.RequestID] = &seqState{seq: c.seq, event: c.event} c.resp <- nil case stopCmd: if st, ok := seqs[c.reqID]; ok { st.event <- DecodeEvent{Done: true, Err: context.Canceled} close(st.event) delete(seqs, c.reqID) } } continue } // Drain control commands without blocking. for { select { case cmd := <-b.cmdCh: switch c := cmd.(type) { case registerCmd: if c.seq.RequestID == "" { c.resp <- fmt.Errorf("missing RequestID") close(c.event) continue } seqs[c.seq.RequestID] = &seqState{seq: c.seq, event: c.event} c.resp <- nil case stopCmd: if st, ok := seqs[c.reqID]; ok { st.event <- DecodeEvent{Done: true, Err: context.Canceled} close(st.event) delete(seqs, c.reqID) } } default: goto drained } } drained: // Collect ready sequences batch := make([]*seqState, 0, len(seqs)) for id, st := range seqs { if st == nil || st.seq == nil { delete(seqs, id) continue } seq := st.seq if seq.Ctx != nil { select { case <-seq.Ctx.Done(): st.event <- DecodeEvent{Done: true, Err: seq.Ctx.Err()} close(st.event) delete(seqs, id) continue default: } } if seq.Remaining <= 0 { st.event <- DecodeEvent{Done: true} close(st.event) delete(seqs, id) continue } if seq.NextInputToken == seq.EosID { st.event <- DecodeEvent{Done: true} close(st.event) delete(seqs, id) continue } batch = append(batch, st) } if len(batch) == 0 { time.Sleep(50 * time.Microsecond) continue } // Build input and positions input := cpu.NewTensor(tensor.Shape{len(batch)}, nil) pos := cpu.NewTensor(tensor.Shape{len(batch)}, nil) kvCaches := make([]model.KVCache, len(batch)) for i, st := range batch { seq := st.seq input.DataFloat32()[i] = float32(seq.NextInputToken) pos.DataFloat32()[i] = float32(seq.Cache.SeqLen()) kvCaches[i] = seq.Cache } ctx := context.Background() if scratchSet != nil { scratchSet.Reset() ctx = compute.WithScratchSet(ctx, scratchSet) } if baseScratch != nil { baseScratch.Reset() ctx = compute.WithScratch(ctx, baseScratch) } logits, err := b.eng.ForwardBatch(ctx, input, pos, kvCaches) if err != nil { for _, st := range batch { st.event <- DecodeEvent{Done: true, Err: err} close(st.event) delete(seqs, st.seq.RequestID) } continue } vocab := logits.Shape()[1] for i, st := range batch { seq := st.seq recent := seq.History if len(recent) > 64 { recent = recent[len(recent)-64:] } next, rowErr := sampleNextTokenFromLogits(logits, i, vocab, seq.Sampler, recent) if rowErr != nil { st.event <- DecodeEvent{Done: true, Err: rowErr} close(st.event) delete(seqs, seq.RequestID) continue } seq.History = append(seq.History, next) if pc, ok := seq.Cache.(*kvcache.PagedKVCache); ok { pc.AppendToken(next) } seq.NextInputToken = next seq.Remaining-- st.event <- DecodeEvent{Token: next} if next == seq.EosID || seq.Remaining <= 0 { st.event <- DecodeEvent{Done: true} close(st.event) delete(seqs, seq.RequestID) } } } } func collectDispatcherGPUs(d *device.DeviceDispatcher, numLayers int) []int { if d == nil || numLayers <= 0 { return nil } seen := make(map[int]struct{}) out := make([]int, 0, 4) for i := 0; i < numLayers; i++ { p := d.LayerPlacement(i).Normalize() if p.Type != tensor.CUDA || p.GPU < 0 { continue } if _, ok := seen[p.GPU]; ok { continue } seen[p.GPU] = struct{}{} out = append(out, p.GPU) } sort.Ints(out) return out } func logitsRowCPU(logits tensor.Tensor, row int, vocab int) ([]float32, error) { if logits == nil { return nil, fmt.Errorf("nil logits") } shape := logits.Shape() if len(shape) != 2 { return nil, fmt.Errorf("expected 2D logits, got shape %v", shape) } if row < 0 || row >= shape[0] { return nil, fmt.Errorf("row %d out of range", row) } if vocab <= 0 || vocab != shape[1] { return nil, fmt.Errorf("vocab mismatch: %d vs %d", vocab, shape[1]) } if cpuT, ok := logits.(*cpu.Tensor); ok { start := row * vocab end := start + vocab data := cpuT.DataFloat32() if end > len(data) { return nil, fmt.Errorf("cpu logits out of range") } out := make([]float32, vocab) copy(out, data[start:end]) return out, nil } if cudaT, ok := logits.(*cuda.Tensor); ok { view, err := cudaT.ViewAt(tensor.Shape{vocab}, uintptr(row*vocab*4)) if err != nil { return nil, err } host := make([]float32, vocab) if err := view.CopyToHost(host); err != nil { return nil, err } return host, nil } if p, ok := logits.Data().(unsafe.Pointer); ok && p != nil { _ = p } return nil, fmt.Errorf("unsupported logits tensor type %T", logits) } func sampleNextTokenFromLogits(logits tensor.Tensor, row int, vocab int, sampler *sample.Sampler, recent []int) (int, error) { if logits == nil { return 0, fmt.Errorf("nil logits") } if sampler == nil { return 0, fmt.Errorf("nil sampler") } if row < 0 { return 0, fmt.Errorf("row %d out of range", row) } cfg := sampler.Config() k := cfg.TopK if cfg.Temperature == 0 { k = 1 } // CPU logits: zero-copy row slice. if cpuT, ok := logits.(*cpu.Tensor); ok { shape := cpuT.Shape() if len(shape) != 2 || shape[1] != vocab { return 0, fmt.Errorf("expected logits shape [*,%d], got %v", vocab, shape) } if row >= shape[0] { return 0, fmt.Errorf("row %d out of range", row) } start := row * vocab end := start + vocab data := cpuT.DataFloat32() if end > len(data) { return 0, fmt.Errorf("cpu logits out of range") } return sampler.Sample(data[start:end], recent), nil } // CUDA logits: prefer GPU top-k path when enabled/greedy and supported by kernel. if cudaT, ok := logits.(*cuda.Tensor); ok { shape := cudaT.Shape() if len(shape) != 2 || shape[1] != vocab { return 0, fmt.Errorf("expected logits shape [*,%d], got %v", vocab, shape) } if row >= shape[0] { return 0, fmt.Errorf("row %d out of range", row) } view, err := cudaT.ViewAt(tensor.Shape{vocab}, uintptr(row*vocab*4)) if err != nil { return 0, err } // CUDA TopK kernel supports k<=64; fall back when disabled or too large. if k <= 0 || k > 64 { host := make([]float32, vocab) if err := view.CopyToHost(host); err != nil { return 0, err } return sampler.Sample(host, recent), nil } repPenalty := cfg.RepetitionPenalty if repPenalty <= 0 { repPenalty = 1.0 } repIDs := make([]int32, len(recent)) for i, t := range recent { repIDs[i] = int32(t) } allIDs, allScores, blocks, err := cuda.TopKLogitsF32(view.Data().(unsafe.Pointer), vocab, repIDs, repPenalty, k, cudaT.GPU()) if err != nil { return 0, err } cands := make([]struct { id int32 score float32 }, 0, blocks*k) for i := 0; i < blocks*k; i++ { if allIDs[i] < 0 { continue } cands = append(cands, struct { id int32 score float32 }{id: allIDs[i], score: allScores[i]}) } if len(cands) == 0 { host := make([]float32, vocab) if err := view.CopyToHost(host); err != nil { return 0, err } return sampler.Sample(host, recent), nil } sort.Slice(cands, func(i, j int) bool { return cands[i].score > cands[j].score }) if len(cands) > k { cands = cands[:k] } finalIDs := make([]int32, len(cands)) finalScores := make([]float32, len(cands)) for i := range cands { finalIDs[i] = cands[i].id finalScores[i] = cands[i].score } return sampler.SampleFromTopK(finalIDs, finalScores), nil } return 0, fmt.Errorf("unsupported logits type: %T", logits) }