|
|
@@ -1,4 +1,5 @@
|
|
|
#include "set_rows.hpp"
|
|
|
+#include "cpy.hpp"
|
|
|
|
|
|
namespace utils {
|
|
|
template<typename T>
|
|
|
@@ -15,6 +16,68 @@ convert (const char* src, char* dst) {
|
|
|
*reinterpret_cast<TOut*>(dst) = dst_val;
|
|
|
}
|
|
|
|
|
|
+template <typename blockType, int qk, cpy_kernel_t cpyblck>
|
|
|
+static void set_rows_sycl_q(const char * __restrict__ src0_d,
|
|
|
+ const int64_t * __restrict__ src1_d,
|
|
|
+ blockType * __restrict__ dst_d,
|
|
|
+ // tensor dimensions src0 and src1
|
|
|
+ const int64_t ne00,
|
|
|
+ const int64_t ne01,
|
|
|
+ const int64_t ne02,
|
|
|
+ const int64_t ne03,
|
|
|
+ const int64_t ne10,
|
|
|
+ const int64_t ne11,
|
|
|
+ const int64_t ne12,
|
|
|
+ const int64_t ne13,
|
|
|
+ // strides for src0
|
|
|
+ const size_t nb00,
|
|
|
+ const size_t nb01,
|
|
|
+ const size_t nb02,
|
|
|
+ const size_t nb03,
|
|
|
+ // strides for src1
|
|
|
+ const size_t nb10,
|
|
|
+ const size_t nb11,
|
|
|
+ const size_t nb12,
|
|
|
+ const size_t nb13,
|
|
|
+ // strides for dst
|
|
|
+ const size_t nb1,
|
|
|
+ const size_t nb2,
|
|
|
+ const size_t nb3,
|
|
|
+ queue_ptr stream) {
|
|
|
+ const int64_t total_blocks = (ne00 * ne01 * ne02 * ne03) / qk;
|
|
|
+ constexpr int block_size = 256;
|
|
|
+ const int64_t grid_size = ceil_div(total_blocks, block_size);
|
|
|
+
|
|
|
+ sycl_parallel_for(stream, sycl::nd_range<1>(grid_size * block_size, block_size), [=](sycl::nd_item<1> item_ct1) {
|
|
|
+ const int64_t i = item_ct1.get_global_linear_id();
|
|
|
+ if (i >= total_blocks) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ const int64_t i_base = i * qk;
|
|
|
+ const int64_t i03 = i_base / (ne00 * ne01 * ne02);
|
|
|
+ const int64_t rem1 = i_base - i03 * (ne00 * ne01 * ne02);
|
|
|
+ const int64_t i02 = rem1 / (ne00 * ne01);
|
|
|
+ const int64_t rem2 = rem1 - i02 * ne00 * ne01;
|
|
|
+ const int64_t i01 = rem2 / ne00;
|
|
|
+ const int64_t i00 = rem2 - i01 * ne00;
|
|
|
+ const int64_t i12 = i03 % ne12;
|
|
|
+ const int64_t i11 = i02 % ne11;
|
|
|
+ const int64_t i10 = i01;
|
|
|
+ const size_t src_offset = calculate_offset<3>({ nb01, nb02, nb03 }, { i01, i02, i03 });
|
|
|
+ const char * src_block = src0_d + src_offset + i00 * sizeof(float);
|
|
|
+ const size_t src1_offset = calculate_offset<3>({ nb10, nb11, nb12 }, { i10, i11, i12 });
|
|
|
+ const int64_t dst_row = src1_d[src1_offset / sizeof(int64_t)];
|
|
|
+ const size_t dst_offset =
|
|
|
+ calculate_offset<3>({ nb1, nb2, nb3 }, { dst_row, i02, i03 }) + (i00 / qk) * sizeof(blockType);
|
|
|
+ char * dst_block = reinterpret_cast<char *>(reinterpret_cast<char *>(dst_d) + dst_offset);
|
|
|
+ cpyblck(src_block, dst_block);
|
|
|
+ });
|
|
|
+ GGML_UNUSED(ne10);
|
|
|
+ GGML_UNUSED(ne13);
|
|
|
+ GGML_UNUSED(nb00);
|
|
|
+ GGML_UNUSED(nb13);
|
|
|
+}
|
|
|
+
|
|
|
template<typename TIn, typename TOut>
|
|
|
static void k_set_rows(
|
|
|
const char * __restrict__ src0, const int64_t * __restrict__ src1, char * __restrict__ dst,
|
|
|
@@ -124,6 +187,37 @@ void ggml_sycl_op_set_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
|
stream
|
|
|
);
|
|
|
break;
|
|
|
+ case GGML_TYPE_BF16:
|
|
|
+ set_rows_sycl<float, sycl::ext::oneapi::bfloat16>(
|
|
|
+ (const char *)src0->data, src1_dd, (char *)dst->data,
|
|
|
+ ne00, ne01, ne02, ne03,
|
|
|
+ ne11, ne12,
|
|
|
+ nb01, nb02, nb03,
|
|
|
+ nb10, nb11, nb12,
|
|
|
+ nb1, nb2, nb3,
|
|
|
+ sizeof(float), sizeof(sycl::ext::oneapi::bfloat16),
|
|
|
+ stream
|
|
|
+ );
|
|
|
+ break;
|
|
|
+ case GGML_TYPE_Q8_0:
|
|
|
+ set_rows_sycl_q<block_q8_0, QK8_0, cpy_blck_f32_q8_0>((const char *)src0->data, src1_dd, (block_q8_0 *)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb1, nb2, nb3, stream);
|
|
|
+ break;
|
|
|
+ case GGML_TYPE_Q5_1:
|
|
|
+ set_rows_sycl_q<block_q5_1, QK5_1, cpy_blck_f32_q5_1>((const char *)src0->data, src1_dd, (block_q5_1 *)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb1, nb2, nb3, stream);
|
|
|
+ break;
|
|
|
+ case GGML_TYPE_Q5_0:
|
|
|
+ set_rows_sycl_q<block_q5_0, QK5_0, cpy_blck_f32_q5_0>((const char *)src0->data, src1_dd, (block_q5_0 *)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb1, nb2, nb3, stream);
|
|
|
+ break;
|
|
|
+ case GGML_TYPE_Q4_1:
|
|
|
+ set_rows_sycl_q<block_q4_1, QK4_1, cpy_blck_f32_q4_1>((const char *)src0->data, src1_dd, (block_q4_1 *)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb1, nb2, nb3, stream);
|
|
|
+ break;
|
|
|
+ case GGML_TYPE_Q4_0:
|
|
|
+ set_rows_sycl_q<block_q4_0, QK4_0, cpy_blck_f32_q4_0>((const char *)src0->data, src1_dd, (block_q4_0 *)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb1, nb2, nb3, stream);
|
|
|
+ break;
|
|
|
+ case GGML_TYPE_IQ4_NL:
|
|
|
+ set_rows_sycl_q<block_iq4_nl, QK4_NL, cpy_blck_f32_iq4_nl>((const char *)src0->data, src1_dd, (block_iq4_nl *)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb1, nb2, nb3, stream);
|
|
|
+ break;
|
|
|
+
|
|
|
default:
|
|
|
GGML_ABORT("Unsupported tensor type!");
|
|
|
break;
|