Просмотр исходного кода

SYCL: set extras only on GGML_TYPE_Q4_0 (#12366)

* SYCL: set extras only on GGML_TYPE_Q4_0

* release tensor_extras in reset buffer interface
Akarshan Biswas 10 месяцев назад
Родитель
Сommit
b3c9a65673
1 измененных файлов с 22 добавлено и 7 удалено
  1. 22 7
      ggml/src/ggml-sycl/ggml-sycl.cpp

+ 22 - 7
ggml/src/ggml-sycl/ggml-sycl.cpp

@@ -333,10 +333,11 @@ ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer,
         assert(tensor->view_src->buffer->buft == buffer->buft);
         return GGML_STATUS_SUCCESS;
     }
-
-    ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
-    tensor->extra = extra;
-    ctx->tensor_extras.push_back(extra); //used to release it when destroy ctx.
+    if (tensor->type == GGML_TYPE_Q4_0) {
+        ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
+        tensor->extra                 = extra;
+        ctx->tensor_extras.push_back(extra);  //used to release it when destroy ctx.
+    }
 
     if (ggml_is_quantized(tensor->type)) {
         // initialize padding to 0 to avoid possible NaN values
@@ -486,6 +487,22 @@ catch (sycl::exception const &exc) {
   std::exit(1);
 }
 
+static void ggml_backend_sycl_buffer_reset(ggml_backend_buffer_t buffer) {
+    GGML_SYCL_DEBUG("[SYCL] call %s\n", __func__);
+    if (buffer == nullptr) {
+        return;
+    }
+
+    ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *) buffer->context;
+
+    if (ctx != nullptr) {
+        for (ggml_tensor_extra_gpu * extra : ctx->tensor_extras) {
+            release_extra_gpu(extra);
+        }
+        ctx->tensor_extras.clear();  // reset the tensor_extras vector
+    }
+}
+
 static const ggml_backend_buffer_i ggml_backend_sycl_buffer_interface = {
     /* .free_buffer     = */ ggml_backend_sycl_buffer_free_buffer,
     /* .get_base        = */ ggml_backend_sycl_buffer_get_base,
@@ -495,7 +512,7 @@ static const ggml_backend_buffer_i ggml_backend_sycl_buffer_interface = {
     /* .get_tensor      = */ ggml_backend_sycl_buffer_get_tensor,
     /* .cpy_tensor      = */ ggml_backend_sycl_buffer_cpy_tensor,
     /* .clear           = */ ggml_backend_sycl_buffer_clear,
-    /* .reset           = */ NULL,
+    /* .reset           = */ ggml_backend_sycl_buffer_reset,
 };
 
 // sycl buffer type
@@ -576,7 +593,6 @@ ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device) {
     static std::mutex mutex;
     std::lock_guard<std::mutex> lock(mutex);
 
-    GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_buffer_type\n");
 
     auto dev_count = ggml_backend_sycl_get_device_count();
 
@@ -3761,7 +3777,6 @@ bool ggml_backend_is_sycl(ggml_backend_t backend) {
 }
 
 int ggml_backend_sycl_get_device_count() {
-    GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_device_count\n");
     return ggml_sycl_info().device_count;
 }