瀏覽代碼

CANN: Add support for async operator submission (#12864)

Submit operators using asynchronous threads to improve performance.

Use the environment variable GGML_CANN_ASYNC_MODE to control whether
asynchronous submission is enabled. It is disabled by default.

Testing shows a 10%–20% performance improvement in scenarios with
small parameter sizes, especially in quantized models.
hipudding 9 月之前
父節點
當前提交
7a395f67a7
共有 4 個文件被更改,包括 593 次插入337 次删除
  1. 153 257
      ggml/src/ggml-cann/aclnn_ops.cpp
  2. 277 47
      ggml/src/ggml-cann/aclnn_ops.h
  3. 135 1
      ggml/src/ggml-cann/common.h
  4. 28 32
      ggml/src/ggml-cann/ggml-cann.cpp

文件差異過大導致無法顯示
+ 153 - 257
ggml/src/ggml-cann/aclnn_ops.cpp


+ 277 - 47
ggml/src/ggml-cann/aclnn_ops.h

@@ -23,6 +23,7 @@
 #ifndef CANN_ACLNN_OPS
 #define CANN_ACLNN_OPS
 
+#include <functional>
 #include <aclnnop/aclnn_abs.h>
 #include <aclnnop/aclnn_neg.h>
 #include <aclnnop/aclnn_exp.h>
@@ -713,6 +714,270 @@ void ggml_cann_count_equal(ggml_backend_cann_context& ctx, ggml_tensor* dst);
  */
 void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst);
 
+/*
+ * @brief A generic wrapper for ACL resources with custom deleter support.
+ */
+using any_acl_resource = std::unique_ptr<void, std::function<void(void*)>>;
+
+/**
+ * @brief Trait structure used to define how to destroy a given ACL resource type.
+ *
+ * @tparam T ACL resource type.
+ */
+template<typename T>
+struct acl_resource_traits;
+
+/**
+ * @brief Specialization for aclTensor, defines how to destroy an aclTensor resource.
+ */
+template<>
+struct acl_resource_traits<aclTensor> {
+    static void destroy(void* p) {
+        ACL_CHECK(aclDestroyTensor(static_cast<aclTensor*>(p)));
+    }
+};
+
+/**
+ * @brief Specialization for aclIntArray, defines how to destroy an aclIntArray resource.
+ */
+template<>
+struct acl_resource_traits<aclIntArray> {
+    static void destroy(void* p) {
+        ACL_CHECK(aclDestroyIntArray(static_cast<aclIntArray*>(p)));
+    }
+};
+
+/**
+ * @brief Specialization for aclScalar, defines how to destroy an aclScalar resource.
+ */
+template<>
+struct acl_resource_traits<aclScalar> {
+    static void destroy(void* p) {
+        ACL_CHECK(aclDestroyScalar(static_cast<aclScalar*>(p)));
+    }
+};
+
+/**
+ * @brief Specialization for aclTensorList, defines how to destroy an aclTensorList resource.
+ */
+template<>
+struct acl_resource_traits<aclTensorList> {
+    static void destroy(void* p) {
+        ACL_CHECK(aclDestroyTensorList(static_cast<aclTensorList*>(p)));
+    }
+};
+
+/**
+ * @brief Creates a generic ACL resource wrapper with proper destruction logic.
+ *
+ * @tparam T ACL resource type.
+ * @param ptr Raw pointer to ACL resource.
+ * @return any_acl_resource Smart pointer that handles destruction.
+ */
+template<typename T>
+any_acl_resource make_acl_resource(T* ptr) {
+    return any_acl_resource(
+        static_cast<void*>(ptr),
+        [](void* p) {
+            acl_resource_traits<T>::destroy(p);
+        }
+    );
+}
+
+/**
+ * @brief Registers multiple ACL resources into a vector for lifetime management.
+ *
+ * @tparam Args Variadic list of ACL resource types.
+ * @param vec Target vector to hold ACL resources.
+ * @param args Raw pointers to ACL resources.
+ */
+template<typename... Args>
+void register_acl_resources(std::vector<any_acl_resource>& vec, Args*... args) {
+    (vec.emplace_back(make_acl_resource(args)), ...);
+}
+
+/**
+ * @brief Task class that wraps the execution of an aclnn function call.
+ */
+class aclnn_task : public cann_task {
+    public:
+        aclnn_task(aclnn_func_t aclnn_func, void * workspace_addr,
+                   uint64_t workspace_size, aclOpExecutor * executor,
+                   aclrtStream stream) :
+            aclnn_func_(aclnn_func),
+            workspace_addr_(workspace_addr),
+            workspace_size_(workspace_size),
+            executor_(executor),
+            stream_(stream) {}
+        virtual void run_task() override {
+            ACL_CHECK(aclnn_func_(workspace_addr_, workspace_size_, executor_, stream_));
+        }
+    private:
+        aclnn_func_t aclnn_func_;
+        void *          workspace_addr_;
+        uint64_t        workspace_size_;
+        aclOpExecutor * executor_;
+        aclrtStream     stream_;
+};
+
+/**
+ * @brief Task class that releases ACL resources after usage.
+ */
+class release_resource_task : public cann_task {
+public:
+    release_resource_task(std::vector<any_acl_resource>&& resources){
+        resource_ = std::move(resources);
+    }
+
+    virtual void run_task() override {
+        resource_.clear();
+    }
+private:
+    std::vector<any_acl_resource> resource_;
+};
+
+/**
+ * @brief Task class for performing asynchronous memory copy operations.
+ */
+class async_memcpy_task : public cann_task {
+public:
+    async_memcpy_task(void* dst, const void* src, size_t size,
+                      aclrtMemcpyKind kind, aclrtStream stream)
+        : dst_(dst), src_(src), size_(size), kind_(kind), stream_(stream) {}
+
+    virtual void run_task() override {
+        ACL_CHECK(aclrtMemcpyAsync(dst_, size_, src_, size_, kind_, stream_));
+    }
+private:
+    void* dst_;
+    const void* src_;
+    size_t size_;
+    aclrtMemcpyKind kind_;
+    aclrtStream stream_;
+};
+
+/**
+ * @brief Task class for performing asynchronous memory set operations.
+ */
+class async_memset_task : public cann_task {
+    public:
+    async_memset_task(void* buffer, size_t size, int32_t value, aclrtStream stream)
+            : buffer_(buffer), size_(size), value_(value), stream_(stream) {}
+
+        virtual void run_task() override {
+            ACL_CHECK(aclrtMemsetAsync(buffer_, size_, value_, size_, stream_));
+        }
+    private:
+        void* buffer_;
+        size_t size_;
+        int32_t value_;
+        aclrtStream stream_;
+};
+
+/**
+ * @brief Launches an asynchronous task using the memory allocator.
+ *
+ * This macro submit an asynchronous task on the specified stream.
+ * The task uses memory allocated by the allocator. It is guaranteed
+ * that the memory will not be accessed by other tasks until this task
+ * completes, due to the sequential execution order within the same stream.
+ *
+ * @param OP_NAME aclnn operator name.
+ * @param args Additional arguments required by the task.
+ *
+ * @note
+ * Memory from the allocator will be "freed" immediately and can be
+ * reallocated to other pointers. However, it won't be accessed by any
+ * other task before this asynchronous task ends, because all tasks in the
+ * same stream are executed in queue order.
+ */
+
+#define GGML_CANN_CALL_ACLNN_OP(CTX, OP_NAME, ...)                                          \
+    do {                                                                                    \
+        uint64_t        workspaceSize = 0;                                                  \
+        aclOpExecutor * executor;                                                           \
+        void *          workspaceAddr = nullptr;                                            \
+        ACL_CHECK(aclnn##OP_NAME##GetWorkspaceSize(__VA_ARGS__, &workspaceSize, &executor));\
+        /* workspace should alloced in main thread to keep malloc order when using vmm. */  \
+        if (workspaceSize > 0) {                                                            \
+            ggml_cann_pool_alloc workspace_allocator(CTX.pool(), workspaceSize);            \
+            workspaceAddr = workspace_allocator.get();                                      \
+        }                                                                                   \
+        if (CTX.async_mode) {                                                               \
+            auto task =                                                                     \
+                std::make_unique<aclnn_task>(aclnn##OP_NAME, workspaceAddr, workspaceSize,  \
+                    executor, CTX.stream()); \
+            CTX.task_queue.submit_task(std::move(task));                                    \
+        } else {                                                                            \
+            ACL_CHECK(aclnn##OP_NAME(workspaceAddr, workspaceSize, executor, CTX.stream()));\
+        }                                                                                   \
+    } while (0)
+
+/**
+ * @brief Registers and releases multiple ACL resources, optionally deferring the release
+ *        using a task.
+ *
+ * @tparam Args Types of the ACL resources.
+ * @param ctx Backend context which manages task submission and async mode.
+ * @param args Pointers to ACL resources to be released.
+ */
+template <typename... Args>
+void ggml_cann_release_resources(ggml_backend_cann_context & ctx, Args &&... args) {
+    std::vector<any_acl_resource> resources;
+    register_acl_resources(resources, std::forward<Args>(args)...);
+    if(ctx.async_mode) {
+        auto task = std::make_unique<release_resource_task>(std::move(resources));
+        ctx.task_queue.submit_task(std::move(task));
+    }
+}
+
+/**
+ * @brief Performs an asynchronous memory copy operation, optionally deferred via task submission.
+ *
+ * @param ctx Backend context containing stream and async configuration.
+ * @param dst Destination memory address.
+ * @param src Source memory address.
+ * @param len Size of memory to copy (in bytes).
+ * @param kind Type of memory copy (host-to-device, device-to-host, etc).
+ */
+inline void ggml_cann_async_memcpy(ggml_backend_cann_context & ctx, void * dst,
+                                   const void * src, size_t len, aclrtMemcpyKind kind) {
+    if (ctx.async_mode) {
+        auto task = std::make_unique<async_memcpy_task>(dst, const_cast<void *>(src), len, kind, ctx.stream());
+        ctx.task_queue.submit_task(std::move(task));
+    } else {
+        ACL_CHECK(aclrtMemcpyAsync(dst, len, src, len, kind, ctx.stream()));
+    }
+}
+
+inline void ggml_cann_async_memcpy(ggml_backend_cann_context * ctx, void * dst,
+                                   const void * src, size_t len, aclrtMemcpyKind kind) {
+    if (ctx->async_mode) {
+        auto task = std::make_unique<async_memcpy_task>(dst, const_cast<void *>(src), len, kind, ctx->stream());
+        ctx->task_queue.submit_task(std::move(task));
+    } else {
+        ACL_CHECK(aclrtMemcpyAsync(dst, len, src, len, kind, ctx->stream()));
+    }
+}
+
+/**
+ * @brief Performs an asynchronous memory set operation, optionally deferred via task submission.
+ *
+ * @param ctx Backend context containing stream and async configuration.
+ * @param buffer Memory buffer to be set.
+ * @param size Size of the memory buffer (in bytes).
+ * @param value Value to set in the buffer.
+ */
+inline void ggml_cann_async_memset(ggml_backend_cann_context & ctx, void * buffer,
+                                   size_t size, int value) {
+    if (ctx.async_mode) {
+        auto task = std::make_unique<async_memset_task>(buffer, size, value, ctx.stream());
+        ctx.task_queue.submit_task(std::move(task));
+    } else {
+        ACL_CHECK(aclrtMemsetAsync(buffer, size, value, size, ctx.stream()));
+    }
+}
+
 /**
  * @brief Applies a element-wise operation to two input tensors using the CANN
  * backend.
@@ -742,42 +1007,9 @@ void ggml_cann_binary_op(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
     bcast_shape(src0, src1, dst, &acl_src0, &acl_src1, &acl_dst);
     binary_op(ctx, acl_src0, acl_src1, acl_dst);
 
-    ACL_CHECK(aclDestroyTensor(acl_src0));
-    ACL_CHECK(aclDestroyTensor(acl_src1));
-    ACL_CHECK(aclDestroyTensor(acl_dst));
+    ggml_cann_release_resources(ctx, acl_src0, acl_src1, acl_dst);
 }
 
-/**
- * @brief Launches an asynchronous task using the memory allocator.
- *
- * This macro submit an asynchronous task on the specified stream.
- * The task uses memory allocated by the allocator. It is guaranteed
- * that the memory will not be accessed by other tasks until this task
- * completes, due to the sequential execution order within the same stream.
- *
- * @param OP_NAME aclnn operator name.
- * @param args Additional arguments required by the task.
- *
- * @note
- * Memory from the allocator will be "freed" immediately and can be
- * reallocated to other pointers. However, it won't be accessed by any
- * other task before this asynchronous task ends, because all tasks in the
- * same stream are executed in queue order.
- */
-#define GGML_CANN_CALL_ACLNN_OP(OP_NAME, ...)                                                \
-    do {                                                                                     \
-        uint64_t        workspaceSize = 0;                                                   \
-        aclOpExecutor * executor;                                                            \
-        void *          workspaceAddr = nullptr;                                             \
-                                                                                             \
-        ACL_CHECK(aclnn##OP_NAME##GetWorkspaceSize(__VA_ARGS__, &workspaceSize, &executor)); \
-                                                                                             \
-        if (workspaceSize > 0) {                                                             \
-            ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);             \
-            workspaceAddr = workspace_allocator.get();                                       \
-        }                                                                                    \
-        ACL_CHECK(aclnn##OP_NAME(workspaceAddr, workspaceSize, executor, ctx.stream()));     \
-    } while (0)
 
 /**
  * @brief Applies a unary operation to an input tensor using the CANN backend.
@@ -799,9 +1031,7 @@ template <void unary_op(ggml_backend_cann_context&, aclTensor*, aclTensor*)>
     aclTensor* acl_dst = ggml_cann_create_tensor(dst);
 
     unary_op(ctx, acl_src, acl_dst);
-
-    ACL_CHECK(aclDestroyTensor(acl_src));
-    ACL_CHECK(aclDestroyTensor(acl_dst));
+    ggml_cann_release_resources(ctx, acl_src, acl_dst);
 }
 
 /**
@@ -832,7 +1062,7 @@ void ggml_cann_unary_op(
  *
  * Internally, the lambda will call:
  * @code
- * GGML_CANN_CALL_ACLNN_OP(OP_NAME, acl_src, acl_dst);
+ * GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst);
  * @endcode
  *
  * @param OP_NAME The name of the ACL unary operator to invoke via GGML_CANN_CALL_ACLNN_OP.
@@ -840,14 +1070,14 @@ void ggml_cann_unary_op(
  * @see ggml_cann_unary_op
  * @see GGML_CANN_CALL_ACLNN_OP
  */
-#define GGML_CANN_CALL_UNARY_OP(OP_NAME)                         \
-    do {                                                         \
-        auto lambda = [](ggml_backend_cann_context& ctx,         \
-            aclTensor* acl_src,                                  \
-            aclTensor* acl_dst) {                                \
-            GGML_CANN_CALL_ACLNN_OP(OP_NAME, acl_src, acl_dst);  \
-        };                                                       \
-        ggml_cann_unary_op(lambda, ctx, dst);                    \
-    }                                                            \
+#define GGML_CANN_CALL_UNARY_OP(OP_NAME)                              \
+    do {                                                              \
+        auto lambda = [](ggml_backend_cann_context& ctx,              \
+            aclTensor* acl_src,                                       \
+            aclTensor* acl_dst) {                                     \
+            GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst);  \
+        };                                                            \
+        ggml_cann_unary_op(lambda, ctx, dst);                         \
+    }                                                                 \
     while (0)
 #endif  // CANN_ACLNN_OPS

+ 135 - 1
ggml/src/ggml-cann/common.h

@@ -31,9 +31,16 @@
 #include <memory>
 #include <string>
 #include <vector>
+#include <atomic>
+#include <condition_variable>
+#include <mutex>
+#include <thread>
+#include <unistd.h>
+#include <functional>
 
 #include "../include/ggml-cann.h"
 #include "../include/ggml.h"
+#include "../ggml-impl.h"
 
 #define MATRIX_ROW_PADDING 512
 #define GGML_CANN_MAX_STREAMS 8
@@ -205,6 +212,127 @@ struct ggml_cann_pool_alloc {
     ggml_cann_pool_alloc& operator=(ggml_cann_pool_alloc&&) = delete;
 };
 
+/**
+ * @brief Function pointer type for ACLNN operator calls.
+ */
+using aclnn_func_t = aclnnStatus (*)(void*, uint64_t, aclOpExecutor*, aclrtStream);
+
+/**
+ * @brief Base class for all CANN tasks to be submitted to the task queue.
+ *
+ * Users should override the run_task() method with actual task logic.
+ */
+class cann_task {
+public:
+    virtual void run_task() {}
+};
+
+/**
+ * @brief A lock-free ring-buffer based task queue for asynchronously executing cann_task instances.
+ */
+class cann_task_queue {
+public:
+    /**
+     * @brief Constructs a task queue with a fixed power-of-two capacity for a specific device.
+     *
+     * @param capacity Queue capacity. Must be a power of 2.
+     * @param device Target device ID (used for context setting).
+     */
+    explicit cann_task_queue(size_t capacity, int32_t device)
+        : buffer_(capacity), capacity_(capacity), head_(0), tail_(0),
+          running_(false), device_(device) {
+        GGML_ASSERT((capacity & (capacity - 1)) == 0 && "capacity must be power of 2");
+        mask_ = capacity_ - 1;
+    }
+
+    /**
+     * @brief Attempts to enqueue a task into the queue.
+     *
+     * @param item Unique pointer to the task.
+     * @return true if the task was successfully enqueued, false if the queue was full.
+     */
+    bool enqueue(std::unique_ptr<cann_task>&& item) {
+        size_t next_tail = (tail_ + 1) & mask_;
+
+        if (next_tail == head_) {
+            return false;
+        }
+
+        buffer_[tail_] = std::move(item);
+        std::atomic_thread_fence(std::memory_order_release);
+        tail_ = next_tail;
+
+        return true;
+    }
+
+    /**
+     * @brief Submits a task to the queue, and starts the worker thread if not already running.
+     *
+     * @param task Task to be submitted.
+     */
+    void submit_task(std::unique_ptr<cann_task>&& task) {
+        while(!enqueue(std::move(task))) {
+            std::this_thread::yield();
+            continue;
+        }
+
+        if (!running_) {
+            running_ = true;
+            thread_ = std::thread(&cann_task_queue::execute, this);
+        }
+
+    }
+
+    /**
+     * @brief Waits until the queue is completely empty and no tasks are being processed.
+     */
+    void wait() {
+        while (running_ && head_ != tail_) {
+            std::this_thread::yield();
+            continue;
+        }
+    }
+
+    /**
+     * @brief Stops the task queue and joins the worker thread.
+     */
+    void stop() {
+        running_ = false;
+        if (thread_.joinable()) {
+            thread_.join();
+        }
+    }
+
+private:
+    /**
+     * @brief Worker thread function that continuously dequeues and executes tasks.
+     */
+    void execute() {
+        ggml_cann_set_device(device_);
+
+        while (running_) {
+            if(head_ == tail_) {
+                std::this_thread::yield();
+                continue;
+            }
+
+            std::atomic_thread_fence(std::memory_order_acquire);
+            buffer_[head_]->run_task();
+            buffer_[head_].reset();
+            head_ = (head_ + 1) & mask_;
+        }
+    }
+
+    std::vector<std::unique_ptr<cann_task>> buffer_;
+    const size_t capacity_;
+    size_t mask_;
+    size_t head_;
+    size_t tail_;
+    bool running_;
+    std::thread thread_;
+    int32_t device_;
+};
+
 /**
  * @brief Context for managing CANN backend operations.
  */
@@ -213,6 +341,8 @@ struct ggml_backend_cann_context {
     std::string name;                /**< Name of the device. */
     std::string description;         /**< Description of the device. */
     aclrtEvent copy_event = nullptr; /**< Event for managing copy operations. */
+    cann_task_queue task_queue;
+    bool async_mode;
 
     aclrtStream streams[GGML_CANN_MAX_STREAMS] = {nullptr}; /**< Array of streams for the device. */
 
@@ -221,9 +351,12 @@ struct ggml_backend_cann_context {
      * @param device Device ID.
      */
     explicit ggml_backend_cann_context(int device)
-        : device(device), name("CANN" + std::to_string(device)) {
+        : device(device), name("CANN" + std::to_string(device)), task_queue(1024, device) {
         ggml_cann_set_device(device);
         description = aclrtGetSocName();
+        async_mode = (getenv("GGML_CANN_ASYNC_MODE") != nullptr);
+        GGML_LOG_INFO("%s: device %d async operator submission is %s\n", __func__,
+            device, async_mode ? "ON" : "OFF");
     }
 
     /**
@@ -231,6 +364,7 @@ struct ggml_backend_cann_context {
      */
     ~ggml_backend_cann_context() {
         ggml_cann_set_device(device);
+        task_queue.stop();
         if (copy_event != nullptr) {
             ACL_CHECK(aclrtDestroyEvent(copy_event));
         }

+ 28 - 32
ggml/src/ggml-cann/ggml-cann.cpp

@@ -1606,7 +1606,7 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
                     auto lambda = [](ggml_backend_cann_context& ctx,
                         aclTensor* acl_src,
                         aclTensor* acl_dst) {
-                        GGML_CANN_CALL_ACLNN_OP(GeluV2, acl_src, 0, acl_dst);
+                        GGML_CANN_CALL_ACLNN_OP(ctx, GeluV2, acl_src, 0, acl_dst);
                     };
                     ggml_cann_unary_op(lambda, ctx, dst);
                 } break;
@@ -1789,12 +1789,11 @@ static void ggml_backend_cann_free(ggml_backend_t backend) {
     delete backend;
 }
 
+
 /**
  * @brief Sets tensor data asynchronously in the CANN backend.
  *
- * This function asynchronously sets tensor data in the CANN backend. Depending
- * on the tensor type, it may perform data transformations before copying data
- * to the device.
+ * This function asynchronously sets tensor data in the CANN backend.
  *
  * @param backend Pointer to the CANN backend structure.
  * @param tensor Pointer to the tensor structure to set data for.
@@ -1809,23 +1808,28 @@ static void ggml_backend_cann_set_tensor_async(ggml_backend_t backend,
                                                size_t size) {
     ggml_backend_cann_context *cann_ctx =
         (ggml_backend_cann_context *)backend->context;
+    ggml_backend_buffer_t buf =
+        tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
 
-    if (!need_transform(tensor->type)) {
-        ACL_CHECK(aclrtMemcpyAsync((char *)tensor->data + offset, size, data,
-                                   size, ACL_MEMCPY_HOST_TO_DEVICE,
-                                   cann_ctx->stream()));
-    } else {
-        void *transform_buffer = malloc(size);
-        ggml_backend_cann_transform(tensor, data, transform_buffer);
+    GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) &&
+        "unsupported buffer type");
+    GGML_ASSERT(!ggml_is_quantized(tensor->type));
 
-        ACL_CHECK(aclrtMemcpyAsync(
-            (char *)tensor->data + offset, size, transform_buffer, size,
-            ACL_MEMCPY_HOST_TO_DEVICE, cann_ctx->stream()));
-        ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
-        free(transform_buffer);
-    }
+    ggml_cann_async_memcpy(cann_ctx, (char *)tensor->data + offset, data, size,
+        ACL_MEMCPY_HOST_TO_DEVICE);
 }
 
+/**
+ * @brief Gets tensor data asynchronously in the CANN backend.
+ *
+ * This function asynchronously gets tensor data in the CANN backend.
+ *
+ * @param backend Pointer to the CANN backend structure.
+ * @param tensor Pointer to the tensor structure to get data from.
+ * @param data Pointer to the host data to copy from the tensor.
+ * @param offset Offset in bytes within the host data.
+ * @param size Size of the data to copy in bytes.
+ */
 static void ggml_backend_cann_get_tensor_async(
     ggml_backend_t backend, const ggml_tensor *tensor, void *data,
     size_t offset, size_t size) {
@@ -1836,20 +1840,11 @@ static void ggml_backend_cann_get_tensor_async(
 
     GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) &&
                 "unsupported buffer type");
+    GGML_ASSERT(!ggml_is_quantized(tensor->type));
+
+    ggml_cann_async_memcpy(cann_ctx, data, (char *)tensor->data + offset, size,
+        ACL_MEMCPY_DEVICE_TO_HOST);
 
-    if (!need_transform(tensor->type)) {
-        ACL_CHECK(aclrtMemcpyAsync(data, size, (char *)tensor->data + offset,
-                                   size, ACL_MEMCPY_DEVICE_TO_HOST,
-                                   cann_ctx->stream()));
-    } else {
-        void *transform_buffer = malloc(size);
-        ACL_CHECK(aclrtMemcpyAsync(
-            transform_buffer, size, (char *)tensor->data + offset, size,
-            ACL_MEMCPY_DEVICE_TO_HOST, cann_ctx->stream()));
-        ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
-        ggml_backend_cann_transform_back(tensor, transform_buffer, data);
-        free(transform_buffer);
-    }
 }
 
 /**
@@ -1909,6 +1904,8 @@ static bool ggml_backend_cann_cpy_tensor_async(
         ggml_cann_set_device(cann_ctx_src->device);
         ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_dst->device, 0));
 
+        // wait for task_queue empty to keep task order.
+        cann_ctx_src->task_queue.wait();
         ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size,
                                    ACL_MEMCPY_DEVICE_TO_DEVICE,
                                    cann_ctx_src->stream()));
@@ -1936,9 +1933,8 @@ static bool ggml_backend_cann_cpy_tensor_async(
 static void ggml_backend_cann_synchronize(ggml_backend_t backend) {
     ggml_backend_cann_context* cann_ctx =
         (ggml_backend_cann_context*)backend->context;
-
+    cann_ctx->task_queue.wait();
     ggml_cann_set_device(cann_ctx->device);
-
     ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
 }
 

部分文件因文件數量過多而無法顯示