1
0

weight_cache.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500
  1. //go:build cuda
  2. // Package compute provides GPU weight caching for persistent weight storage.
  3. package compute
  4. import (
  5. "fmt"
  6. "log"
  7. "os"
  8. "sort"
  9. "strconv"
  10. "strings"
  11. "sync"
  12. "unsafe"
  13. "makarna/pkg/backend/cpu"
  14. "makarna/pkg/backend/cuda"
  15. "makarna/pkg/tensor"
  16. )
  17. // GPUWeightCache stores quantized weights on GPU for reuse across calls.
  18. // This is the key optimization - upload once, use many times.
  19. type GPUWeightCache struct {
  20. mu sync.RWMutex
  21. weights map[string]cachedWeight
  22. gpu int
  23. totalMem uint64 // bytes allocated
  24. allocByLayer map[int]uint64
  25. dupByTensor map[uintptr]string
  26. allocCount uint64
  27. }
  28. type cachedWeight struct {
  29. ptr unsafe.Pointer
  30. tensor *cuda.Tensor // Keep reference for float weights to prevent GC
  31. dtype tensor.DType
  32. shape tensor.Shape
  33. numBlocks int
  34. sizeBytes int
  35. }
  36. // Global weight cache per GPU
  37. var (
  38. weightCaches = make(map[int]*GPUWeightCache)
  39. weightCacheMu sync.Mutex
  40. )
  41. // GetWeightCache returns the weight cache for a GPU, creating if needed.
  42. func GetWeightCache(gpu int) *GPUWeightCache {
  43. weightCacheMu.Lock()
  44. defer weightCacheMu.Unlock()
  45. if cache, ok := weightCaches[gpu]; ok {
  46. return cache
  47. }
  48. cache := &GPUWeightCache{
  49. weights: make(map[string]cachedWeight),
  50. gpu: gpu,
  51. allocByLayer: make(map[int]uint64),
  52. dupByTensor: make(map[uintptr]string),
  53. }
  54. weightCaches[gpu] = cache
  55. return cache
  56. }
  57. // Get returns a cached GPU weight pointer, or nil if not cached.
  58. func (c *GPUWeightCache) Get(key string) (unsafe.Pointer, bool) {
  59. c.mu.RLock()
  60. defer c.mu.RUnlock()
  61. if w, ok := c.weights[key]; ok {
  62. return w.ptr, true
  63. }
  64. return nil, false
  65. }
  66. // GetTensor returns the cached CUDA tensor for float weights (Float32/Float16 on GPU).
  67. // For quantized weights, this returns (nil, false) since they are stored as raw device pointers.
  68. func (c *GPUWeightCache) GetTensor(key string) (*cuda.Tensor, bool) {
  69. c.mu.RLock()
  70. defer c.mu.RUnlock()
  71. w, ok := c.weights[key]
  72. if !ok {
  73. return nil, false
  74. }
  75. if w.tensor == nil {
  76. return nil, false
  77. }
  78. return w.tensor, true
  79. }
  80. // Upload uploads a CPU tensor to GPU and caches it.
  81. // Returns the GPU pointer for immediate use.
  82. func (c *GPUWeightCache) Upload(key string, t *cpu.Tensor) (unsafe.Pointer, error) {
  83. c.mu.Lock()
  84. defer c.mu.Unlock()
  85. // Check if already cached
  86. if w, ok := c.weights[key]; ok {
  87. return w.ptr, nil
  88. }
  89. shape := t.Shape()
  90. dtype := t.DType()
  91. numElements := shape.NumElements()
  92. var ptr unsafe.Pointer
  93. var sizeBytes int
  94. var numBlocks int
  95. var err error
  96. switch dtype {
  97. case tensor.Float16, tensor.BFloat16:
  98. sizeBytes = numElements * 2
  99. gpuTensor, err2 := cuda.NewTensor(shape, dtype, c.gpu)
  100. if err2 != nil {
  101. return nil, fmt.Errorf("alloc %v weight: %w", dtype, err2)
  102. }
  103. if numElements > 0 {
  104. srcPtr := t.Data().(unsafe.Pointer)
  105. dstPtr := gpuTensor.Data().(unsafe.Pointer)
  106. if err2 := cuda.MemcpyH2D(dstPtr, srcPtr, uintptr(sizeBytes), c.gpu); err2 != nil {
  107. gpuTensor.Free()
  108. return nil, fmt.Errorf("copy %v weight: %w", dtype, err2)
  109. }
  110. }
  111. ptr = gpuTensor.Data().(unsafe.Pointer)
  112. // Store tensor reference to prevent GC.
  113. c.weights[key] = cachedWeight{
  114. ptr: ptr,
  115. tensor: gpuTensor,
  116. dtype: dtype,
  117. shape: shape,
  118. numBlocks: 0,
  119. sizeBytes: sizeBytes,
  120. }
  121. c.totalMem += uint64(sizeBytes)
  122. c.recordAlloc(key, sizeBytes, dtype, shape, t)
  123. return ptr, nil
  124. case tensor.Q8_K:
  125. numBlocks = numElements / 256
  126. sizeBytes = numBlocks * 292
  127. data := t.Data().(unsafe.Pointer)
  128. dataSlice := unsafe.Slice((*byte)(data), sizeBytes)
  129. ptr, err = cuda.UploadQ8K(dataSlice, numBlocks, c.gpu)
  130. case tensor.Q5_K:
  131. numBlocks = numElements / 256
  132. sizeBytes = numBlocks * 176
  133. data := t.Data().(unsafe.Pointer)
  134. dataSlice := unsafe.Slice((*byte)(data), sizeBytes)
  135. ptr, err = cuda.UploadQ5K(dataSlice, numBlocks, c.gpu)
  136. case tensor.Q4_K:
  137. numBlocks = numElements / 256
  138. sizeBytes = numBlocks * 144
  139. data := t.Data().(unsafe.Pointer)
  140. dataSlice := unsafe.Slice((*byte)(data), sizeBytes)
  141. ptr, err = cuda.UploadQ4K(dataSlice, numBlocks, c.gpu)
  142. case tensor.Q2_K:
  143. numBlocks = numElements / 256
  144. sizeBytes = numBlocks * 84 // 16 (scales) + 64 (qs) + 2 (d) + 2 (dmin)
  145. data := t.Data().(unsafe.Pointer)
  146. dataSlice := unsafe.Slice((*byte)(data), sizeBytes)
  147. ptr, err = cuda.UploadQ2K(dataSlice, numBlocks, c.gpu)
  148. case tensor.Q3_K:
  149. numBlocks = numElements / 256
  150. sizeBytes = numBlocks * 110 // 32(hm) + 64(qs) + 12(scales) + 2(d)
  151. data := t.Data().(unsafe.Pointer)
  152. dataSlice := unsafe.Slice((*byte)(data), sizeBytes)
  153. ptr, err = cuda.UploadQ3K(dataSlice, numBlocks, c.gpu)
  154. case tensor.Q6_K:
  155. numBlocks = numElements / 256
  156. sizeBytes = numBlocks * 210 // 128(ql) + 64(qh) + 16(scales) + 2(d)
  157. data := t.Data().(unsafe.Pointer)
  158. dataSlice := unsafe.Slice((*byte)(data), sizeBytes)
  159. ptr, err = cuda.UploadQ6K(dataSlice, numBlocks, c.gpu)
  160. case tensor.Float32:
  161. sizeBytes = numElements * 4
  162. gpuTensor, err2 := cuda.NewTensor(shape, tensor.Float32, c.gpu)
  163. if err2 != nil {
  164. return nil, fmt.Errorf("alloc F32 weight: %w", err2)
  165. }
  166. if err2 := gpuTensor.CopyFrom(t.DataFloat32()); err2 != nil {
  167. return nil, fmt.Errorf("copy F32 weight: %w", err2)
  168. }
  169. ptr = gpuTensor.Data().(unsafe.Pointer)
  170. // Store tensor reference to prevent GC.
  171. c.weights[key] = cachedWeight{
  172. ptr: ptr,
  173. tensor: gpuTensor,
  174. dtype: dtype,
  175. shape: shape,
  176. numBlocks: 0,
  177. sizeBytes: sizeBytes,
  178. }
  179. c.totalMem += uint64(sizeBytes)
  180. c.recordAlloc(key, sizeBytes, dtype, shape, t)
  181. return ptr, nil
  182. default:
  183. return nil, fmt.Errorf("unsupported dtype for GPU cache: %v", dtype)
  184. }
  185. if err != nil {
  186. return nil, err
  187. }
  188. c.weights[key] = cachedWeight{
  189. ptr: ptr,
  190. tensor: nil, // Quant weights store raw device pointers.
  191. dtype: dtype,
  192. shape: shape,
  193. numBlocks: numBlocks,
  194. sizeBytes: sizeBytes,
  195. }
  196. c.totalMem += uint64(sizeBytes)
  197. c.recordAlloc(key, sizeBytes, dtype, shape, t)
  198. return ptr, nil
  199. }
  200. // UploadF16 uploads a Float32 CPU tensor to GPU as Float16 and caches it.
  201. // Intended for Tensor Core GEMM paths (e.g., dense matmul weights).
  202. func (c *GPUWeightCache) UploadF16(key string, t *cpu.Tensor) (unsafe.Pointer, error) {
  203. c.mu.Lock()
  204. defer c.mu.Unlock()
  205. // Check if already cached
  206. if w, ok := c.weights[key]; ok {
  207. return w.ptr, nil
  208. }
  209. if t.DType() != tensor.Float32 {
  210. return nil, fmt.Errorf("UploadF16: expected Float32 tensor, got %v", t.DType())
  211. }
  212. shape := t.Shape()
  213. numElements := shape.NumElements()
  214. sizeBytes := numElements * 2
  215. tmpF32, err := cuda.NewTensor(shape, tensor.Float32, c.gpu)
  216. if err != nil {
  217. return nil, fmt.Errorf("alloc temp F32 weight: %w", err)
  218. }
  219. if err := tmpF32.CopyFrom(t.DataFloat32()); err != nil {
  220. tmpF32.Free()
  221. return nil, fmt.Errorf("copy temp F32 weight: %w", err)
  222. }
  223. gpuTensor, err := cuda.NewTensor(shape, tensor.Float16, c.gpu)
  224. if err != nil {
  225. tmpF32.Free()
  226. return nil, fmt.Errorf("alloc F16 weight: %w", err)
  227. }
  228. if err := cuda.CastF32ToF16(tmpF32.Data().(unsafe.Pointer), gpuTensor.Data().(unsafe.Pointer), numElements, c.gpu); err != nil {
  229. tmpF32.Free()
  230. gpuTensor.Free()
  231. return nil, fmt.Errorf("cast weight F32->F16: %w", err)
  232. }
  233. tmpF32.Free()
  234. ptr := gpuTensor.Data().(unsafe.Pointer)
  235. c.weights[key] = cachedWeight{
  236. ptr: ptr,
  237. tensor: gpuTensor,
  238. dtype: tensor.Float16,
  239. shape: shape,
  240. numBlocks: 0,
  241. sizeBytes: sizeBytes,
  242. }
  243. c.totalMem += uint64(sizeBytes)
  244. c.recordAlloc(key, sizeBytes, tensor.Float16, shape, t)
  245. return ptr, nil
  246. }
  247. // TotalMemory returns total GPU memory used by cache in bytes.
  248. func (c *GPUWeightCache) TotalMemory() uint64 {
  249. c.mu.RLock()
  250. defer c.mu.RUnlock()
  251. return c.totalMem
  252. }
  253. // Clear frees all cached weights.
  254. func (c *GPUWeightCache) Clear() {
  255. c.mu.Lock()
  256. defer c.mu.Unlock()
  257. for _, w := range c.weights {
  258. if w.tensor != nil {
  259. // Float weights: tensor will be freed by finalizer
  260. // Just clear the reference
  261. w.tensor = nil
  262. } else {
  263. // Quantized weights: free the raw pointer
  264. cuda.FreeDevicePtr(w.ptr)
  265. }
  266. }
  267. c.weights = make(map[string]cachedWeight)
  268. c.totalMem = 0
  269. c.allocByLayer = make(map[int]uint64)
  270. c.dupByTensor = make(map[uintptr]string)
  271. c.allocCount = 0
  272. }
  273. // ClearAllCaches frees all GPU weight caches.
  274. func ClearAllCaches() {
  275. weightCacheMu.Lock()
  276. defer weightCacheMu.Unlock()
  277. for _, cache := range weightCaches {
  278. cache.Clear()
  279. }
  280. weightCaches = make(map[int]*GPUWeightCache)
  281. }
  282. // LogWeightCacheSummary prints per-GPU and per-layer allocation summaries when enabled.
  283. func LogWeightCacheSummary() {
  284. if !weightMemLogSummaryEnabled() {
  285. return
  286. }
  287. weightCacheMu.Lock()
  288. caches := make([]*GPUWeightCache, 0, len(weightCaches))
  289. for _, cache := range weightCaches {
  290. caches = append(caches, cache)
  291. }
  292. weightCacheMu.Unlock()
  293. for _, cache := range caches {
  294. cache.dumpSummary()
  295. }
  296. }
  297. func (c *GPUWeightCache) recordAlloc(key string, sizeBytes int, dtype tensor.DType, shape tensor.Shape, t *cpu.Tensor) {
  298. if sizeBytes <= 0 {
  299. return
  300. }
  301. layer, ok := layerFromCacheKey(key)
  302. if !ok {
  303. layer = -1
  304. }
  305. if c.allocByLayer == nil {
  306. c.allocByLayer = make(map[int]uint64)
  307. }
  308. c.allocByLayer[layer] += uint64(sizeBytes)
  309. c.allocCount++
  310. if weightMemLogAllocEnabled() {
  311. log.Printf("gpu-cache alloc gpu=%d layer=%d bytes=%d total=%s key=%s dtype=%s shape=%v",
  312. c.gpu, layer, sizeBytes, formatBytes(c.totalMem), key, dtype.String(), shape)
  313. if t != nil {
  314. if c.dupByTensor == nil {
  315. c.dupByTensor = make(map[uintptr]string)
  316. }
  317. tID := uintptr(unsafe.Pointer(t))
  318. if prev, ok := c.dupByTensor[tID]; ok && prev != key {
  319. log.Printf("gpu-cache dup gpu=%d tensor=%p prev_key=%s new_key=%s", c.gpu, t, prev, key)
  320. } else if !ok {
  321. c.dupByTensor[tID] = key
  322. }
  323. }
  324. }
  325. }
  326. func (c *GPUWeightCache) dumpSummary() {
  327. if c == nil {
  328. return
  329. }
  330. c.mu.RLock()
  331. total := c.totalMem
  332. allocCount := c.allocCount
  333. byLayer := make([]layerAlloc, 0, len(c.allocByLayer))
  334. for layer, bytes := range c.allocByLayer {
  335. byLayer = append(byLayer, layerAlloc{layer: layer, bytes: bytes})
  336. }
  337. c.mu.RUnlock()
  338. sort.Slice(byLayer, func(i, j int) bool {
  339. return byLayer[i].layer < byLayer[j].layer
  340. })
  341. totalMem, freeMem, err := cuda.MemoryInfoDevice(c.gpu)
  342. if err != nil {
  343. log.Printf("gpu-cache summary gpu=%d total=%s allocs=%d", c.gpu, formatBytes(total), allocCount)
  344. } else {
  345. log.Printf("gpu-cache summary gpu=%d total=%s allocs=%d free=%s/%s", c.gpu, formatBytes(total), allocCount, formatBytes(freeMem), formatBytes(totalMem))
  346. }
  347. for _, entry := range byLayer {
  348. label := "shared"
  349. if entry.layer >= 0 {
  350. label = fmt.Sprintf("layer%d", entry.layer)
  351. }
  352. log.Printf("gpu-cache layer=%s bytes=%s", label, formatBytes(entry.bytes))
  353. }
  354. }
  355. type layerAlloc struct {
  356. layer int
  357. bytes uint64
  358. }
  359. var (
  360. weightMemLogOnce sync.Once
  361. weightMemLogAlloc bool
  362. weightMemLogSummary bool
  363. )
  364. func weightMemLogAllocEnabled() bool {
  365. weightMemLogInit()
  366. return weightMemLogAlloc
  367. }
  368. func weightMemLogSummaryEnabled() bool {
  369. weightMemLogInit()
  370. return weightMemLogSummary
  371. }
  372. func weightMemLogInit() {
  373. weightMemLogOnce.Do(func() {
  374. raw := strings.ToLower(strings.TrimSpace(os.Getenv("MAKARNA_GPU_MEMLOG")))
  375. if raw == "" || raw == "0" || raw == "false" || raw == "off" {
  376. return
  377. }
  378. switch raw {
  379. case "1", "true", "all":
  380. weightMemLogAlloc = true
  381. weightMemLogSummary = true
  382. return
  383. }
  384. if strings.Contains(raw, "alloc") {
  385. weightMemLogAlloc = true
  386. }
  387. if strings.Contains(raw, "summary") {
  388. weightMemLogSummary = true
  389. }
  390. if !weightMemLogAlloc && !weightMemLogSummary {
  391. weightMemLogAlloc = true
  392. }
  393. })
  394. }
  395. func layerFromCacheKey(key string) (int, bool) {
  396. if strings.HasPrefix(key, "layer") {
  397. rest := key[len("layer"):]
  398. n := readLeadingInt(rest)
  399. if n >= 0 {
  400. return n, true
  401. }
  402. }
  403. if strings.HasPrefix(key, "kda_l") {
  404. rest := key[len("kda_l"):]
  405. n := readLeadingInt(rest)
  406. if n >= 0 {
  407. return n, true
  408. }
  409. }
  410. return 0, false
  411. }
  412. func readLeadingInt(s string) int {
  413. if s == "" {
  414. return -1
  415. }
  416. end := 0
  417. for end < len(s) && s[end] >= '0' && s[end] <= '9' {
  418. end++
  419. }
  420. if end == 0 {
  421. return -1
  422. }
  423. n, err := strconv.Atoi(s[:end])
  424. if err != nil {
  425. return -1
  426. }
  427. return n
  428. }
  429. func formatBytes(v uint64) string {
  430. const unit = 1024
  431. if v < unit {
  432. return fmt.Sprintf("%dB", v)
  433. }
  434. div, exp := uint64(unit), 0
  435. for n := v / unit; n >= unit && exp < 4; n /= unit {
  436. div *= unit
  437. exp++
  438. }
  439. value := float64(v) / float64(div)
  440. suffix := []string{"KB", "MB", "GB", "TB", "PB"}[exp]
  441. return fmt.Sprintf("%.2f%s", value, suffix)
  442. }