1
0

graph.go 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. package graph
  2. import "makarna/pkg/tensor"
  3. // ExecutionPlan describes the static computation graph for a request.
  4. // It is intentionally lightweight so the plan can be reused across
  5. // decode steps without rebuilding the structure.
  6. type ExecutionPlan struct {
  7. // RequestID links the plan to a running session.
  8. RequestID string
  9. // MaxContext tokens reserved for this request. The KV cache manager
  10. // must have already reserved enough blocks to satisfy this budget.
  11. MaxContext int
  12. // BlockSize controls how many tokens are packed in each KV block.
  13. BlockSize int
  14. // Layers lists per-layer stage information (prefill/decode flags).
  15. Layers []LayerPlan
  16. }
  17. // LayerPlan captures per-layer execution intent. The current engine only
  18. // needs to distinguish whether a layer participates in decode.
  19. type LayerPlan struct {
  20. Index int
  21. HasAttention bool
  22. HasMLP bool
  23. SupportsDecode bool
  24. Device tensor.DevicePlacement
  25. }
  26. // RequestSpec declares what a caller wants to run. The scheduler converts
  27. // this into an ExecutionPlan and hands it to the runtime.
  28. type RequestSpec struct {
  29. ID string
  30. MaxContext int
  31. BlockSize int
  32. NumLayers int
  33. UseAttention bool
  34. LayerDevices []tensor.DevicePlacement
  35. }
  36. // BuildPlan produces a minimal ExecutionPlan suitable for single-GPU decode.
  37. // The plan stays constant while the scheduler feeds new token batches.
  38. func BuildPlan(spec RequestSpec) ExecutionPlan {
  39. plan := ExecutionPlan{
  40. RequestID: spec.ID,
  41. MaxContext: spec.MaxContext,
  42. BlockSize: spec.BlockSize,
  43. Layers: make([]LayerPlan, spec.NumLayers),
  44. }
  45. for i := 0; i < spec.NumLayers; i++ {
  46. device := tensor.DevicePlacement{Type: tensor.CPU, GPU: -1}
  47. if i < len(spec.LayerDevices) {
  48. device = spec.LayerDevices[i].Normalize()
  49. }
  50. plan.Layers[i] = LayerPlan{
  51. Index: i,
  52. HasAttention: spec.UseAttention,
  53. HasMLP: true,
  54. SupportsDecode: true,
  55. Device: device,
  56. }
  57. }
  58. return plan
  59. }