| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576 |
- package engine
- import (
- "fmt"
- "makarna/pkg/backend/cpu"
- "makarna/pkg/graph"
- "makarna/pkg/model"
- "makarna/pkg/tensor"
- )
- // Scheduler orchestrates prefill and decode while reusing a static graph plan.
- // It supports multi-token steps (len(tokens) > 1) for continuous batching.
- type Scheduler struct {
- engine *Engine
- plan graph.ExecutionPlan
- cache model.KVCache
- }
- // NewScheduler binds a model engine with a graph plan and KV cache.
- func (e *Engine) NewScheduler(plan graph.ExecutionPlan, cache model.KVCache) *Scheduler {
- return &Scheduler{
- engine: e,
- plan: plan,
- cache: cache,
- }
- }
- // RemainingContext returns how many tokens can still fit in the reserved plan.
- func (s *Scheduler) RemainingContext() int {
- return s.plan.MaxContext - s.cache.SeqLen()
- }
- // Prefill runs a single step prefill (full prompt) and advances the cache.
- func (s *Scheduler) Prefill(tokens []int) (tensor.Tensor, error) {
- return s.run(tokens)
- }
- // Decode runs multi-token decode in one step. The caller is responsible for
- // picking the next tokens from the returned logits.
- func (s *Scheduler) Decode(tokens []int) (tensor.Tensor, error) {
- return s.run(tokens)
- }
- func (s *Scheduler) run(tokens []int) (tensor.Tensor, error) {
- if len(tokens) == 0 {
- return nil, fmt.Errorf("no tokens to run")
- }
- if s.cache.SeqLen()+len(tokens) > s.plan.MaxContext {
- return nil, fmt.Errorf("context limit exceeded: need %d, max %d", s.cache.SeqLen()+len(tokens), s.plan.MaxContext)
- }
- input := cpu.NewTensor(tensor.Shape{len(tokens)}, nil)
- for i, id := range tokens {
- input.DataFloat32()[i] = float32(id)
- }
- pos := cpu.NewTensor(tensor.Shape{len(tokens)}, nil)
- base := s.cache.SeqLen()
- for i := range tokens {
- pos.DataFloat32()[i] = float32(base + i)
- }
- before := s.cache.SeqLen()
- logits, err := s.engine.Forward(nil, input, pos, s.cache)
- if err != nil {
- return nil, err
- }
- // Some model implementations advance the cache internally.
- // If the cache didn't advance, commit here.
- if s.cache.SeqLen() == before {
- s.cache.Commit(len(tokens))
- }
- return logits, nil
- }
|