cuda_memory.cu 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. #include "cuda_common.cuh"
  2. #include <cstdlib>
  3. #include <cstring>
  4. #include <mutex>
  5. #include <unordered_map>
  6. #include <vector>
  7. // --- Memory ---
  8. namespace {
  9. struct free_block {
  10. void * ptr;
  11. size_t size;
  12. };
  13. struct device_pool {
  14. std::mutex mu;
  15. std::unordered_map<void *, size_t> alloc_sizes;
  16. std::vector<free_block> free_list;
  17. size_t cached_bytes = 0;
  18. };
  19. // Current CUDA device cached per host thread.
  20. // This is updated by cuda_set_device and used by cuda_malloc/cuda_free.
  21. static thread_local int tls_device = 0;
  22. // Keep a small per-device cache of freed allocations to avoid cudaMalloc/cudaFree churn
  23. // and to keep VRAM usage stable after first-touch allocations.
  24. static device_pool g_pools[16];
  25. static constexpr size_t MAX_FREE_BLOCKS_PER_DEVICE = 1024;
  26. static size_t g_pool_max_cached_bytes = 512ULL << 20; // 512MB
  27. static size_t g_pool_max_block_bytes = 64ULL << 20; // 64MB
  28. static bool g_pool_enabled = true;
  29. static std::once_flag g_pool_config_once;
  30. static size_t parse_env_bytes(const char * env, size_t def_val) {
  31. if (env == nullptr || env[0] == '\0') {
  32. return def_val;
  33. }
  34. char * end = nullptr;
  35. unsigned long long val = std::strtoull(env, &end, 10);
  36. if (end != nullptr && *end != '\0') {
  37. switch (*end) {
  38. case 'k':
  39. case 'K':
  40. val *= 1024ULL;
  41. break;
  42. case 'm':
  43. case 'M':
  44. val *= 1024ULL * 1024ULL;
  45. break;
  46. case 'g':
  47. case 'G':
  48. val *= 1024ULL * 1024ULL * 1024ULL;
  49. break;
  50. default:
  51. break;
  52. }
  53. }
  54. return static_cast<size_t>(val);
  55. }
  56. static bool env_true(const char * env) {
  57. if (env == nullptr) {
  58. return false;
  59. }
  60. if (std::strcmp(env, "1") == 0 || std::strcmp(env, "true") == 0 || std::strcmp(env, "TRUE") == 0) {
  61. return true;
  62. }
  63. return false;
  64. }
  65. static void init_pool_config() {
  66. std::call_once(g_pool_config_once, []() {
  67. const char * disable = std::getenv("MAKARNA_CUDA_POOL_DISABLE");
  68. if (env_true(disable)) {
  69. g_pool_enabled = false;
  70. g_pool_max_cached_bytes = 0;
  71. g_pool_max_block_bytes = 0;
  72. return;
  73. }
  74. const char * max_bytes = std::getenv("MAKARNA_CUDA_POOL_MAX_BYTES");
  75. const char * max_block = std::getenv("MAKARNA_CUDA_POOL_MAX_BLOCK_BYTES");
  76. g_pool_max_cached_bytes = parse_env_bytes(max_bytes, g_pool_max_cached_bytes);
  77. g_pool_max_block_bytes = parse_env_bytes(max_block, g_pool_max_block_bytes);
  78. });
  79. }
  80. static device_pool & pool_for(int device) {
  81. if (device < 0) device = 0;
  82. if (device >= 16) device = device % 16;
  83. return g_pools[device];
  84. }
  85. static void * pool_alloc(int device, size_t size) {
  86. init_pool_config();
  87. device_pool & p = pool_for(device);
  88. std::lock_guard<std::mutex> lock(p.mu);
  89. // Best-fit search: pick the smallest block that satisfies the request.
  90. size_t best_i = (size_t) -1;
  91. size_t best_size = (size_t) -1;
  92. for (size_t i = 0; i < p.free_list.size(); ++i) {
  93. const free_block & b = p.free_list[i];
  94. if (b.size >= size && b.size < best_size) {
  95. best_i = i;
  96. best_size = b.size;
  97. }
  98. }
  99. if (best_i != (size_t) -1) {
  100. void * ptr = p.free_list[best_i].ptr;
  101. size_t bsize = p.free_list[best_i].size;
  102. // erase by swap-with-back
  103. p.free_list[best_i] = p.free_list.back();
  104. p.free_list.pop_back();
  105. if (p.cached_bytes >= bsize) {
  106. p.cached_bytes -= bsize;
  107. } else {
  108. p.cached_bytes = 0;
  109. }
  110. return ptr;
  111. }
  112. return nullptr;
  113. }
  114. static void pool_record_alloc(int device, void * ptr, size_t size) {
  115. if (ptr == nullptr) return;
  116. device_pool & p = pool_for(device);
  117. std::lock_guard<std::mutex> lock(p.mu);
  118. p.alloc_sizes[ptr] = size;
  119. }
  120. static size_t pool_lookup_size(int device, void * ptr) {
  121. device_pool & p = pool_for(device);
  122. std::lock_guard<std::mutex> lock(p.mu);
  123. auto it = p.alloc_sizes.find(ptr);
  124. if (it == p.alloc_sizes.end()) {
  125. return 0;
  126. }
  127. return it->second;
  128. }
  129. static int pool_find_device(void * ptr, size_t * out_size) {
  130. if (out_size) *out_size = 0;
  131. if (ptr == nullptr) return -1;
  132. for (int d = 0; d < 16; ++d) {
  133. device_pool & p = g_pools[d];
  134. std::lock_guard<std::mutex> lock(p.mu);
  135. auto it = p.alloc_sizes.find(ptr);
  136. if (it != p.alloc_sizes.end()) {
  137. if (out_size) *out_size = it->second;
  138. return d;
  139. }
  140. }
  141. return -1;
  142. }
  143. static void pool_free(int device, void * ptr) {
  144. init_pool_config();
  145. if (ptr == nullptr) return;
  146. size_t size = pool_lookup_size(device, ptr);
  147. int actual_device = device;
  148. if (size == 0) {
  149. int found = pool_find_device(ptr, &size);
  150. if (found >= 0) {
  151. actual_device = found;
  152. }
  153. }
  154. device_pool & p = pool_for(actual_device);
  155. std::lock_guard<std::mutex> lock(p.mu);
  156. if (!g_pool_enabled || g_pool_max_cached_bytes == 0 || g_pool_max_block_bytes == 0 || size > g_pool_max_block_bytes) {
  157. cudaSetDevice(actual_device);
  158. cudaFree(ptr);
  159. p.alloc_sizes.erase(ptr);
  160. return;
  161. }
  162. if (p.free_list.size() >= MAX_FREE_BLOCKS_PER_DEVICE || p.cached_bytes+size > g_pool_max_cached_bytes) {
  163. // Pool full: actually free.
  164. cudaSetDevice(actual_device);
  165. cudaFree(ptr);
  166. p.alloc_sizes.erase(ptr);
  167. return;
  168. }
  169. p.free_list.push_back(free_block{ptr, size});
  170. p.cached_bytes += size;
  171. }
  172. } // namespace
  173. int cuda_set_device(int id) {
  174. // cudaSetDevice is expensive when called repeatedly.
  175. // Cache per host thread since CUDA device context is thread-affine.
  176. if (tls_device == id) {
  177. return 0;
  178. }
  179. CHECK_CUDA(cudaSetDevice(id));
  180. tls_device = id;
  181. return 0;
  182. }
  183. void* cuda_malloc(size_t size) {
  184. init_pool_config();
  185. const int device = tls_device;
  186. void * ptr = pool_alloc(device, size);
  187. if (ptr != nullptr) {
  188. return ptr;
  189. }
  190. ptr = NULL;
  191. if (cudaMalloc(&ptr, size) != cudaSuccess) {
  192. return NULL;
  193. }
  194. pool_record_alloc(device, ptr, size);
  195. return ptr;
  196. }
  197. void cuda_free(void* ptr) {
  198. const int device = tls_device;
  199. pool_free(device, ptr);
  200. }
  201. int cuda_synchronize() {
  202. CHECK_CUDA(cudaDeviceSynchronize());
  203. return 0;
  204. }
  205. int cuda_memcpy_h2d(void* dst, void* src, size_t size) {
  206. CHECK_CUDA(cudaMemcpy(dst, src, size, cudaMemcpyHostToDevice));
  207. return 0;
  208. }
  209. int cuda_memcpy_d2h(void* dst, void* src, size_t size) {
  210. CHECK_CUDA(cudaMemcpy(dst, src, size, cudaMemcpyDeviceToHost));
  211. return 0;
  212. }
  213. int cuda_memcpy_d2d(void* dst, void* src, size_t size) {
  214. CHECK_CUDA(cudaMemcpy(dst, src, size, cudaMemcpyDeviceToDevice));
  215. return 0;
  216. }
  217. int cuda_mem_info(size_t* free_bytes, size_t* total_bytes) {
  218. // cudaMemGetInfo can return cudaErrorOperatingSystem in some restricted
  219. // environments even though allocations/kernels work. Fall back to device
  220. // properties so higher-level placement logic can still function.
  221. cudaError_t err = cudaMemGetInfo(free_bytes, total_bytes);
  222. if (err == cudaSuccess) {
  223. return 0;
  224. }
  225. if (err == cudaErrorOperatingSystem) {
  226. // Some sandboxes block driver queries (MemGetInfo/GetDeviceProperties)
  227. // but still allow allocations. Approximate "free" with a probing alloc.
  228. (void)cudaGetLastError();
  229. size_t max_ok = 0;
  230. size_t probe = 256ULL << 20; // 256MB
  231. const size_t max_probe = 64ULL << 30; // 64GB cap
  232. void* p = nullptr;
  233. while (probe <= max_probe) {
  234. cudaError_t e = cudaMalloc(&p, probe);
  235. if (e == cudaSuccess) {
  236. (void)cudaFree(p);
  237. p = nullptr;
  238. max_ok = probe;
  239. probe <<= 1;
  240. continue;
  241. }
  242. (void)cudaGetLastError();
  243. break;
  244. }
  245. size_t lo = max_ok;
  246. size_t hi = probe;
  247. // Binary search to 64MB granularity.
  248. const size_t gran = 64ULL << 20;
  249. while (hi > lo + gran) {
  250. size_t mid = lo + (hi - lo) / 2;
  251. mid = (mid / (1ULL << 20)) * (1ULL << 20); // align to 1MB
  252. if (mid <= lo) {
  253. break;
  254. }
  255. cudaError_t e = cudaMalloc(&p, mid);
  256. if (e == cudaSuccess) {
  257. (void)cudaFree(p);
  258. p = nullptr;
  259. lo = mid;
  260. } else {
  261. (void)cudaGetLastError();
  262. hi = mid;
  263. }
  264. }
  265. if (free_bytes) {
  266. *free_bytes = lo;
  267. }
  268. if (total_bytes) {
  269. *total_bytes = lo;
  270. }
  271. return 0;
  272. }
  273. fprintf(stderr, "CUDA error at %s:%d: %s\n", __FILE__, __LINE__, cudaGetErrorString(err));
  274. return 1;
  275. }
  276. int cuda_device_count(int* count) {
  277. int c = 0;
  278. CHECK_CUDA(cudaGetDeviceCount(&c));
  279. if (count) {
  280. *count = c;
  281. }
  282. return 0;
  283. }