Просмотр исходного кода

CANN: Improve device ID handling and aclnnArange checks (#16752)

* cann: improve device ID handling and aclnnArange checks

- Stop relying on CANN's internal device ID retrieval; use a global variable instead.
- Enforce stricter dimension validation in aclnnArange for better compatibility across CANN versions.

* cann: use thread local var
Chenguang Li 2 месяцев назад
Родитель
Сommit
3479efd112
2 измененных файлов с 18 добавлено и 7 удалено
  1. 2 2
      ggml/src/ggml-cann/aclnn_ops.cpp
  2. 16 5
      ggml/src/ggml-cann/ggml-cann.cpp

+ 2 - 2
ggml/src/ggml-cann/aclnn_ops.cpp

@@ -2234,7 +2234,7 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx,
                               ACL_MEM_MALLOC_HUGE_FIRST));
                               ACL_MEM_MALLOC_HUGE_FIRST));
 
 
         acl_theta_scale_tensor = ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float),
         acl_theta_scale_tensor = ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float),
-                                                         theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
+                                                         theta_scale_ne, theta_scale_nb, 1);
 
 
         float start      = 0;
         float start      = 0;
         float step       = 1;
         float step       = 1;
@@ -2251,7 +2251,7 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx,
             yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float));
             yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float));
             void * yarn_ramp_buffer = yarn_ramp_allocator.get();
             void * yarn_ramp_buffer = yarn_ramp_allocator.get();
             acl_yarn_ramp_tensor   = ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float), theta_scale_ne,
             acl_yarn_ramp_tensor   = ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float), theta_scale_ne,
-                                                             theta_scale_nb, GGML_MAX_DIMS);
+                                                             theta_scale_nb, 1);
             float       zero_value = 0, one_value = 1;
             float       zero_value = 0, one_value = 1;
             float       denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]);
             float       denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]);
             aclScalar * low              = aclCreateScalar(&corr_dims[0], aclDataType::ACL_FLOAT);
             aclScalar * low              = aclCreateScalar(&corr_dims[0], aclDataType::ACL_FLOAT);

+ 16 - 5
ggml/src/ggml-cann/ggml-cann.cpp

@@ -67,19 +67,30 @@
     GGML_ABORT("CANN error");
     GGML_ABORT("CANN error");
 }
 }
 
 
+// Thread-local variable to record the current device of this thread.
+thread_local int g_current_cann_device = -1;
+
 /**
 /**
- * @brief Sets the device to be used by CANN.
+ * @brief Set the CANN device to be used.
  *
  *
- * @param device The device ID to set.
+ * @param device The target device ID to set.
  */
  */
 void ggml_cann_set_device(const int32_t device) {
 void ggml_cann_set_device(const int32_t device) {
-    int current_device = -1;
-    aclrtGetDevice(&current_device);
+    // int current_device = -1;
+    // Note: In some CANN versions, if no device has been set yet,
+    //       aclrtGetDevice(&current_device) may return 0 by default.
+    // aclrtGetDevice(&current_device);
 
 
-    if (device == current_device) {
+    // If the current device is already the target one, no need to switch.
+    if (device == g_current_cann_device) {
         return;
         return;
     }
     }
+
+    // Switch to the new device.
     ACL_CHECK(aclrtSetDevice(device));
     ACL_CHECK(aclrtSetDevice(device));
+
+    // Update the global device record.
+    g_current_cann_device = device;
 }
 }
 
 
 /**
 /**