common.h 14 KB


  1. /*
  2. * Copyright (c) 2023-2024 The ggml authors
  3. *
  4. * Permission is hereby granted, free of charge, to any person obtaining a copy
  5. * of this software and associated documentation files (the "Software"), to
  6. * deal in the Software without restriction, including without limitation the
  7. * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
  8. * sell copies of the Software, and to permit persons to whom the Software is
  9. * furnished to do so, subject to the following conditions:
  10. *
  11. * The above copyright notice and this permission notice shall be included in
  12. * all copies or substantial portions of the Software.
  13. *
  14. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  15. * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  16. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  17. * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  18. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
  19. * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
  20. * IN THE SOFTWARE.
  21. */
  22. #ifndef CANN_COMMON_H
  23. #define CANN_COMMON_H
  24. #include <acl/acl.h>
  25. #include <cstdio>
  26. #include <iostream>
  27. #include <map>
  28. #include <memory>
  29. #include <string>
  30. #include <vector>
  31. #include <atomic>
  32. #include <condition_variable>
  33. #include <mutex>
  34. #include <thread>
  35. #include <unistd.h>
  36. #include <functional>
  37. #include <optional>
  38. #include "../include/ggml-cann.h"
  39. #include "../include/ggml.h"
  40. #include "../ggml-impl.h"
  41. #define MATRIX_ROW_PADDING 512
  42. #define GGML_CANN_MAX_STREAMS 8
  43. /**
  44. * @brief Handles CANN-related errors by printing an error message and
  45. * terminating the program.
  46. * @param stmt The statement that caused the error.
  47. * @param func The function in which the error occurred.
  48. * @param file The file in which the error occurred.
  49. * @param line The line number at which the error occurred.
  50. * @param msg The error message.
  51. */
  52. [[noreturn]] void ggml_cann_error(const char* stmt, const char* func,
  53. const char* file, int line, const char* msg);
  54. /**
  55. * @brief Checks the result of a CANN function call and invokes the error
  56. * handler if the call fails.
  57. * @param stmt The CANN function call to check.
  58. * @param success The success code that indicates the call was successful.
  59. * @param error_fn The function to call to retrieve the error message.
  60. */
  61. #define ACL_CHECK_GEN(stmt, success, error_fn) \
  62. do { \
  63. int err_code = (stmt); \
  64. if (err_code != (success)) { \
  65. ggml_cann_error(#stmt, __func__, __FILE__, __LINE__, error_fn()); \
  66. } \
  67. } while (0);
  68. #define ACL_CHECK(stmt) ACL_CHECK_GEN(stmt, 0, aclGetRecentErrMsg)
  69. /**
  70. * @brief Contains information about CANN devices.
  71. */
  72. struct ggml_cann_device_info {
  73. /**
  74. * @brief Number of CANN devices available.
  75. */
  76. int32_t device_count;
  77. /**
  78. * @brief Information about a single CANN device.
  79. */
  80. struct cann_device_info {
  81. int cc; /**< Compute capability. */
  82. size_t smpb; /**< Maximum shared memory per block. */
  83. bool vmm; /**< Virtual memory support. */
  84. size_t vmm_granularity; /**< Granularity of virtual memory. */
  85. size_t total_vram; /**< Total video RAM available on the device. */
  86. };
  87. cann_device_info devices[GGML_CANN_MAX_DEVICES] =
  88. {}; /**< Array of CANN device information. */
  89. };
  90. const ggml_cann_device_info& ggml_cann_info();
  91. void ggml_cann_set_device(int32_t device);
  92. int32_t ggml_cann_get_device();
  93. std::optional<std::string> get_env(const std::string& name);
  94. bool parse_bool(const std::string& value);
  95. /**
  96. * @brief Abstract base class for memory pools used by CANN.
  97. */
  98. struct ggml_cann_pool {
  99. /**
  100. * @brief Virtual destructor for the memory pool.
  101. */
  102. virtual ~ggml_cann_pool() = default;
  103. /**
  104. * @brief Allocates memory from the pool.
  105. *
  106. * @param size The size of the memory block to allocate.
  107. * @param actual_size Pointer to a variable where the actual allocated size
  108. * will be stored.
  109. * @return Pointer to the allocated memory block.
  110. */
  111. virtual void* alloc(size_t size, size_t* actual_size) = 0;
  112. /**
  113. * @brief Frees a previously allocated memory block.
  114. *
  115. * @param ptr Pointer to the memory block to free.
  116. * @param size Size of the memory block to free.
  117. * @note Note that all CANN opertors are running async. Make sure memory is
  118. * still avaiable before this operator finished.
  119. */
  120. virtual void free(void* ptr, size_t size) = 0;
  121. };
  122. /**
  123. * @brief RAII wrapper for managing memory allocations from a CANN memory pool.
  124. */
  125. struct ggml_cann_pool_alloc {
  126. ggml_cann_pool* pool = nullptr; /**< Pointer to the memory pool. */
  127. void* ptr = nullptr; /**< Pointer to the allocated memory block. */
  128. size_t actual_size = 0; /**< Actual size of the allocated memory block. */
  129. /**
  130. * @brief Default constructor.
  131. */
  132. ggml_cann_pool_alloc() = default;
  133. /**
  134. * @brief Constructor that initializes the memory pool.
  135. * @param pool Reference to the memory pool.
  136. */
  137. explicit ggml_cann_pool_alloc(ggml_cann_pool& pool) : pool(&pool) {}
  138. /**
  139. * @brief Constructor that initializes the memory pool and allocates memory.
  140. * @param pool Reference to the memory pool.
  141. * @param size Size of the memory block to allocate.
  142. */
  143. ggml_cann_pool_alloc(ggml_cann_pool& pool, size_t size) : pool(&pool) {
  144. alloc(size);
  145. }
  146. /**
  147. * @brief Destructor that frees the allocated memory block.
  148. */
  149. ~ggml_cann_pool_alloc() {
  150. if (ptr != nullptr) {
  151. pool->free(ptr, actual_size);
  152. }
  153. }
  154. /**
  155. * @brief Allocates memory from the pool.
  156. * @param size Size of the memory block to allocate.
  157. * @return Pointer to the allocated memory block.
  158. */
  159. void* alloc(size_t size) {
  160. GGML_ASSERT(pool != nullptr);
  161. GGML_ASSERT(ptr == nullptr);
  162. ptr = pool->alloc(size, &this->actual_size);
  163. return ptr;
  164. }
  165. /**
  166. * @brief Allocates memory from a specific memory pool.
  167. * @param pool Reference to the memory pool.
  168. * @param size Size of the memory block to allocate.
  169. * @return Pointer to the allocated memory block.
  170. */
  171. void* alloc(ggml_cann_pool& pool, size_t size) {
  172. this->pool = &pool;
  173. return alloc(size);
  174. }
  175. /**
  176. * @brief Gets the pointer to the allocated memory block.
  177. * @return Pointer to the allocated memory block.
  178. */
  179. void* get() { return ptr; }
  180. // Deleted copy constructor
  181. ggml_cann_pool_alloc(const ggml_cann_pool_alloc&) = delete;
  182. // Deleted move constructor
  183. ggml_cann_pool_alloc(ggml_cann_pool_alloc&&) = delete;
  184. // Deleted copy assignment operator
  185. ggml_cann_pool_alloc& operator=(const ggml_cann_pool_alloc&) = delete;
  186. // Deleted move assignment operator
  187. ggml_cann_pool_alloc& operator=(ggml_cann_pool_alloc&&) = delete;
  188. };
  189. /**
  190. * @brief Function pointer type for ACLNN operator calls.
  191. */
  192. using aclnn_func_t = aclnnStatus (*)(void*, uint64_t, aclOpExecutor*, aclrtStream);
  193. /**
  194. * @brief Base class for all CANN tasks to be submitted to the task queue.
  195. *
  196. * Users should override the run_task() method with actual task logic.
  197. */
  198. class cann_task {
  199. public:
  200. virtual void run_task() {}
  201. };
  202. /**
  203. * @brief A lock-free ring-buffer based task queue for asynchronously executing cann_task instances.
  204. */
  205. class cann_task_queue {
  206. public:
  207. /**
  208. * @brief Constructs a task queue with a fixed power-of-two capacity for a specific device.
  209. *
  210. * @param capacity Queue capacity. Must be a power of 2.
  211. * @param device Target device ID (used for context setting).
  212. */
  213. explicit cann_task_queue(size_t capacity, int32_t device)
  214. : buffer_(capacity), capacity_(capacity), head_(0), tail_(0),
  215. running_(false), device_(device) {
  216. GGML_ASSERT((capacity & (capacity - 1)) == 0 && "capacity must be power of 2");
  217. mask_ = capacity_ - 1;
  218. }
  219. /**
  220. * @brief Attempts to enqueue a task into the queue.
  221. *
  222. * @param item Unique pointer to the task.
  223. * @return true if the task was successfully enqueued, false if the queue was full.
  224. */
  225. bool enqueue(std::unique_ptr<cann_task>&& item) {
  226. size_t next_tail = (tail_ + 1) & mask_;
  227. if (next_tail == head_) {
  228. return false;
  229. }
  230. buffer_[tail_] = std::move(item);
  231. std::atomic_thread_fence(std::memory_order_release);
  232. tail_ = next_tail;
  233. return true;
  234. }
  235. /**
  236. * @brief Submits a task to the queue, and starts the worker thread if not already running.
  237. *
  238. * @param task Task to be submitted.
  239. */
  240. void submit_task(std::unique_ptr<cann_task>&& task) {
  241. while(!enqueue(std::move(task))) {
  242. std::this_thread::yield();
  243. continue;
  244. }
  245. if (!running_) {
  246. running_ = true;
  247. thread_ = std::thread(&cann_task_queue::execute, this);
  248. }
  249. }
  250. /**
  251. * @brief Waits until the queue is completely empty and no tasks are being processed.
  252. */
  253. void wait() {
  254. while (running_ && head_ != tail_) {
  255. std::this_thread::yield();
  256. continue;
  257. }
  258. }
  259. /**
  260. * @brief Stops the task queue and joins the worker thread.
  261. */
  262. void stop() {
  263. running_ = false;
  264. if (thread_.joinable()) {
  265. thread_.join();
  266. }
  267. }
  268. private:
  269. /**
  270. * @brief Worker thread function that continuously dequeues and executes tasks.
  271. */
  272. void execute() {
  273. ggml_cann_set_device(device_);
  274. while (running_) {
  275. if(head_ == tail_) {
  276. std::this_thread::yield();
  277. continue;
  278. }
  279. std::atomic_thread_fence(std::memory_order_acquire);
  280. buffer_[head_]->run_task();
  281. buffer_[head_].reset();
  282. head_ = (head_ + 1) & mask_;
  283. }
  284. }
  285. std::vector<std::unique_ptr<cann_task>> buffer_;
  286. const size_t capacity_;
  287. size_t mask_;
  288. size_t head_;
  289. size_t tail_;
  290. bool running_;
  291. std::thread thread_;
  292. int32_t device_;
  293. };
  294. #ifdef USE_ACL_GRAPH
  295. struct ggml_graph_node_properties {
  296. void * node_address;
  297. ggml_op node_op;
  298. int64_t ne[GGML_MAX_DIMS];
  299. size_t nb[GGML_MAX_DIMS];
  300. void * src_address[GGML_MAX_SRC];
  301. int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
  302. };
  303. struct ggml_cann_graph {
  304. ~ggml_cann_graph() {
  305. if (graph != nullptr) {
  306. aclmdlRIDestroy(graph);
  307. }
  308. }
  309. aclmdlRI graph = nullptr;
  310. std::vector<ggml_graph_node_properties> ggml_graph_properties;
  311. };
  312. #endif // USE_ACL_GRAPH
  313. /**
  314. * @brief Context for managing CANN backend operations.
  315. */
  316. struct ggml_backend_cann_context {
  317. int32_t device; /**< Device ID. */
  318. std::string name; /**< Name of the device. */
  319. std::string description; /**< Description of the device. */
  320. aclrtEvent copy_event = nullptr; /**< Event for managing copy operations. */
  321. #ifdef USE_ACL_GRAPH
  322. /// Cached CANN ACL graph used for executing the current ggml computation graph.
  323. std::unique_ptr<ggml_cann_graph> cann_graph;
  324. #endif
  325. cann_task_queue task_queue;
  326. bool async_mode;
  327. bool support_set_rows;
  328. aclrtStream streams[GGML_CANN_MAX_STREAMS] = {nullptr}; /**< Array of streams for the device. */
  329. /**
  330. * @brief Constructor for initializing the context with a given device.
  331. * @param device Device ID.
  332. */
  333. explicit ggml_backend_cann_context(int device)
  334. : device(device), name("CANN" + std::to_string(device)), task_queue(1024, device) {
  335. ggml_cann_set_device(device);
  336. description = aclrtGetSocName();
  337. async_mode = parse_bool(get_env("GGML_CANN_ASYNC_MODE").value_or(""));
  338. GGML_LOG_INFO("%s: device %d async operator submission is %s\n", __func__,
  339. device, async_mode ? "ON" : "OFF");
  340. support_set_rows = parse_bool(get_env("LLAMA_SET_ROWS").value_or(""));
  341. GGML_LOG_INFO("%s: LLAMA_SET_ROWS is %s\n", __func__, support_set_rows ? "ON" : "OFF");
  342. if (!support_set_rows) {
  343. GGML_LOG_INFO("%s: CANN Graph currently only supports execution when LLAMA_SET_ROWS is ON. "
  344. "Falling back to eager mode.\n", __func__);
  345. }
  346. }
  347. /**
  348. * @brief Destructor for cleaning up resources.
  349. */
  350. ~ggml_backend_cann_context() {
  351. ggml_cann_set_device(device);
  352. task_queue.stop();
  353. if (copy_event != nullptr) {
  354. ACL_CHECK(aclrtDestroyEvent(copy_event));
  355. }
  356. for (int i = 0; i < GGML_CANN_MAX_STREAMS; ++i) {
  357. if (streams[i] != nullptr) {
  358. ACL_CHECK(aclrtDestroyStream(streams[i]));
  359. }
  360. }
  361. }
  362. /**
  363. * @brief Get or create a stream for a given index.
  364. * @param stream Index of the stream.
  365. * @return The stream corresponding to the given index.
  366. */
  367. aclrtStream stream(int stream) {
  368. if (streams[stream] == nullptr) {
  369. ggml_cann_set_device(device);
  370. ACL_CHECK(aclrtCreateStream(&streams[stream]));
  371. }
  372. return streams[stream];
  373. }
  374. /**
  375. * @brief Get or create the default stream (index 0).
  376. * @return The default stream.
  377. */
  378. aclrtStream stream() { return stream(0); }
  379. // TODO: each stream should have a memory pool.
  380. std::unique_ptr<ggml_cann_pool>
  381. mem_pool; /**< Memory pool for the device. */
  382. /**
  383. * @brief Create a new memory pool for a given device.
  384. * @param device Device ID.
  385. * @return A unique pointer to the new memory pool.
  386. */
  387. static std::unique_ptr<ggml_cann_pool> new_pool_for_device(int device);
  388. /**
  389. * @brief Get or create the memory pool for the context.
  390. * @return Reference to the memory pool.
  391. */
  392. ggml_cann_pool& pool() {
  393. if (mem_pool == nullptr) {
  394. mem_pool = new_pool_for_device(device);
  395. }
  396. return *mem_pool;
  397. }
  398. };
  399. #endif // CANN_COMMON_H