device.go 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. // Package device provides cross-device tensor operations and placement management.
  2. // It serves as the central hub for device-aware computation in the makarna engine.
  3. package device
  4. import (
  5. "fmt"
  6. "sync"
  7. "unsafe"
  8. "makarna/pkg/backend/cpu"
  9. "makarna/pkg/backend/cuda"
  10. "makarna/pkg/tensor"
  11. )
  12. // WeightCache caches GPU copies of weights to avoid repeated H2D transfers.
  13. // Thread-safe for concurrent layer execution.
  14. type WeightCache struct {
  15. mu sync.RWMutex
  16. cache map[string]*cuda.Tensor // key: "layer_idx:weight_name"
  17. gpuID int
  18. }
  19. // NewWeightCache creates a new weight cache for a specific GPU.
  20. func NewWeightCache(gpuID int) *WeightCache {
  21. return &WeightCache{
  22. cache: make(map[string]*cuda.Tensor),
  23. gpuID: gpuID,
  24. }
  25. }
  26. // Get retrieves a cached GPU tensor, returning nil if not cached.
  27. func (wc *WeightCache) Get(key string) *cuda.Tensor {
  28. wc.mu.RLock()
  29. defer wc.mu.RUnlock()
  30. return wc.cache[key]
  31. }
  32. // Put adds a GPU tensor to the cache.
  33. func (wc *WeightCache) Put(key string, t *cuda.Tensor) {
  34. wc.mu.Lock()
  35. defer wc.mu.Unlock()
  36. wc.cache[key] = t
  37. }
  38. // Clear frees all cached GPU tensors.
  39. func (wc *WeightCache) Clear() {
  40. wc.mu.Lock()
  41. defer wc.mu.Unlock()
  42. wc.cache = make(map[string]*cuda.Tensor)
  43. }
  44. // EnsureOn returns a tensor on the requested placement, copying if needed.
  45. // For CPU tensors going to CUDA, this creates a NEW tensor each time.
  46. // Use EnsureOnCached for weight tensors that should be cached.
  47. func EnsureOn(t tensor.Tensor, target tensor.DevicePlacement) (tensor.Tensor, error) {
  48. if twp, ok := t.(tensor.TensorWithPlacement); ok {
  49. if twp.Placement() == target.Normalize() {
  50. return t, nil
  51. }
  52. }
  53. switch target.Type {
  54. case tensor.CPU:
  55. return toCPU(t)
  56. case tensor.CUDA:
  57. return toCUDA(t, target.GPU)
  58. default:
  59. return nil, fmt.Errorf("unsupported target device %v", target.Type)
  60. }
  61. }
  62. // EnsureOnCached is like EnsureOn but uses a cache for weight tensors.
  63. // The key should uniquely identify the weight (e.g., "layer_0:wq").
  64. func EnsureOnCached(t tensor.Tensor, target tensor.DevicePlacement, cache *WeightCache, key string) (tensor.Tensor, error) {
  65. if target.Type != tensor.CUDA {
  66. return EnsureOn(t, target)
  67. }
  68. if cache == nil {
  69. return EnsureOn(t, target)
  70. }
  71. // Check cache first
  72. if cached := cache.Get(key); cached != nil {
  73. return cached, nil
  74. }
  75. // Not cached, create and cache
  76. result, err := toCUDA(t, target.GPU)
  77. if err != nil {
  78. return nil, err
  79. }
  80. cudaTensor, ok := result.(*cuda.Tensor)
  81. if ok {
  82. cache.Put(key, cudaTensor)
  83. }
  84. return result, nil
  85. }
  86. // CUDAAvailable returns whether CUDA is available.
  87. func CUDAAvailable() bool {
  88. return cuda.Available()
  89. }
  90. func toCPU(t tensor.Tensor) (tensor.Tensor, error) {
  91. if c, ok := t.(*cpu.Tensor); ok {
  92. return c, nil
  93. }
  94. switch src := t.(type) {
  95. case *cuda.Tensor:
  96. if !cuda.Available() {
  97. return nil, fmt.Errorf("CUDA not available")
  98. }
  99. out := cpu.NewTensor(src.Shape(), nil)
  100. host := out.DataFloat32()
  101. if err := src.CopyToHost(host); err != nil {
  102. return nil, fmt.Errorf("copy to host failed: %w", err)
  103. }
  104. return out, nil
  105. default:
  106. return nil, fmt.Errorf("toCPU: unsupported tensor type %T", t)
  107. }
  108. }
  109. func toCUDA(t tensor.Tensor, gpu int) (tensor.Tensor, error) {
  110. if !cuda.Available() {
  111. return nil, fmt.Errorf("CUDA not available - build with -tags=cuda")
  112. }
  113. switch src := t.(type) {
  114. case *cuda.Tensor:
  115. if src.GPU() == gpu {
  116. return src, nil
  117. }
  118. if src.DType() != tensor.Float32 {
  119. return nil, fmt.Errorf("cross-GPU tensor copy only supports float32, got %v", src.DType())
  120. }
  121. out, err := cuda.NewTensor(src.Shape(), src.DType(), gpu)
  122. if err != nil {
  123. return nil, err
  124. }
  125. size := uintptr(src.Shape().NumElements() * src.DType().Size())
  126. if err := cuda.MemcpyD2D(out.Data().(unsafe.Pointer), src.Data().(unsafe.Pointer), size, gpu); err != nil {
  127. // Conservative fallback: stage via host.
  128. host := make([]float32, src.Shape().NumElements())
  129. if err2 := src.CopyToHost(host); err2 != nil {
  130. out.Free()
  131. return nil, fmt.Errorf("cross-GPU copy D2H failed: %w", err2)
  132. }
  133. if err2 := out.CopyFrom(host); err2 != nil {
  134. out.Free()
  135. return nil, fmt.Errorf("cross-GPU copy H2D failed: %w", err2)
  136. }
  137. }
  138. return out, nil
  139. }
  140. // For quantized tensors, we need dequantization first
  141. if t.DType() != tensor.Float32 {
  142. return nil, fmt.Errorf("toCUDA: only float32 currently supported, got %v", t.DType())
  143. }
  144. out, err := cuda.NewTensor(t.Shape(), t.DType(), gpu)
  145. if err != nil {
  146. return nil, err
  147. }
  148. switch s := t.(type) {
  149. case *cpu.Tensor:
  150. if err := out.CopyFrom(s.DataFloat32()); err != nil {
  151. return nil, err
  152. }
  153. default:
  154. return nil, fmt.Errorf("toCUDA: unsupported source type %T", t)
  155. }
  156. return out, nil
  157. }
  158. // DeviceDispatcher manages per-device operations and caching.
  159. type DeviceDispatcher struct {
  160. layerDevices []tensor.DevicePlacement
  161. weightCaches map[int]*WeightCache // gpuID -> cache
  162. mu sync.RWMutex
  163. }
  164. // NewDeviceDispatcher creates a dispatcher with the given layer placements.
  165. func NewDeviceDispatcher(layerDevices []tensor.DevicePlacement) *DeviceDispatcher {
  166. dd := &DeviceDispatcher{
  167. layerDevices: layerDevices,
  168. weightCaches: make(map[int]*WeightCache),
  169. }
  170. // Pre-create caches for each GPU mentioned
  171. for _, p := range layerDevices {
  172. if p.Type == tensor.CUDA {
  173. if _, exists := dd.weightCaches[p.GPU]; !exists {
  174. dd.weightCaches[p.GPU] = NewWeightCache(p.GPU)
  175. }
  176. }
  177. }
  178. return dd
  179. }
  180. // LayerPlacement returns the device placement for a layer.
  181. func (dd *DeviceDispatcher) LayerPlacement(layerIdx int) tensor.DevicePlacement {
  182. if layerIdx >= 0 && layerIdx < len(dd.layerDevices) {
  183. return dd.layerDevices[layerIdx]
  184. }
  185. return tensor.DevicePlacement{Type: tensor.CPU, GPU: -1}
  186. }
  187. // GetWeightCache returns the weight cache for a GPU, creating one if needed.
  188. func (dd *DeviceDispatcher) GetWeightCache(gpuID int) *WeightCache {
  189. dd.mu.Lock()
  190. defer dd.mu.Unlock()
  191. if cache, exists := dd.weightCaches[gpuID]; exists {
  192. return cache
  193. }
  194. cache := NewWeightCache(gpuID)
  195. dd.weightCaches[gpuID] = cache
  196. return cache
  197. }
  198. // IsLayerOnGPU returns true if the layer should run on GPU.
  199. func (dd *DeviceDispatcher) IsLayerOnGPU(layerIdx int) bool {
  200. p := dd.LayerPlacement(layerIdx)
  201. return p.Type == tensor.CUDA
  202. }
  203. // NumGPULayers counts how many layers are placed on GPU.
  204. func (dd *DeviceDispatcher) NumGPULayers() int {
  205. count := 0
  206. for _, p := range dd.layerDevices {
  207. if p.Type == tensor.CUDA {
  208. count++
  209. }
  210. }
  211. return count
  212. }
  213. // Clear frees all cached resources.
  214. func (dd *DeviceDispatcher) Clear() {
  215. dd.mu.Lock()
  216. defer dd.mu.Unlock()
  217. for _, cache := range dd.weightCaches {
  218. cache.Clear()
  219. }
  220. }