|
|
@@ -16,9 +16,18 @@
|
|
|
#include <sycl/sycl.hpp>
|
|
|
#include <sycl/half_type.hpp>
|
|
|
#include <syclcompat/math.hpp>
|
|
|
-#include <oneapi/mkl.hpp>
|
|
|
#include <map>
|
|
|
|
|
|
+#ifdef GGML_SYCL_USE_INTEL_ONEMKL
|
|
|
+#include <oneapi/mkl.hpp>
|
|
|
+// Allow to use the same namespace for Intel oneMKL and oneMath
|
|
|
+namespace oneapi {
|
|
|
+ namespace math = mkl;
|
|
|
+}
|
|
|
+#else
|
|
|
+#include <oneapi/math.hpp>
|
|
|
+#endif
|
|
|
+
|
|
|
#include "ggml.h"
|
|
|
|
|
|
#if defined(__linux__)
|
|
|
@@ -83,13 +92,32 @@ inline std::string get_device_backend_and_type(const sycl::device &device) {
|
|
|
}
|
|
|
|
|
|
template <typename Ts> struct matrix_info_t {
|
|
|
- oneapi::mkl::transpose transpose_info[2];
|
|
|
+ oneapi::math::transpose transpose_info[2];
|
|
|
Ts value_info[2];
|
|
|
std::int64_t size_info[3];
|
|
|
std::int64_t ld_info[3];
|
|
|
std::int64_t groupsize_info;
|
|
|
};
|
|
|
|
|
|
+inline auto get_onemath_backend(sycl::queue& queue)
|
|
|
+#if defined(GGML_SYCL_GENERIC) || defined(GGML_SYCL_USE_INTEL_ONEMKL)
|
|
|
+ -> sycl::queue&
|
|
|
+#endif
|
|
|
+{
|
|
|
+// If the backend is known at compile-time, use oneMath backend_selector to use
|
|
|
+// compile-time dispatching and avoid the need to dlopen libraries. Otherwise
|
|
|
+// fallback to runtime dispatching.
|
|
|
+#if defined(GGML_SYCL_NVIDIA)
|
|
|
+ return oneapi::math::backend_selector<oneapi::math::backend::cublas>{ queue };
|
|
|
+#elif defined(GGML_SYCL_AMD)
|
|
|
+ return oneapi::math::backend_selector<oneapi::math::backend::rocblas>{ queue };
|
|
|
+#elif defined(GGML_SYCL_GENERIC) || defined(GGML_SYCL_USE_INTEL_ONEMKL)
|
|
|
+ return queue;
|
|
|
+#else
|
|
|
+ static_assert(false, "Unsupported backend");
|
|
|
+#endif
|
|
|
+}
|
|
|
+
|
|
|
namespace dpct
|
|
|
{
|
|
|
typedef sycl::queue *queue_ptr;
|
|
|
@@ -1686,26 +1714,18 @@ namespace dpct
|
|
|
|
|
|
namespace detail
|
|
|
{
|
|
|
- template <class Ta, class Tb, class Tc, class Ts>
|
|
|
- inline void gemm_impl(sycl::queue &q, oneapi::mkl::transpose a_trans,
|
|
|
- oneapi::mkl::transpose b_trans, int m, int n, int k,
|
|
|
- const void *alpha, const void *a, int lda, const void *b,
|
|
|
- int ldb, const void *beta, void *c, int ldc)
|
|
|
- {
|
|
|
- Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
|
|
|
- Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
|
|
|
- auto data_a = get_memory<const Ta>(a);
|
|
|
- auto data_b = get_memory<const Tb>(b);
|
|
|
- auto data_c = get_memory<Tc>(c);
|
|
|
-#ifdef GGML_SYCL_NVIDIA
|
|
|
- oneapi::mkl::blas::column_major::gemm(oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q },
|
|
|
- a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb,
|
|
|
- beta_value, data_c, ldc);
|
|
|
-#else
|
|
|
- oneapi::mkl::blas::column_major::gemm(q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb,
|
|
|
- beta_value, data_c, ldc);
|
|
|
-#endif
|
|
|
- }
|
|
|
+ template <class Ta, class Tb, class Tc, class Ts>
|
|
|
+ inline void gemm_impl(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m,
|
|
|
+ int n, int k, const void * alpha, const void * a, int lda, const void * b, int ldb,
|
|
|
+ const void * beta, void * c, int ldc) {
|
|
|
+ Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
|
|
|
+ Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
|
|
|
+ auto data_a = get_memory<const Ta>(a);
|
|
|
+ auto data_b = get_memory<const Tb>(b);
|
|
|
+ auto data_c = get_memory<Tc>(c);
|
|
|
+ oneapi::math::blas::column_major::gemm(get_onemath_backend(q), a_trans, b_trans, m, n, k, alpha_value, data_a,
|
|
|
+ lda, data_b, ldb, beta_value, data_c, ldc);
|
|
|
+ }
|
|
|
|
|
|
template <typename VecT, class BinaryOperation, class = void>
|
|
|
class vectorized_binary
|
|
|
@@ -1735,7 +1755,7 @@ namespace dpct
|
|
|
};
|
|
|
|
|
|
template <class Ta, class Tb, class Tc, class Ts>
|
|
|
- inline void gemm_batch_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans,
|
|
|
+ inline void gemm_batch_impl(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans,
|
|
|
int m, int n, int k, const void * alpha, const void ** a, int lda, const void ** b,
|
|
|
int ldb, const void * beta, void ** c, int ldc, int batch_size,
|
|
|
matrix_info_t<float> * matrix_info) {
|
|
|
@@ -1754,48 +1774,28 @@ namespace dpct
|
|
|
matrix_info->ld_info[2] = ldc;
|
|
|
matrix_info->groupsize_info = batch_size;
|
|
|
|
|
|
-#ifdef GGML_SYCL_NVIDIA
|
|
|
- sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
|
|
|
- oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, matrix_info->transpose_info,
|
|
|
- matrix_info->transpose_info + 1, matrix_info->size_info, matrix_info->size_info + 1,
|
|
|
- matrix_info->size_info + 2, reinterpret_cast<Ts *>(matrix_info->value_info),
|
|
|
- reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b),
|
|
|
- matrix_info->ld_info + 1, reinterpret_cast<Ts *>(matrix_info->value_info + 1),
|
|
|
- reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
|
|
|
-#else
|
|
|
- sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
|
|
|
- q, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info,
|
|
|
- matrix_info->size_info + 1, matrix_info->size_info + 2, reinterpret_cast<Ts *>(matrix_info->value_info),
|
|
|
- reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b),
|
|
|
- matrix_info->ld_info + 1, reinterpret_cast<Ts *>(matrix_info->value_info + 1),
|
|
|
- reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
|
|
|
-#endif
|
|
|
+ sycl::event e = oneapi::math::blas::column_major::gemm_batch(
|
|
|
+ get_onemath_backend(q), matrix_info->transpose_info, matrix_info->transpose_info + 1,
|
|
|
+ matrix_info->size_info, matrix_info->size_info + 1, matrix_info->size_info + 2,
|
|
|
+ reinterpret_cast<Ts *>(matrix_info->value_info), reinterpret_cast<const Ta **>(a), matrix_info->ld_info,
|
|
|
+ reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1,
|
|
|
+ reinterpret_cast<Ts *>(matrix_info->value_info + 1), reinterpret_cast<Tc **>(c),
|
|
|
+ matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
|
|
|
}
|
|
|
|
|
|
template <class Ta, class Tb, class Tc, class Ts>
|
|
|
- inline void
|
|
|
- gemm_batch_impl(sycl::queue &q, oneapi::mkl::transpose a_trans,
|
|
|
- oneapi::mkl::transpose b_trans, int m, int n,
|
|
|
- int k, const void *alpha, const void *a, int lda,
|
|
|
- long long int stride_a, const void *b, int ldb,
|
|
|
- long long int stride_b, const void *beta, void *c,
|
|
|
- int ldc, long long int stride_c, int batch_size)
|
|
|
- {
|
|
|
+ inline void gemm_batch_impl(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans,
|
|
|
+ int m, int n, int k, const void * alpha, const void * a, int lda,
|
|
|
+ long long int stride_a, const void * b, int ldb, long long int stride_b,
|
|
|
+ const void * beta, void * c, int ldc, long long int stride_c, int batch_size) {
|
|
|
Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
|
|
|
Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
|
|
|
auto data_a = get_memory<const Ta>(a);
|
|
|
auto data_b = get_memory<const Tb>(b);
|
|
|
auto data_c = get_memory<Tc>(c);
|
|
|
-#ifdef GGML_SYCL_NVIDIA
|
|
|
- oneapi::mkl::blas::column_major::gemm_batch(
|
|
|
- oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, a_trans, b_trans, m, n, k,
|
|
|
- alpha_value, data_a, lda, stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc, stride_c,
|
|
|
- batch_size);
|
|
|
-#else
|
|
|
- oneapi::mkl::blas::column_major::gemm_batch(q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
|
|
|
- stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc,
|
|
|
- stride_c, batch_size);
|
|
|
-#endif
|
|
|
+ oneapi::math::blas::column_major::gemm_batch(get_onemath_backend(q), a_trans, b_trans, m, n, k, alpha_value,
|
|
|
+ data_a, lda, stride_a, data_b, ldb, stride_b, beta_value,
|
|
|
+ data_c, ldc, stride_c, batch_size);
|
|
|
}
|
|
|
|
|
|
} // namespace detail
|
|
|
@@ -2259,13 +2259,10 @@ namespace dpct
|
|
|
sycl::range<3>(x, y, 1), direction);
|
|
|
}
|
|
|
|
|
|
- inline void gemm(sycl::queue &q, oneapi::mkl::transpose a_trans,
|
|
|
- oneapi::mkl::transpose b_trans, int m, int n, int k,
|
|
|
- const void *alpha, const void *a, library_data_t a_type,
|
|
|
- int lda, const void *b, library_data_t b_type, int ldb,
|
|
|
- const void *beta, void *c, library_data_t c_type, int ldc,
|
|
|
- library_data_t scaling_type)
|
|
|
- {
|
|
|
+ inline void gemm(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m, int n,
|
|
|
+ int k, const void * alpha, const void * a, library_data_t a_type, int lda, const void * b,
|
|
|
+ library_data_t b_type, int ldb, const void * beta, void * c, library_data_t c_type, int ldc,
|
|
|
+ library_data_t scaling_type) {
|
|
|
if (scaling_type == library_data_t::real_float &&
|
|
|
c_type == library_data_t::complex_float)
|
|
|
{
|
|
|
@@ -2329,9 +2326,8 @@ namespace dpct
|
|
|
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
|
|
library_data_t::real_float, library_data_t::real_float):
|
|
|
{
|
|
|
- detail::gemm_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float,
|
|
|
- float>(q, a_trans, b_trans, m, n, k, alpha, a, lda, b,
|
|
|
- ldb, beta, c, ldc);
|
|
|
+ detail::gemm_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, float, float>(
|
|
|
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
|
|
|
break;
|
|
|
}
|
|
|
case detail::get_type_combination_id(
|
|
|
@@ -2369,8 +2365,7 @@ namespace dpct
|
|
|
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
|
|
library_data_t::real_bfloat16, library_data_t::real_float):
|
|
|
{
|
|
|
- detail::gemm_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16,
|
|
|
- oneapi::mkl::bfloat16, float>(
|
|
|
+ detail::gemm_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, oneapi::math::bfloat16, float>(
|
|
|
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
|
|
|
break;
|
|
|
}
|
|
|
@@ -2390,7 +2385,7 @@ namespace dpct
|
|
|
default:
|
|
|
throw std::runtime_error("the combination of data type is unsupported");
|
|
|
}
|
|
|
- } // gemm()
|
|
|
+ } // gemm()
|
|
|
|
|
|
/// Computes a batch of matrix-matrix product with general matrices.
|
|
|
/// \param [in] q The queue where the routine should be executed.
|
|
|
@@ -2412,7 +2407,7 @@ namespace dpct
|
|
|
/// \param [in] ldc Leading dimension of C.
|
|
|
/// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
|
|
|
/// \param [in] scaling_type Data type of the scaling factors.
|
|
|
- inline void gemm_batch(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m,
|
|
|
+ inline void gemm_batch(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m,
|
|
|
int n, int k, const void * alpha, const void * a[], library_data_t a_type, int lda,
|
|
|
const void * b[], library_data_t b_type, int ldb, const void * beta, void * c[],
|
|
|
library_data_t c_type, int ldc, int batch_size, library_data_t scaling_type,
|
|
|
@@ -2450,7 +2445,7 @@ namespace dpct
|
|
|
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
|
|
library_data_t::real_bfloat16, library_data_t::real_float):
|
|
|
{
|
|
|
- detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float>(
|
|
|
+ detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, oneapi::math::bfloat16, float>(
|
|
|
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
|
|
|
break;
|
|
|
}
|
|
|
@@ -2458,7 +2453,7 @@ namespace dpct
|
|
|
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
|
|
library_data_t::real_float, library_data_t::real_float):
|
|
|
{
|
|
|
- detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float, float>(
|
|
|
+ detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, float, float>(
|
|
|
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
|
|
|
break;
|
|
|
}
|
|
|
@@ -2534,15 +2529,11 @@ namespace dpct
|
|
|
/// \param [in] stride_c Stride between the different C matrices.
|
|
|
/// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
|
|
|
/// \param [in] scaling_type Data type of the scaling factors.
|
|
|
- inline void gemm_batch(sycl::queue &q, oneapi::mkl::transpose a_trans,
|
|
|
- oneapi::mkl::transpose b_trans, int m, int n, int k,
|
|
|
- const void *alpha, const void *a, library_data_t a_type,
|
|
|
- int lda, long long int stride_a, const void *b,
|
|
|
- library_data_t b_type, int ldb, long long int stride_b,
|
|
|
- const void *beta, void *c, library_data_t c_type,
|
|
|
- int ldc, long long int stride_c, int batch_size,
|
|
|
- library_data_t scaling_type)
|
|
|
- {
|
|
|
+ inline void gemm_batch(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m,
|
|
|
+ int n, int k, const void * alpha, const void * a, library_data_t a_type, int lda,
|
|
|
+ long long int stride_a, const void * b, library_data_t b_type, int ldb,
|
|
|
+ long long int stride_b, const void * beta, void * c, library_data_t c_type, int ldc,
|
|
|
+ long long int stride_c, int batch_size, library_data_t scaling_type) {
|
|
|
if (scaling_type == library_data_t::real_float &&
|
|
|
c_type == library_data_t::complex_float)
|
|
|
{
|
|
|
@@ -2611,20 +2602,18 @@ namespace dpct
|
|
|
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
|
|
library_data_t::real_bfloat16, library_data_t::real_float):
|
|
|
{
|
|
|
- detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16,
|
|
|
- oneapi::mkl::bfloat16, float>(
|
|
|
- q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b,
|
|
|
- beta, c, ldc, stride_c, batch_size);
|
|
|
+ detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, oneapi::math::bfloat16, float>(
|
|
|
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c,
|
|
|
+ batch_size);
|
|
|
break;
|
|
|
}
|
|
|
case detail::get_type_combination_id(
|
|
|
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
|
|
library_data_t::real_float, library_data_t::real_float):
|
|
|
{
|
|
|
- detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float,
|
|
|
- float>(q, a_trans, b_trans, m, n, k, alpha, a, lda,
|
|
|
- stride_a, b, ldb, stride_b, beta, c, ldc,
|
|
|
- stride_c, batch_size);
|
|
|
+ detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, float, float>(
|
|
|
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c,
|
|
|
+ batch_size);
|
|
|
break;
|
|
|
}
|
|
|
#endif
|