|
|
@@ -31,6 +31,8 @@
|
|
|
#include <mutex>
|
|
|
#include <queue>
|
|
|
#include <chrono>
|
|
|
+#include <unordered_set>
|
|
|
+#include <optional>
|
|
|
|
|
|
#include "ggml-impl.h"
|
|
|
#include "ggml-backend-impl.h"
|
|
|
@@ -93,6 +95,26 @@ int32_t ggml_cann_get_device() {
|
|
|
return id;
|
|
|
}
|
|
|
|
|
|
+/**
|
|
|
+ * @brief Get the value of the specified environment variable (name).
|
|
|
+ * if not empty, return a std::string object
|
|
|
+ */
|
|
|
+std::optional<std::string> get_env(const std::string& name) {
|
|
|
+ const char* val = std::getenv(name.c_str());
|
|
|
+ if (!val) return std::nullopt;
|
|
|
+ std::string res = std::string(val);
|
|
|
+ std::transform(res.begin(), res.end(), res.begin(), ::tolower);
|
|
|
+ return res;
|
|
|
+}
|
|
|
+
|
|
|
+/**
|
|
|
+ * @brief Verify whether the environment variable is a valid value.
|
|
|
+ */
|
|
|
+bool parse_bool(const std::string& value) {
|
|
|
+ std::unordered_set<std::string> valid_values = {"on", "1", "yes", "y", "enable", "true"};
|
|
|
+ return valid_values.find(value) != valid_values.end();
|
|
|
+}
|
|
|
+
|
|
|
/**
|
|
|
* @brief Initialize the CANN device information.
|
|
|
*
|
|
|
@@ -214,7 +236,7 @@ struct ggml_cann_pool_buf_prio : public ggml_cann_pool {
|
|
|
* @param device The device ID to associate with this buffer pool.
|
|
|
*/
|
|
|
explicit ggml_cann_pool_buf_prio(int device) : device(device) {
|
|
|
- disable_clean = getenv("GGML_CANN_DISABLE_BUF_POOL_CLEAN") != nullptr;
|
|
|
+ disable_clean = parse_bool(get_env("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or(""));
|
|
|
}
|
|
|
|
|
|
/**
|
|
|
@@ -410,7 +432,7 @@ struct ggml_cann_pool_buf : public ggml_cann_pool {
|
|
|
* @param device The device ID to associate with this buffer pool.
|
|
|
*/
|
|
|
explicit ggml_cann_pool_buf(int device) : device(device) {
|
|
|
- disable_clean = getenv("GGML_CANN_DISABLE_BUF_POOL_CLEAN") != nullptr;
|
|
|
+ disable_clean = parse_bool(get_env("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or(""));
|
|
|
}
|
|
|
|
|
|
/**
|
|
|
@@ -731,16 +753,18 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool {
|
|
|
*/
|
|
|
std::unique_ptr<ggml_cann_pool> ggml_backend_cann_context::new_pool_for_device(
|
|
|
int device) {
|
|
|
- bool disable_vmm = (getenv("GGML_CANN_DISABLE_VMM_POOL") != nullptr);
|
|
|
- if (!disable_vmm && ggml_cann_info().devices[device].vmm) {
|
|
|
- GGML_LOG_INFO("%s: device %d use vmm pool\n", __func__, device);
|
|
|
- return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_vmm(device));
|
|
|
- }
|
|
|
- bool enable_buf_prio = (getenv("GGML_CANN_ENABLE_BUF_PRIO_POOL") != nullptr);
|
|
|
- if (enable_buf_prio) {
|
|
|
+ std::string mem_pool_type = get_env("GGML_CANN_MEM_POOL").value_or("");
|
|
|
+
|
|
|
+ if (mem_pool_type == "prio") {
|
|
|
GGML_LOG_INFO("%s: device %d use buffer pool with priority queue\n", __func__, device);
|
|
|
return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_buf_prio(device));
|
|
|
}
|
|
|
+
|
|
|
+ if (ggml_cann_info().devices[device].vmm && mem_pool_type != "leg") {
|
|
|
+ GGML_LOG_INFO("%s: device %d use vmm pool\n", __func__, device);
|
|
|
+ return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_vmm(device));
|
|
|
+ }
|
|
|
+
|
|
|
GGML_LOG_INFO("%s: device %d use buffer pool\n", __func__, device);
|
|
|
return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_buf(device));
|
|
|
}
|