|
|
@@ -3809,11 +3809,43 @@ static void ggml_backend_sycl_graph_compute_impl(ggml_backend_sycl_context * syc
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+#ifdef GGML_SYCL_GRAPH
|
|
|
+static bool check_graph_compatibility(ggml_cgraph * cgraph) {
|
|
|
+ if (ggml_sycl_info().device_count > 1) {
|
|
|
+ // A sycl_ex::command_graph object can only be created for a single device
|
|
|
+ GGML_LOG_INFO("%s: disabling SYCL graphs due to multiple devices\n", __func__);
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+
|
|
|
+ for (int i = 0; i < cgraph->n_nodes; i++) {
|
|
|
+ const ggml_op node_op = cgraph->nodes[i]->op;
|
|
|
+ switch (node_op) {
|
|
|
+ default:
|
|
|
+ break;
|
|
|
+ case GGML_OP_CONCAT:
|
|
|
+ // ggml_sycl_op_concat() does a blocking host wait after memcpy operations,
|
|
|
+ // but wait() can't be called on the events returned by a queue recording
|
|
|
+ // to a graph.
|
|
|
+ [[fallthrough]];
|
|
|
+ case GGML_OP_MUL_MAT_ID:
|
|
|
+ // ggml_sycl_mul_mat_id() does a blocking host wait on the sycl queue after
|
|
|
+ // submitting a memcpy operation, but wait() can't be called on a queue that
|
|
|
+ // is recording to a graph.
|
|
|
+ GGML_LOG_INFO("%s: disabling SYCL graphs due to unsupported node type %s\n", __func__,
|
|
|
+ ggml_op_name(node_op));
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return true;
|
|
|
+}
|
|
|
+#endif
|
|
|
+
|
|
|
static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
|
|
auto * sycl_ctx = static_cast<ggml_backend_sycl_context *>(backend->context);
|
|
|
|
|
|
#ifdef GGML_SYCL_GRAPH
|
|
|
- if (!g_ggml_sycl_disable_graph) {
|
|
|
+ bool use_sycl_graph = !g_ggml_sycl_disable_graph && check_graph_compatibility(cgraph);
|
|
|
+ if (use_sycl_graph) {
|
|
|
const bool graph_support = dpct::get_device(sycl_ctx->device).has(sycl::aspect::ext_oneapi_limited_graph);
|
|
|
if (!graph_support) {
|
|
|
GGML_SYCL_DEBUG("[SYCL-GRAPH] can not use graphs on device:%d\n", sycl_ctx->device);
|