|
@@ -2694,35 +2694,31 @@ catch (sycl::exception const &exc) {
|
|
|
std::exit(1);
|
|
std::exit(1);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-static void k_compute_batched_ptrs(const sycl::half *src0_as_f16,
|
|
|
|
|
- const sycl::half *src1_as_f16, char *dst,
|
|
|
|
|
- const void **ptrs_src, void **ptrs_dst,
|
|
|
|
|
- int64_t ne12, int64_t ne13, int64_t ne23,
|
|
|
|
|
- size_t nb02, size_t nb03, size_t nb12,
|
|
|
|
|
- size_t nb13, size_t nbd2, size_t nbd3,
|
|
|
|
|
- int64_t r2, int64_t r3,
|
|
|
|
|
- const sycl::nd_item<3> &item_ct1) {
|
|
|
|
|
- int64_t i13 = item_ct1.get_group(2) * item_ct1.get_local_range(2) +
|
|
|
|
|
- item_ct1.get_local_id(2);
|
|
|
|
|
- int64_t i12 = item_ct1.get_group(1) * item_ct1.get_local_range(1) +
|
|
|
|
|
- item_ct1.get_local_id(1);
|
|
|
|
|
|
|
+static void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::half * src1_as_f16, char * dst,
|
|
|
|
|
+ const void ** ptrs_src, void ** ptrs_dst, int64_t ne12, int64_t ne13, int64_t ne23,
|
|
|
|
|
+ size_t nb02, size_t nb03, size_t nb12, size_t nb13, size_t nbd2, size_t nbd3,
|
|
|
|
|
+ int64_t r2, int64_t r3, const sycl::nd_item<3> & item_ct1) {
|
|
|
|
|
+ const int64_t i13 = item_ct1.get_group(2) * item_ct1.get_local_range(2) + item_ct1.get_local_id(2);
|
|
|
|
|
+ const int64_t i12 = item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1);
|
|
|
|
|
|
|
|
if (i13 >= ne13 || i12 >= ne12) {
|
|
if (i13 >= ne13 || i12 >= ne12) {
|
|
|
return;
|
|
return;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- int64_t i03 = i13 / r3;
|
|
|
|
|
- int64_t i02 = i12 / r2;
|
|
|
|
|
|
|
+ const int64_t i03 = i13 / r3;
|
|
|
|
|
+ const int64_t i02 = i12 / r2;
|
|
|
|
|
+
|
|
|
|
|
+ const uint8_t * src0_bytes = reinterpret_cast<const uint8_t *>(src0_as_f16);
|
|
|
|
|
+ const uint8_t * src1_bytes = reinterpret_cast<const uint8_t *>(src1_as_f16);
|
|
|
|
|
+ uint8_t * dst_bytes = reinterpret_cast<uint8_t *>(dst);
|
|
|
|
|
|
|
|
- ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03;
|
|
|
|
|
- ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12 + i13*nb13;
|
|
|
|
|
- ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3;
|
|
|
|
|
|
|
+ ptrs_src[0 * ne23 + i12 + i13 * ne12] = src0_bytes + i02 * nb02 + i03 * nb03;
|
|
|
|
|
+ ptrs_src[1 * ne23 + i12 + i13 * ne12] = src1_bytes + i12 * nb12 + i13 * nb13;
|
|
|
|
|
+ ptrs_dst[0 * ne23 + i12 + i13 * ne12] = dst_bytes + i12 * nbd2 + i13 * nbd3;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
|
|
|
|
|
- const ggml_tensor *src0,
|
|
|
|
|
- const ggml_tensor *src1,
|
|
|
|
|
- ggml_tensor *dst) try {
|
|
|
|
|
|
|
+static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor * src0,
|
|
|
|
|
+ const ggml_tensor * src1, ggml_tensor * dst) try {
|
|
|
GGML_ASSERT(!ggml_is_transposed(src0));
|
|
GGML_ASSERT(!ggml_is_transposed(src0));
|
|
|
GGML_ASSERT(!ggml_is_transposed(src1));
|
|
GGML_ASSERT(!ggml_is_transposed(src1));
|
|
|
GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer));
|
|
GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer));
|
|
@@ -2730,102 +2726,100 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
|
|
|
|
|
|
|
|
GGML_TENSOR_BINARY_OP_LOCALS
|
|
GGML_TENSOR_BINARY_OP_LOCALS
|
|
|
|
|
|
|
|
|
|
+ // TODO: see https://github.com/ggml-org/llama.cpp/pull/13155
|
|
|
|
|
+ // Batched mul_mat requires a rewrite to support both oneDNN and non-contiguous dst
|
|
|
|
|
+ GGML_ASSERT(ggml_is_contiguous(dst));
|
|
|
|
|
|
|
|
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
|
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
|
|
- queue_ptr main_stream = ctx.stream();;
|
|
|
|
|
|
|
+ queue_ptr queue = ctx.stream();
|
|
|
|
|
|
|
|
- void * src0_ddq = src0->data;
|
|
|
|
|
- sycl::half *src0_as_f16 = (sycl::half *)src0_ddq;
|
|
|
|
|
- float * src1_ddf = (float *) src1->data;
|
|
|
|
|
- float * dst_ddf = (float *) dst->data;
|
|
|
|
|
|
|
+ dpct::has_capability_or_fail(queue->get_device(), { sycl::aspect::fp16 });
|
|
|
|
|
|
|
|
- // convert src1 to fp16
|
|
|
|
|
|
|
+ const sycl::half * src0_f16 = static_cast<const sycl::half *>(src0->data);
|
|
|
|
|
+ float * dst_ddf = static_cast<float *>(dst->data);
|
|
|
|
|
+
|
|
|
|
|
+ const sycl::half * src1_f16 = static_cast<const sycl::half *>(src1->data);
|
|
|
|
|
+ const size_t type_size_src1 = ggml_type_size(src1->type);
|
|
|
|
|
+ GGML_ASSERT(nb10 == type_size_src1);
|
|
|
|
|
+
|
|
|
|
|
+ // SRC1 strides
|
|
|
|
|
+ int64_t s11 = nb11 / type_size_src1;
|
|
|
|
|
+ int64_t s12 = nb12 / type_size_src1;
|
|
|
|
|
+ int64_t s13 = nb13 / type_size_src1;
|
|
|
ggml_sycl_pool_alloc<sycl::half> src1_f16_alloc(ctx.pool());
|
|
ggml_sycl_pool_alloc<sycl::half> src1_f16_alloc(ctx.pool());
|
|
|
|
|
+
|
|
|
|
|
+ // convert src1 to fp16
|
|
|
if (src1->type != GGML_TYPE_F16) {
|
|
if (src1->type != GGML_TYPE_F16) {
|
|
|
- const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst);
|
|
|
|
|
|
|
+ const to_fp16_nc_sycl_t to_fp16_nc_sycl = get_to_fp16_nc_sycl(src1->type);
|
|
|
|
|
+ GGML_ASSERT(to_fp16_nc_sycl != nullptr);
|
|
|
const int64_t ne_src1 = ggml_nelements(src1);
|
|
const int64_t ne_src1 = ggml_nelements(src1);
|
|
|
src1_f16_alloc.alloc(ne_src1);
|
|
src1_f16_alloc.alloc(ne_src1);
|
|
|
- GGML_ASSERT(to_fp16_sycl != nullptr);
|
|
|
|
|
- to_fp16_sycl(src1_ddf, src1_f16_alloc.get(), ne_src1, main_stream);
|
|
|
|
|
|
|
+ to_fp16_nc_sycl(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, queue);
|
|
|
|
|
+
|
|
|
|
|
+ src1_f16 = src1_f16_alloc.get();
|
|
|
|
|
+ s11 = ne10;
|
|
|
|
|
+ s12 = ne11 * s11;
|
|
|
|
|
+ s13 = ne12 * s12;
|
|
|
}
|
|
}
|
|
|
- sycl::half *src1_f16 = src1->type == GGML_TYPE_F16 ? (sycl::half *)src1_ddf
|
|
|
|
|
- : src1_f16_alloc.get();
|
|
|
|
|
|
|
|
|
|
- char * dst_t;
|
|
|
|
|
|
|
+ ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool());
|
|
|
|
|
+ char * dst_t = reinterpret_cast<char *>(dst_ddf);
|
|
|
|
|
|
|
|
- dpct::library_data_t cu_compute_type = dpct::library_data_t::real_float;
|
|
|
|
|
- dpct::library_data_t cu_data_type = dpct::library_data_t::real_float;
|
|
|
|
|
|
|
+ dpct::library_data_t mkl_compute_type = dpct::library_data_t::real_float;
|
|
|
|
|
+ dpct::library_data_t mkl_data_type = dpct::library_data_t::real_float;
|
|
|
|
|
|
|
|
// dst strides
|
|
// dst strides
|
|
|
size_t nbd2 = dst->nb[2];
|
|
size_t nbd2 = dst->nb[2];
|
|
|
size_t nbd3 = dst->nb[3];
|
|
size_t nbd3 = dst->nb[3];
|
|
|
|
|
|
|
|
const float alpha_f32 = 1.0f;
|
|
const float alpha_f32 = 1.0f;
|
|
|
- const float beta_f32 = 0.0f;
|
|
|
|
|
|
|
+ const float beta_f32 = 0.0f;
|
|
|
|
|
|
|
|
const void * alpha = &alpha_f32;
|
|
const void * alpha = &alpha_f32;
|
|
|
const void * beta = &beta_f32;
|
|
const void * beta = &beta_f32;
|
|
|
|
|
|
|
|
- dst_t = (char *) dst_ddf;
|
|
|
|
|
-
|
|
|
|
|
GGML_ASSERT(ne12 % ne02 == 0);
|
|
GGML_ASSERT(ne12 % ne02 == 0);
|
|
|
GGML_ASSERT(ne13 % ne03 == 0);
|
|
GGML_ASSERT(ne13 % ne03 == 0);
|
|
|
|
|
|
|
|
// broadcast factors
|
|
// broadcast factors
|
|
|
- const int64_t r2 = ne12/ne02;
|
|
|
|
|
- const int64_t r3 = ne13/ne03;
|
|
|
|
|
|
|
+ const int64_t r2 = ne12 / ne02;
|
|
|
|
|
+ const int64_t r3 = ne13 / ne03;
|
|
|
|
|
|
|
|
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
|
|
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
|
|
|
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
|
|
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
|
|
|
- SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
|
|
|
|
|
- *main_stream, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
|
|
|
|
|
- (const char *) src0_as_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00,
|
|
|
|
|
- (const char *) src1_f16, dpct::library_data_t::real_half, nb11 / nb10, nb12 / nb10, beta, (char *) dst_t,
|
|
|
|
|
- cu_data_type, ne01, nb2 / nb0, ne12 * ne13, cu_compute_type)));
|
|
|
|
|
|
|
+ SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans,
|
|
|
|
|
+ oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
|
|
|
|
|
+ src0_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00,
|
|
|
|
|
+ src1_f16, dpct::library_data_t::real_half, s11, s12, beta, dst_t,
|
|
|
|
|
+ mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type)));
|
|
|
} else {
|
|
} else {
|
|
|
- const int ne23 = ne12*ne13;
|
|
|
|
|
|
|
+ const int ne23 = ne12 * ne13;
|
|
|
|
|
|
|
|
- ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
|
|
|
|
|
- ggml_sycl_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23);
|
|
|
|
|
|
|
+ ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2 * ne23);
|
|
|
|
|
+ ggml_sycl_pool_alloc<void *> ptrs_dst(ctx.pool(), 1 * ne23);
|
|
|
ggml_sycl_pool_alloc<matrix_info_t<float>> matrix_info(ctx.host_pool(), 1);
|
|
ggml_sycl_pool_alloc<matrix_info_t<float>> matrix_info(ctx.host_pool(), 1);
|
|
|
|
|
|
|
|
sycl::range<3> block_dims(1, ne12, ne13);
|
|
sycl::range<3> block_dims(1, ne12, ne13);
|
|
|
- /*
|
|
|
|
|
- DPCT1049:47: The work-group size passed to the SYCL kernel may exceed
|
|
|
|
|
- the limit. To get the device limit, query
|
|
|
|
|
- info::device::max_work_group_size. Adjust the work-group size if needed.
|
|
|
|
|
- */
|
|
|
|
|
- {
|
|
|
|
|
- dpct::has_capability_or_fail(main_stream->get_device(),
|
|
|
|
|
- {sycl::aspect::fp16});
|
|
|
|
|
-
|
|
|
|
|
- main_stream->submit([&](sycl::handler &cgh) {
|
|
|
|
|
- const void **ptrs_src_get = ptrs_src.get();
|
|
|
|
|
- void **ptrs_dst_get = ptrs_dst.get();
|
|
|
|
|
- size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : nb12 / 2;
|
|
|
|
|
- size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : nb13 / 2;
|
|
|
|
|
- cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims),
|
|
|
|
|
- [=](sycl::nd_item<3> item_ct1) {
|
|
|
|
|
- k_compute_batched_ptrs(
|
|
|
|
|
- src0_as_f16, src1_f16,
|
|
|
|
|
- dst_t, ptrs_src_get,
|
|
|
|
|
- ptrs_dst_get, ne12, ne13, ne23,
|
|
|
|
|
- nb02, nb03, nb12_scaled, nb13_scaled,
|
|
|
|
|
- nbd2, nbd3, r2, r3, item_ct1);
|
|
|
|
|
- });
|
|
|
|
|
|
|
+ queue->submit([&](sycl::handler & cgh) {
|
|
|
|
|
+ const void ** ptrs_src_get = ptrs_src.get();
|
|
|
|
|
+ void ** ptrs_dst_get = ptrs_dst.get();
|
|
|
|
|
+ size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half);
|
|
|
|
|
+ size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half);
|
|
|
|
|
+ cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
|
|
|
|
+ k_compute_batched_ptrs(src0_f16, src1_f16, dst_t, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02,
|
|
|
|
|
+ nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1);
|
|
|
});
|
|
});
|
|
|
- }
|
|
|
|
|
|
|
+ });
|
|
|
|
|
+
|
|
|
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
|
|
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
|
|
|
- *main_stream, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
|
|
|
|
|
|
|
+ *queue, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
|
|
|
(const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
|
|
(const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
|
|
|
- (const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, nb11 / nb10, beta,
|
|
|
|
|
- (void **) (ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23, cu_compute_type, matrix_info.get())));
|
|
|
|
|
|
|
+ (const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, s11, beta,
|
|
|
|
|
+ (void **) (ptrs_dst.get() + 0 * ne23), mkl_data_type, ne0, ne23, mkl_compute_type, matrix_info.get())));
|
|
|
}
|
|
}
|
|
|
-}
|
|
|
|
|
-catch (sycl::exception const &exc) {
|
|
|
|
|
- std::cerr << exc.what() << "Exception caught at file:" << __FILE__
|
|
|
|
|
- << ", line:" << __LINE__ << std::endl;
|
|
|
|
|
- std::exit(1);
|
|
|
|
|
|
|
+} catch (const sycl::exception & exc) {
|
|
|
|
|
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
|
|
|
|
|
+ std::exit(1);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
|
|
inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
|
|
@@ -2966,7 +2960,7 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
|
|
|
// The kernel from the if path is faster for that specific case, but does not support all mul mats.
|
|
// The kernel from the if path is faster for that specific case, but does not support all mul mats.
|
|
|
ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
|
|
ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
|
|
|
}
|
|
}
|
|
|
- } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
|
|
|
|
|
|
|
+ } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
|
|
|
// KQV single-batch
|
|
// KQV single-batch
|
|
|
ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst);
|
|
ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst);
|
|
|
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
|
|
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
|
|
@@ -3873,9 +3867,6 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|
|
if (a->ne[3] != b->ne[3]) {
|
|
if (a->ne[3] != b->ne[3]) {
|
|
|
return false;
|
|
return false;
|
|
|
}
|
|
}
|
|
|
- if (!ggml_is_contiguous(b)) {
|
|
|
|
|
- return false;
|
|
|
|
|
- }
|
|
|
|
|
ggml_type a_type = a->type;
|
|
ggml_type a_type = a->type;
|
|
|
if (a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ4_XS ||
|
|
if (a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ4_XS ||
|
|
|
a_type == GGML_TYPE_IQ3_XXS || a_type == GGML_TYPE_IQ3_S ||
|
|
a_type == GGML_TYPE_IQ3_XXS || a_type == GGML_TYPE_IQ3_S ||
|