Procházet zdrojové kódy

sycl: Hotfix for non dnnl codepath (#14677)

Anton Mitkov před 6 měsíci
rodič
revize
bdca38376f
1 změnil soubory, kde provedl 9 přidání a 1 odebrání
  1. 9 1
      ggml/src/ggml-sycl/ggml-sycl.cpp

+ 9 - 1
ggml/src/ggml-sycl/ggml-sycl.cpp

@@ -2875,12 +2875,20 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
             }
             }
 
 
         }
         }
+#if GGML_SYCL_DNNL
+        // oneDNN handles strided data and does not need overhead of get_to_fp16_nc_sycl
         const int64_t ne_src1 = src1->nb[last_str] * src1->ne[last_dim] / type_size_src1;
         const int64_t ne_src1 = src1->nb[last_str] * src1->ne[last_dim] / type_size_src1;
         src1_f16_alloc.alloc(ne_src1);
         src1_f16_alloc.alloc(ne_src1);
-
         const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst);
         const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst);
         GGML_ASSERT(to_fp16_sycl != nullptr);
         GGML_ASSERT(to_fp16_sycl != nullptr);
         to_fp16_sycl(src1_f16, src1_f16_alloc.get(), ne_src1, queue);
         to_fp16_sycl(src1_f16, src1_f16_alloc.get(), ne_src1, queue);
+# else
+        const int64_t ne_src1 = ggml_nelements(src1);
+        src1_f16_alloc.alloc(ne_src1);
+        const to_fp16_nc_sycl_t to_fp16_nc_sycl = get_to_fp16_nc_sycl(src1->type);
+        GGML_ASSERT(to_fp16_nc_sycl != nullptr);
+        to_fp16_nc_sycl(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, queue);
+#endif
 
 
         src1_f16 = src1_f16_alloc.get();
         src1_f16 = src1_f16_alloc.get();
         s11      = ne10;
         s11      = ne10;