compute.go 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. // Package compute provides device-agnostic computation dispatching.
  2. // Operations automatically route to the appropriate backend (CPU/CUDA)
  3. // based on tensor placement, eliminating manual device management in model code.
  4. package compute
  5. import (
  6. "fmt"
  7. "makarna/pkg/backend/cpu"
  8. "makarna/pkg/backend/cuda"
  9. "makarna/pkg/backend/device"
  10. "makarna/pkg/tensor"
  11. )
  12. // Context holds computation state for a forward pass.
  13. type Context struct {
  14. Dispatcher *device.DeviceDispatcher
  15. LayerIdx int
  16. Scratch *ScratchSpace
  17. CPUMoE bool // Keep MoE expert weights on CPU
  18. }
  19. // NewContext creates a computation context.
  20. func NewContext(dispatcher *device.DeviceDispatcher, layerIdx int) *Context {
  21. return &Context{
  22. Dispatcher: dispatcher,
  23. LayerIdx: layerIdx,
  24. }
  25. }
  26. // Placement returns the current layer's device placement.
  27. func (c *Context) Placement() tensor.DevicePlacement {
  28. if c.Dispatcher == nil {
  29. return tensor.DevicePlacement{Type: tensor.CPU, GPU: -1}
  30. }
  31. return c.Dispatcher.LayerPlacement(c.LayerIdx)
  32. }
  33. // IsGPU returns true if current layer is on GPU.
  34. func (c *Context) IsGPU() bool {
  35. return c.Placement().Type == tensor.CUDA
  36. }
  37. // EnsureWeight ensures a weight tensor is on the correct device with caching.
  38. func (c *Context) EnsureWeight(t tensor.Tensor, name string) (tensor.Tensor, error) {
  39. if c.Dispatcher == nil {
  40. return t, nil
  41. }
  42. placement := c.Placement()
  43. if placement.Type == tensor.CPU {
  44. return t, nil
  45. }
  46. cache := c.Dispatcher.GetWeightCache(placement.GPU)
  47. key := fmt.Sprintf("%d:%s", c.LayerIdx, name)
  48. return device.EnsureOnCached(t, placement, cache, key)
  49. }
  50. // EnsureActivation ensures an activation tensor is on the correct device.
  51. // Unlike weights, activations are not cached between forward passes.
  52. func (c *Context) EnsureActivation(t tensor.Tensor) (tensor.Tensor, error) {
  53. if c.Dispatcher == nil {
  54. return t, nil
  55. }
  56. return device.EnsureOn(t, c.Placement())
  57. }
  58. // Zeros creates a zero tensor on the appropriate device.
  59. func Zeros(ctx *Context, shape tensor.Shape) tensor.Tensor {
  60. if ctx == nil || !ctx.IsGPU() || !device.CUDAAvailable() {
  61. return cpu.NewTensor(shape, nil)
  62. }
  63. t, err := cuda.NewTensor(shape, tensor.Float32, ctx.Placement().GPU)
  64. if err != nil {
  65. // Fallback to CPU
  66. return cpu.NewTensor(shape, nil)
  67. }
  68. return t
  69. }
  70. // ZerosCPU always creates a CPU tensor (for inputs/outputs).
  71. func ZerosCPU(shape tensor.Shape) *cpu.Tensor {
  72. return cpu.NewTensor(shape, nil)
  73. }
  74. // ToCPU copies a tensor to CPU if needed.
  75. func ToCPU(t tensor.Tensor) (*cpu.Tensor, error) {
  76. if cpuT, ok := t.(*cpu.Tensor); ok {
  77. return cpuT, nil
  78. }
  79. result, err := device.EnsureOn(t, tensor.DevicePlacement{Type: tensor.CPU, GPU: -1})
  80. if err != nil {
  81. return nil, err
  82. }
  83. return result.(*cpu.Tensor), nil
  84. }
  85. // Copy copies data between tensors, handling cross-device copies.
  86. func Copy(dst, src tensor.Tensor) error {
  87. // Same device, same type
  88. if dstCPU, ok := dst.(*cpu.Tensor); ok {
  89. if srcCPU, ok := src.(*cpu.Tensor); ok {
  90. copy(dstCPU.DataFloat32(), srcCPU.DataFloat32())
  91. return nil
  92. }
  93. }
  94. if dstCUDA, ok := dst.(*cuda.Tensor); ok {
  95. if srcCUDA, ok := src.(*cuda.Tensor); ok {
  96. // TODO: CUDA-to-CUDA copy kernel
  97. _ = dstCUDA
  98. _ = srcCUDA
  99. return fmt.Errorf("CUDA-to-CUDA copy not implemented")
  100. }
  101. }
  102. // Cross-device: need intermediate copy
  103. return fmt.Errorf("cross-device copy requires explicit conversion")
  104. }