package engine import ( "fmt" "makarna/pkg/backend/cpu" "makarna/pkg/graph" "makarna/pkg/model" "makarna/pkg/tensor" ) // Scheduler orchestrates prefill and decode while reusing a static graph plan. // It supports multi-token steps (len(tokens) > 1) for continuous batching. type Scheduler struct { engine *Engine plan graph.ExecutionPlan cache model.KVCache } // NewScheduler binds a model engine with a graph plan and KV cache. func (e *Engine) NewScheduler(plan graph.ExecutionPlan, cache model.KVCache) *Scheduler { return &Scheduler{ engine: e, plan: plan, cache: cache, } } // RemainingContext returns how many tokens can still fit in the reserved plan. func (s *Scheduler) RemainingContext() int { return s.plan.MaxContext - s.cache.SeqLen() } // Prefill runs a single step prefill (full prompt) and advances the cache. func (s *Scheduler) Prefill(tokens []int) (tensor.Tensor, error) { return s.run(tokens) } // Decode runs multi-token decode in one step. The caller is responsible for // picking the next tokens from the returned logits. func (s *Scheduler) Decode(tokens []int) (tensor.Tensor, error) { return s.run(tokens) } func (s *Scheduler) run(tokens []int) (tensor.Tensor, error) { if len(tokens) == 0 { return nil, fmt.Errorf("no tokens to run") } if s.cache.SeqLen()+len(tokens) > s.plan.MaxContext { return nil, fmt.Errorf("context limit exceeded: need %d, max %d", s.cache.SeqLen()+len(tokens), s.plan.MaxContext) } input := cpu.NewTensor(tensor.Shape{len(tokens)}, nil) for i, id := range tokens { input.DataFloat32()[i] = float32(id) } pos := cpu.NewTensor(tensor.Shape{len(tokens)}, nil) base := s.cache.SeqLen() for i := range tokens { pos.DataFloat32()[i] = float32(base + i) } before := s.cache.SeqLen() logits, err := s.engine.Forward(nil, input, pos, s.cache) if err != nil { return nil, err } // Some model implementations advance the cache internally. // If the cache didn't advance, commit here. if s.cache.SeqLen() == before { s.cache.Commit(len(tokens)) } return logits, nil }