|
@@ -1689,9 +1689,14 @@ namespace dpct
|
|
|
auto data_a = get_memory<const Ta>(a);
|
|
auto data_a = get_memory<const Ta>(a);
|
|
|
auto data_b = get_memory<const Tb>(b);
|
|
auto data_b = get_memory<const Tb>(b);
|
|
|
auto data_c = get_memory<Tc>(c);
|
|
auto data_c = get_memory<Tc>(c);
|
|
|
- oneapi::mkl::blas::column_major::gemm(
|
|
|
|
|
- q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
|
|
|
|
|
- data_b, ldb, beta_value, data_c, ldc);
|
|
|
|
|
|
|
+#ifdef GGML_SYCL_NVIDIA
|
|
|
|
|
+ oneapi::mkl::blas::column_major::gemm(oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q },
|
|
|
|
|
+ a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb,
|
|
|
|
|
+ beta_value, data_c, ldc);
|
|
|
|
|
+#else
|
|
|
|
|
+ oneapi::mkl::blas::column_major::gemm(q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb,
|
|
|
|
|
+ beta_value, data_c, ldc);
|
|
|
|
|
+#endif
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
template <typename VecT, class BinaryOperation, class = void>
|
|
template <typename VecT, class BinaryOperation, class = void>
|
|
@@ -1754,14 +1759,22 @@ namespace dpct
|
|
|
matrix_info->ld_info[2] = ldc;
|
|
matrix_info->ld_info[2] = ldc;
|
|
|
matrix_info->groupsize_info = batch_size;
|
|
matrix_info->groupsize_info = batch_size;
|
|
|
|
|
|
|
|
|
|
+#ifdef GGML_SYCL_NVIDIA
|
|
|
|
|
+ sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
|
|
|
|
|
+ oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, matrix_info->transpose_info,
|
|
|
|
|
+ matrix_info->transpose_info + 1, matrix_info->size_info, matrix_info->size_info + 1,
|
|
|
|
|
+ matrix_info->size_info + 2, matrix_info->value_info, reinterpret_cast<const Ta **>(a),
|
|
|
|
|
+ matrix_info->ld_info, reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1,
|
|
|
|
|
+ matrix_info->value_info + 1, reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1,
|
|
|
|
|
+ &(matrix_info->groupsize_info));
|
|
|
|
|
+#else
|
|
|
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
|
|
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
|
|
|
- q, matrix_info->transpose_info, matrix_info->transpose_info + 1,
|
|
|
|
|
- matrix_info->size_info, matrix_info->size_info + 1,
|
|
|
|
|
- matrix_info->size_info + 2, matrix_info->value_info,
|
|
|
|
|
- reinterpret_cast<const Ta **>(a), matrix_info->ld_info,
|
|
|
|
|
- reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1,
|
|
|
|
|
- matrix_info->value_info + 1, reinterpret_cast<Tc **>(c),
|
|
|
|
|
|
|
+ q, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info,
|
|
|
|
|
+ matrix_info->size_info + 1, matrix_info->size_info + 2, matrix_info->value_info,
|
|
|
|
|
+ reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b),
|
|
|
|
|
+ matrix_info->ld_info + 1, matrix_info->value_info + 1, reinterpret_cast<Tc **>(c),
|
|
|
matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
|
|
matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
|
|
|
|
|
+#endif
|
|
|
|
|
|
|
|
q.submit([&](sycl::handler &cgh)
|
|
q.submit([&](sycl::handler &cgh)
|
|
|
{
|
|
{
|
|
@@ -1783,10 +1796,16 @@ namespace dpct
|
|
|
auto data_a = get_memory<const Ta>(a);
|
|
auto data_a = get_memory<const Ta>(a);
|
|
|
auto data_b = get_memory<const Tb>(b);
|
|
auto data_b = get_memory<const Tb>(b);
|
|
|
auto data_c = get_memory<Tc>(c);
|
|
auto data_c = get_memory<Tc>(c);
|
|
|
|
|
+#ifdef GGML_SYCL_NVIDIA
|
|
|
oneapi::mkl::blas::column_major::gemm_batch(
|
|
oneapi::mkl::blas::column_major::gemm_batch(
|
|
|
- q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
|
|
|
|
|
- stride_a, data_b, ldb, stride_b, beta_value,
|
|
|
|
|
- data_c, ldc, stride_c, batch_size);
|
|
|
|
|
|
|
+ oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, a_trans, b_trans, m, n, k,
|
|
|
|
|
+ alpha_value, data_a, lda, stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc, stride_c,
|
|
|
|
|
+ batch_size);
|
|
|
|
|
+#else
|
|
|
|
|
+ oneapi::mkl::blas::column_major::gemm_batch(q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
|
|
|
|
|
+ stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc,
|
|
|
|
|
+ stride_c, batch_size);
|
|
|
|
|
+#endif
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
} // namespace detail
|
|
} // namespace detail
|