Просмотр исходного кода

sycl: Remove not needed copy f16->f32 for dnnl mul mat (#14125)

Anton Mitkov 7 месяцев назад
Родитель
Сommit
ed52f3668e
2 измененных файлов с 6 добавлено и 6 удалено
  1. 3 0
      ggml/src/ggml-sycl/gemm.hpp
  2. 3 6
      ggml/src/ggml-sycl/ggml-sycl.cpp

+ 3 - 0
ggml/src/ggml-sycl/gemm.hpp

@@ -65,6 +65,9 @@ public:
 
         dnnl::primitive_attr primitive_attr;
         primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
+#ifdef GGML_SYCL_F16
+        primitive_attr.set_fpmath_mode(dnnl::fpmath_mode::f16);
+#endif
 
         auto a_mem = dnnl::memory(a_in_md, eng, const_cast<void*>(a));
         auto b_mem = dnnl::memory(b_in_md, eng, const_cast<void*>(b));

+ 3 - 6
ggml/src/ggml-sycl/ggml-sycl.cpp

@@ -2127,21 +2127,18 @@ inline void ggml_sycl_op_mul_mat_sycl(
         const sycl::half *src1_ptr = src1->type == GGML_TYPE_F16
                 ? (const sycl::half *)src1->data + src1_padded_row_size
                                          : src1_as_f16.get();
-        ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool(), row_diff * src1_ncols);
 
 #if GGML_SYCL_DNNL
         if (!g_ggml_sycl_disable_dnn) {
             DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ptr,
                                       DnnlGemmWrapper::to_dt<sycl::half>(), src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
-                                      dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>(), stream);
-            scope_op_debug_print scope_dbg_print(__func__, "/to_fp32_sycl", dst, /*num_src=*/2,
-                                                 " : converting dst to fp32");
-            const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
-            to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
+                                      dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
         }
         else
 #endif
         {
+            ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool(), row_diff * src1_ncols);
+
             const sycl::half alpha_f16 = 1.0f;
             const sycl::half beta_f16  = 0.0f;
             SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(