|
|
@@ -82,6 +82,14 @@ inline std::string get_device_backend_and_type(const sycl::device &device) {
|
|
|
return device_type.str();
|
|
|
}
|
|
|
|
|
|
+template <typename Ts> struct matrix_info_t {
|
|
|
+ oneapi::mkl::transpose transpose_info[2];
|
|
|
+ Ts value_info[2];
|
|
|
+ std::int64_t size_info[3];
|
|
|
+ std::int64_t ld_info[3];
|
|
|
+ std::int64_t groupsize_info;
|
|
|
+};
|
|
|
+
|
|
|
namespace dpct
|
|
|
{
|
|
|
typedef sycl::queue *queue_ptr;
|
|
|
@@ -1727,26 +1735,13 @@ namespace dpct
|
|
|
};
|
|
|
|
|
|
template <class Ta, class Tb, class Tc, class Ts>
|
|
|
- inline void gemm_batch_impl(sycl::queue &q, oneapi::mkl::transpose a_trans,
|
|
|
- oneapi::mkl::transpose b_trans, int m, int n, int k,
|
|
|
- const void *alpha, const void **a, int lda,
|
|
|
- const void **b, int ldb, const void *beta, void **c,
|
|
|
- int ldc, int batch_size)
|
|
|
- {
|
|
|
- struct matrix_info_t
|
|
|
- {
|
|
|
- oneapi::mkl::transpose transpose_info[2];
|
|
|
- Ts value_info[2];
|
|
|
- std::int64_t size_info[3];
|
|
|
- std::int64_t ld_info[3];
|
|
|
- std::int64_t groupsize_info;
|
|
|
- };
|
|
|
-
|
|
|
+ inline void gemm_batch_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans,
|
|
|
+ int m, int n, int k, const void * alpha, const void ** a, int lda, const void ** b,
|
|
|
+ int ldb, const void * beta, void ** c, int ldc, int batch_size,
|
|
|
+ matrix_info_t<float> * matrix_info) {
|
|
|
Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
|
|
|
Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
|
|
|
|
|
|
- matrix_info_t *matrix_info =
|
|
|
- (matrix_info_t *)std::malloc(sizeof(matrix_info_t));
|
|
|
matrix_info->transpose_info[0] = a_trans;
|
|
|
matrix_info->transpose_info[1] = b_trans;
|
|
|
matrix_info->value_info[0] = alpha_value;
|
|
|
@@ -1763,23 +1758,18 @@ namespace dpct
|
|
|
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
|
|
|
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, matrix_info->transpose_info,
|
|
|
matrix_info->transpose_info + 1, matrix_info->size_info, matrix_info->size_info + 1,
|
|
|
- matrix_info->size_info + 2, matrix_info->value_info, reinterpret_cast<const Ta **>(a),
|
|
|
- matrix_info->ld_info, reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1,
|
|
|
- matrix_info->value_info + 1, reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1,
|
|
|
- &(matrix_info->groupsize_info));
|
|
|
+ matrix_info->size_info + 2, reinterpret_cast<Ts *>(matrix_info->value_info),
|
|
|
+ reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b),
|
|
|
+ matrix_info->ld_info + 1, reinterpret_cast<Ts *>(matrix_info->value_info + 1),
|
|
|
+ reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
|
|
|
#else
|
|
|
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
|
|
|
q, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info,
|
|
|
- matrix_info->size_info + 1, matrix_info->size_info + 2, matrix_info->value_info,
|
|
|
+ matrix_info->size_info + 1, matrix_info->size_info + 2, reinterpret_cast<Ts *>(matrix_info->value_info),
|
|
|
reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b),
|
|
|
- matrix_info->ld_info + 1, matrix_info->value_info + 1, reinterpret_cast<Tc **>(c),
|
|
|
- matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
|
|
|
+ matrix_info->ld_info + 1, reinterpret_cast<Ts *>(matrix_info->value_info + 1),
|
|
|
+ reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
|
|
|
#endif
|
|
|
-
|
|
|
- q.submit([&](sycl::handler &cgh)
|
|
|
- {
|
|
|
- cgh.depends_on(e);
|
|
|
- cgh.host_task([=] { std::free(matrix_info); }); });
|
|
|
}
|
|
|
|
|
|
template <class Ta, class Tb, class Tc, class Ts>
|
|
|
@@ -2422,25 +2412,11 @@ namespace dpct
|
|
|
/// \param [in] ldc Leading dimension of C.
|
|
|
/// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
|
|
|
/// \param [in] scaling_type Data type of the scaling factors.
|
|
|
- inline void gemm_batch(sycl::queue &q, oneapi::mkl::transpose a_trans,
|
|
|
- oneapi::mkl::transpose b_trans, int m, int n, int k,
|
|
|
- const void *alpha, const void *a[],
|
|
|
- library_data_t a_type, int lda, const void *b[],
|
|
|
- library_data_t b_type, int ldb, const void *beta,
|
|
|
- void *c[], library_data_t c_type, int ldc,
|
|
|
- int batch_size, library_data_t scaling_type)
|
|
|
- {
|
|
|
- if (scaling_type == library_data_t::real_float &&
|
|
|
- c_type == library_data_t::complex_float)
|
|
|
- {
|
|
|
- scaling_type = library_data_t::complex_float;
|
|
|
- }
|
|
|
- else if (scaling_type == library_data_t::real_double &&
|
|
|
- c_type == library_data_t::complex_double)
|
|
|
- {
|
|
|
- scaling_type = library_data_t::complex_double;
|
|
|
- }
|
|
|
-
|
|
|
+ inline void gemm_batch(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m,
|
|
|
+ int n, int k, const void * alpha, const void * a[], library_data_t a_type, int lda,
|
|
|
+ const void * b[], library_data_t b_type, int ldb, const void * beta, void * c[],
|
|
|
+ library_data_t c_type, int ldc, int batch_size, library_data_t scaling_type,
|
|
|
+ matrix_info_t<float> * matrix_info) {
|
|
|
std::uint64_t key =
|
|
|
detail::get_type_combination_id(a_type, b_type, c_type, scaling_type);
|
|
|
switch (key)
|
|
|
@@ -2449,48 +2425,24 @@ namespace dpct
|
|
|
library_data_t::real_float, library_data_t::real_float,
|
|
|
library_data_t::real_float, library_data_t::real_float):
|
|
|
{
|
|
|
- detail::gemm_batch_impl<float, float, float, float>(
|
|
|
- q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
|
|
|
- batch_size);
|
|
|
+ detail::gemm_batch_impl<float, float, float, float>(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb,
|
|
|
+ beta, c, ldc, batch_size, matrix_info);
|
|
|
break;
|
|
|
}
|
|
|
case detail::get_type_combination_id(
|
|
|
library_data_t::real_double, library_data_t::real_double,
|
|
|
library_data_t::real_double, library_data_t::real_double):
|
|
|
{
|
|
|
- detail::gemm_batch_impl<double, double, double, double>(
|
|
|
- q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
|
|
|
- batch_size);
|
|
|
- break;
|
|
|
- }
|
|
|
- case detail::get_type_combination_id(
|
|
|
- library_data_t::complex_float, library_data_t::complex_float,
|
|
|
- library_data_t::complex_float, library_data_t::complex_float):
|
|
|
- {
|
|
|
- detail::gemm_batch_impl<std::complex<float>, std::complex<float>,
|
|
|
- std::complex<float>, std::complex<float>>(
|
|
|
- q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
|
|
|
- batch_size);
|
|
|
- break;
|
|
|
- }
|
|
|
- case detail::get_type_combination_id(
|
|
|
- library_data_t::complex_double, library_data_t::complex_double,
|
|
|
- library_data_t::complex_double, library_data_t::complex_double):
|
|
|
- {
|
|
|
- detail::gemm_batch_impl<std::complex<double>, std::complex<double>,
|
|
|
- std::complex<double>, std::complex<double>>(
|
|
|
- q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
|
|
|
- batch_size);
|
|
|
+ detail::gemm_batch_impl<double, double, double, double>(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb,
|
|
|
+ beta, c, ldc, batch_size, matrix_info);
|
|
|
break;
|
|
|
}
|
|
|
case detail::get_type_combination_id(
|
|
|
library_data_t::real_half, library_data_t::real_half,
|
|
|
library_data_t::real_half, library_data_t::real_half):
|
|
|
{
|
|
|
- detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half,
|
|
|
- sycl::half>(q, a_trans, b_trans, m, n, k, alpha,
|
|
|
- a, lda, b, ldb, beta, c, ldc,
|
|
|
- batch_size);
|
|
|
+ detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
|
|
|
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
|
|
|
break;
|
|
|
}
|
|
|
#ifdef __INTEL_MKL__
|
|
|
@@ -2498,19 +2450,16 @@ namespace dpct
|
|
|
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
|
|
library_data_t::real_bfloat16, library_data_t::real_float):
|
|
|
{
|
|
|
- detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16,
|
|
|
- oneapi::mkl::bfloat16, float>(
|
|
|
- q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
|
|
|
- batch_size);
|
|
|
+ detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float>(
|
|
|
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
|
|
|
break;
|
|
|
}
|
|
|
case detail::get_type_combination_id(
|
|
|
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
|
|
library_data_t::real_float, library_data_t::real_float):
|
|
|
{
|
|
|
- detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float,
|
|
|
- float>(q, a_trans, b_trans, m, n, k, alpha, a, lda,
|
|
|
- b, ldb, beta, c, ldc, batch_size);
|
|
|
+ detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float, float>(
|
|
|
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
|
|
|
break;
|
|
|
}
|
|
|
#endif
|
|
|
@@ -2522,10 +2471,9 @@ namespace dpct
|
|
|
dpct::get_value(reinterpret_cast<const std::int32_t *>(alpha), q);
|
|
|
float beta_float =
|
|
|
dpct::get_value(reinterpret_cast<const std::int32_t *>(beta), q);
|
|
|
- detail::gemm_batch_impl<std::int8_t, std::int8_t, std::int32_t,
|
|
|
- float>(q, a_trans, b_trans, m, n, k, &alpha_float,
|
|
|
- a, lda, b, ldb, &beta_float, c, ldc,
|
|
|
- batch_size);
|
|
|
+ detail::gemm_batch_impl<std::int8_t, std::int8_t, std::int32_t, float>(
|
|
|
+ q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc, batch_size,
|
|
|
+ matrix_info);
|
|
|
break;
|
|
|
}
|
|
|
case detail::get_type_combination_id(
|
|
|
@@ -2533,8 +2481,7 @@ namespace dpct
|
|
|
library_data_t::real_float, library_data_t::real_float):
|
|
|
{
|
|
|
detail::gemm_batch_impl<std::int8_t, std::int8_t, float, float>(
|
|
|
- q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
|
|
|
- batch_size);
|
|
|
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
|
|
|
break;
|
|
|
}
|
|
|
case detail::get_type_combination_id(
|
|
|
@@ -2542,8 +2489,7 @@ namespace dpct
|
|
|
library_data_t::real_float, library_data_t::real_float):
|
|
|
{
|
|
|
detail::gemm_batch_impl<sycl::half, sycl::half, float, float>(
|
|
|
- q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
|
|
|
- batch_size);
|
|
|
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
|
|
|
break;
|
|
|
}
|
|
|
case detail::get_type_combination_id(
|
|
|
@@ -2557,8 +2503,7 @@ namespace dpct
|
|
|
sycl::half alpha_half(alpha_value);
|
|
|
sycl::half beta_half(beta_value);
|
|
|
detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
|
|
|
- q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc,
|
|
|
- batch_size);
|
|
|
+ q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc, batch_size, matrix_info);
|
|
|
break;
|
|
|
}
|
|
|
default:
|