|
|
@@ -2,6 +2,13 @@
|
|
|
#include "dequantize.hpp"
|
|
|
#include "presets.hpp"
|
|
|
|
|
|
+#if defined(__INTEL_LLVM_COMPILER)
|
|
|
+ #if __has_include(<sycl/ext/oneapi/bfloat16.hpp>)
|
|
|
+ #include <sycl/ext/oneapi/bfloat16.hpp>
|
|
|
+ #define GGML_SYCL_HAS_BF16
|
|
|
+ #endif
|
|
|
+#endif
|
|
|
+
|
|
|
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
|
|
static void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k,
|
|
|
const sycl::nd_item<3> &item_ct1) {
|
|
|
@@ -566,6 +573,10 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) {
|
|
|
return dequantize_row_iq4_nl_sycl;
|
|
|
case GGML_TYPE_F32:
|
|
|
return convert_unary_sycl<float>;
|
|
|
+#ifdef GGML_SYCL_HAS_BF16
|
|
|
+ case GGML_TYPE_BF16:
|
|
|
+ return convert_unary_sycl<sycl::ext::oneapi::bfloat16>;
|
|
|
+#endif
|
|
|
default:
|
|
|
return nullptr;
|
|
|
}
|
|
|
@@ -627,6 +638,10 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) {
|
|
|
return dequantize_row_iq4_nl_sycl;
|
|
|
case GGML_TYPE_F16:
|
|
|
return convert_unary_sycl<sycl::half>;
|
|
|
+#ifdef GGML_SYCL_HAS_BF16
|
|
|
+ case GGML_TYPE_BF16:
|
|
|
+ return convert_unary_sycl<sycl::ext::oneapi::bfloat16>;
|
|
|
+#endif
|
|
|
default:
|
|
|
return nullptr;
|
|
|
}
|
|
|
@@ -636,6 +651,10 @@ to_fp16_nc_sycl_t get_to_fp16_nc_sycl(ggml_type type) {
|
|
|
switch (type) {
|
|
|
case GGML_TYPE_F32:
|
|
|
return convert_unary_nc_sycl<float>;
|
|
|
+#ifdef GGML_SYCL_HAS_BF16
|
|
|
+ case GGML_TYPE_BF16:
|
|
|
+ return convert_unary_nc_sycl<sycl::ext::oneapi::bfloat16>;
|
|
|
+#endif
|
|
|
default:
|
|
|
return nullptr;
|
|
|
}
|