common.h 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420
  1. /*
  2. * Copyright (c) 2023-2024 The ggml authors
  3. *
  4. * Permission is hereby granted, free of charge, to any person obtaining a copy
  5. * of this software and associated documentation files (the "Software"), to
  6. * deal in the Software without restriction, including without limitation the
  7. * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
  8. * sell copies of the Software, and to permit persons to whom the Software is
  9. * furnished to do so, subject to the following conditions:
  10. *
  11. * The above copyright notice and this permission notice shall be included in
  12. * all copies or substantial portions of the Software.
  13. *
  14. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  15. * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  16. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  17. * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  18. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
  19. * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
  20. * IN THE SOFTWARE.
  21. */
  22. #ifndef CANN_COMMON_H
  23. #define CANN_COMMON_H
  24. #include <acl/acl.h>
  25. #include <cstdio>
  26. #include <iostream>
  27. #include <map>
  28. #include <memory>
  29. #include <string>
  30. #include <vector>
  31. #include <atomic>
  32. #include <condition_variable>
  33. #include <mutex>
  34. #include <thread>
  35. #include <unistd.h>
  36. #include <functional>
  37. #include "../include/ggml-cann.h"
  38. #include "../include/ggml.h"
  39. #include "../ggml-impl.h"
  40. #define MATRIX_ROW_PADDING 512
  41. #define GGML_CANN_MAX_STREAMS 8
  42. /**
  43. * @brief Handles CANN-related errors by printing an error message and
  44. * terminating the program.
  45. * @param stmt The statement that caused the error.
  46. * @param func The function in which the error occurred.
  47. * @param file The file in which the error occurred.
  48. * @param line The line number at which the error occurred.
  49. * @param msg The error message.
  50. */
  51. [[noreturn]] void ggml_cann_error(const char* stmt, const char* func,
  52. const char* file, int line, const char* msg);
  53. /**
  54. * @brief Checks the result of a CANN function call and invokes the error
  55. * handler if the call fails.
  56. * @param stmt The CANN function call to check.
  57. * @param success The success code that indicates the call was successful.
  58. * @param error_fn The function to call to retrieve the error message.
  59. */
  60. #define ACL_CHECK_GEN(stmt, success, error_fn) \
  61. do { \
  62. int err_code = (stmt); \
  63. if (err_code != (success)) { \
  64. ggml_cann_error(#stmt, __func__, __FILE__, __LINE__, error_fn()); \
  65. } \
  66. } while (0);
  67. #define ACL_CHECK(stmt) ACL_CHECK_GEN(stmt, 0, aclGetRecentErrMsg)
  68. /**
  69. * @brief Contains information about CANN devices.
  70. */
  71. struct ggml_cann_device_info {
  72. /**
  73. * @brief Number of CANN devices available.
  74. */
  75. int32_t device_count;
  76. /**
  77. * @brief Information about a single CANN device.
  78. */
  79. struct cann_device_info {
  80. int cc; /**< Compute capability. */
  81. size_t smpb; /**< Maximum shared memory per block. */
  82. bool vmm; /**< Virtual memory support. */
  83. size_t vmm_granularity; /**< Granularity of virtual memory. */
  84. size_t total_vram; /**< Total video RAM available on the device. */
  85. };
  86. cann_device_info devices[GGML_CANN_MAX_DEVICES] =
  87. {}; /**< Array of CANN device information. */
  88. };
  89. const ggml_cann_device_info& ggml_cann_info();
  90. void ggml_cann_set_device(int32_t device);
  91. int32_t ggml_cann_get_device();
  92. /**
  93. * @brief Abstract base class for memory pools used by CANN.
  94. */
  95. struct ggml_cann_pool {
  96. /**
  97. * @brief Virtual destructor for the memory pool.
  98. */
  99. virtual ~ggml_cann_pool() = default;
  100. /**
  101. * @brief Allocates memory from the pool.
  102. *
  103. * @param size The size of the memory block to allocate.
  104. * @param actual_size Pointer to a variable where the actual allocated size
  105. * will be stored.
  106. * @return Pointer to the allocated memory block.
  107. */
  108. virtual void* alloc(size_t size, size_t* actual_size) = 0;
  109. /**
  110. * @brief Frees a previously allocated memory block.
  111. *
  112. * @param ptr Pointer to the memory block to free.
  113. * @param size Size of the memory block to free.
  114. * @note Note that all CANN opertors are running async. Make sure memory is
  115. * still avaiable before this operator finished.
  116. */
  117. virtual void free(void* ptr, size_t size) = 0;
  118. };
  119. /**
  120. * @brief RAII wrapper for managing memory allocations from a CANN memory pool.
  121. */
  122. struct ggml_cann_pool_alloc {
  123. ggml_cann_pool* pool = nullptr; /**< Pointer to the memory pool. */
  124. void* ptr = nullptr; /**< Pointer to the allocated memory block. */
  125. size_t actual_size = 0; /**< Actual size of the allocated memory block. */
  126. /**
  127. * @brief Default constructor.
  128. */
  129. ggml_cann_pool_alloc() = default;
  130. /**
  131. * @brief Constructor that initializes the memory pool.
  132. * @param pool Reference to the memory pool.
  133. */
  134. explicit ggml_cann_pool_alloc(ggml_cann_pool& pool) : pool(&pool) {}
  135. /**
  136. * @brief Constructor that initializes the memory pool and allocates memory.
  137. * @param pool Reference to the memory pool.
  138. * @param size Size of the memory block to allocate.
  139. */
  140. ggml_cann_pool_alloc(ggml_cann_pool& pool, size_t size) : pool(&pool) {
  141. alloc(size);
  142. }
  143. /**
  144. * @brief Destructor that frees the allocated memory block.
  145. */
  146. ~ggml_cann_pool_alloc() {
  147. if (ptr != nullptr) {
  148. pool->free(ptr, actual_size);
  149. }
  150. }
  151. /**
  152. * @brief Allocates memory from the pool.
  153. * @param size Size of the memory block to allocate.
  154. * @return Pointer to the allocated memory block.
  155. */
  156. void* alloc(size_t size) {
  157. GGML_ASSERT(pool != nullptr);
  158. GGML_ASSERT(ptr == nullptr);
  159. ptr = pool->alloc(size, &this->actual_size);
  160. return ptr;
  161. }
  162. /**
  163. * @brief Allocates memory from a specific memory pool.
  164. * @param pool Reference to the memory pool.
  165. * @param size Size of the memory block to allocate.
  166. * @return Pointer to the allocated memory block.
  167. */
  168. void* alloc(ggml_cann_pool& pool, size_t size) {
  169. this->pool = &pool;
  170. return alloc(size);
  171. }
  172. /**
  173. * @brief Gets the pointer to the allocated memory block.
  174. * @return Pointer to the allocated memory block.
  175. */
  176. void* get() { return ptr; }
  177. // Deleted copy constructor
  178. ggml_cann_pool_alloc(const ggml_cann_pool_alloc&) = delete;
  179. // Deleted move constructor
  180. ggml_cann_pool_alloc(ggml_cann_pool_alloc&&) = delete;
  181. // Deleted copy assignment operator
  182. ggml_cann_pool_alloc& operator=(const ggml_cann_pool_alloc&) = delete;
  183. // Deleted move assignment operator
  184. ggml_cann_pool_alloc& operator=(ggml_cann_pool_alloc&&) = delete;
  185. };
  186. /**
  187. * @brief Function pointer type for ACLNN operator calls.
  188. */
  189. using aclnn_func_t = aclnnStatus (*)(void*, uint64_t, aclOpExecutor*, aclrtStream);
  190. /**
  191. * @brief Base class for all CANN tasks to be submitted to the task queue.
  192. *
  193. * Users should override the run_task() method with actual task logic.
  194. */
  195. class cann_task {
  196. public:
  197. virtual void run_task() {}
  198. };
  199. /**
  200. * @brief A lock-free ring-buffer based task queue for asynchronously executing cann_task instances.
  201. */
  202. class cann_task_queue {
  203. public:
  204. /**
  205. * @brief Constructs a task queue with a fixed power-of-two capacity for a specific device.
  206. *
  207. * @param capacity Queue capacity. Must be a power of 2.
  208. * @param device Target device ID (used for context setting).
  209. */
  210. explicit cann_task_queue(size_t capacity, int32_t device)
  211. : buffer_(capacity), capacity_(capacity), head_(0), tail_(0),
  212. running_(false), device_(device) {
  213. GGML_ASSERT((capacity & (capacity - 1)) == 0 && "capacity must be power of 2");
  214. mask_ = capacity_ - 1;
  215. }
  216. /**
  217. * @brief Attempts to enqueue a task into the queue.
  218. *
  219. * @param item Unique pointer to the task.
  220. * @return true if the task was successfully enqueued, false if the queue was full.
  221. */
  222. bool enqueue(std::unique_ptr<cann_task>&& item) {
  223. size_t next_tail = (tail_ + 1) & mask_;
  224. if (next_tail == head_) {
  225. return false;
  226. }
  227. buffer_[tail_] = std::move(item);
  228. std::atomic_thread_fence(std::memory_order_release);
  229. tail_ = next_tail;
  230. return true;
  231. }
  232. /**
  233. * @brief Submits a task to the queue, and starts the worker thread if not already running.
  234. *
  235. * @param task Task to be submitted.
  236. */
  237. void submit_task(std::unique_ptr<cann_task>&& task) {
  238. while(!enqueue(std::move(task))) {
  239. std::this_thread::yield();
  240. continue;
  241. }
  242. if (!running_) {
  243. running_ = true;
  244. thread_ = std::thread(&cann_task_queue::execute, this);
  245. }
  246. }
  247. /**
  248. * @brief Waits until the queue is completely empty and no tasks are being processed.
  249. */
  250. void wait() {
  251. while (running_ && head_ != tail_) {
  252. std::this_thread::yield();
  253. continue;
  254. }
  255. }
  256. /**
  257. * @brief Stops the task queue and joins the worker thread.
  258. */
  259. void stop() {
  260. running_ = false;
  261. if (thread_.joinable()) {
  262. thread_.join();
  263. }
  264. }
  265. private:
  266. /**
  267. * @brief Worker thread function that continuously dequeues and executes tasks.
  268. */
  269. void execute() {
  270. ggml_cann_set_device(device_);
  271. while (running_) {
  272. if(head_ == tail_) {
  273. std::this_thread::yield();
  274. continue;
  275. }
  276. std::atomic_thread_fence(std::memory_order_acquire);
  277. buffer_[head_]->run_task();
  278. buffer_[head_].reset();
  279. head_ = (head_ + 1) & mask_;
  280. }
  281. }
  282. std::vector<std::unique_ptr<cann_task>> buffer_;
  283. const size_t capacity_;
  284. size_t mask_;
  285. size_t head_;
  286. size_t tail_;
  287. bool running_;
  288. std::thread thread_;
  289. int32_t device_;
  290. };
  291. /**
  292. * @brief Context for managing CANN backend operations.
  293. */
  294. struct ggml_backend_cann_context {
  295. int32_t device; /**< Device ID. */
  296. std::string name; /**< Name of the device. */
  297. std::string description; /**< Description of the device. */
  298. aclrtEvent copy_event = nullptr; /**< Event for managing copy operations. */
  299. cann_task_queue task_queue;
  300. bool async_mode;
  301. aclrtStream streams[GGML_CANN_MAX_STREAMS] = {nullptr}; /**< Array of streams for the device. */
  302. /**
  303. * @brief Constructor for initializing the context with a given device.
  304. * @param device Device ID.
  305. */
  306. explicit ggml_backend_cann_context(int device)
  307. : device(device), name("CANN" + std::to_string(device)), task_queue(1024, device) {
  308. ggml_cann_set_device(device);
  309. description = aclrtGetSocName();
  310. async_mode = (getenv("GGML_CANN_ASYNC_MODE") != nullptr);
  311. GGML_LOG_INFO("%s: device %d async operator submission is %s\n", __func__,
  312. device, async_mode ? "ON" : "OFF");
  313. }
  314. /**
  315. * @brief Destructor for cleaning up resources.
  316. */
  317. ~ggml_backend_cann_context() {
  318. ggml_cann_set_device(device);
  319. task_queue.stop();
  320. if (copy_event != nullptr) {
  321. ACL_CHECK(aclrtDestroyEvent(copy_event));
  322. }
  323. for (int i = 0; i < GGML_CANN_MAX_STREAMS; ++i) {
  324. if (streams[i] != nullptr) {
  325. ACL_CHECK(aclrtDestroyStream(streams[i]));
  326. }
  327. }
  328. }
  329. /**
  330. * @brief Get or create a stream for a given index.
  331. * @param stream Index of the stream.
  332. * @return The stream corresponding to the given index.
  333. */
  334. aclrtStream stream(int stream) {
  335. if (streams[stream] == nullptr) {
  336. ggml_cann_set_device(device);
  337. ACL_CHECK(aclrtCreateStream(&streams[stream]));
  338. }
  339. return streams[stream];
  340. }
  341. /**
  342. * @brief Get or create the default stream (index 0).
  343. * @return The default stream.
  344. */
  345. aclrtStream stream() { return stream(0); }
  346. // TODO: each stream should have a memory pool.
  347. std::unique_ptr<ggml_cann_pool>
  348. mem_pool; /**< Memory pool for the device. */
  349. /**
  350. * @brief Create a new memory pool for a given device.
  351. * @param device Device ID.
  352. * @return A unique pointer to the new memory pool.
  353. */
  354. static std::unique_ptr<ggml_cann_pool> new_pool_for_device(int device);
  355. /**
  356. * @brief Get or create the memory pool for the context.
  357. * @return Reference to the memory pool.
  358. */
  359. ggml_cann_pool& pool() {
  360. if (mem_pool == nullptr) {
  361. mem_pool = new_pool_for_device(device);
  362. }
  363. return *mem_pool;
  364. }
  365. };
  366. #endif // CANN_COMMON_H