|
@@ -0,0 +1,122 @@
|
|
|
|
|
+#include "roll.hpp"
|
|
|
|
|
+#include "common.hpp"
|
|
|
|
|
+
|
|
|
|
|
+using namespace sycl;
|
|
|
|
|
+
|
|
|
|
|
+static inline int wrap_add(int i, int shift, int n) {
|
|
|
|
|
+
|
|
|
|
|
+ int s = i + shift;
|
|
|
|
|
+ return (s >= n) ? (s - n) : s;
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+static void kernel_roll_fused_i0_i1(
|
|
|
|
|
+ queue &q,
|
|
|
|
|
+ const float *src_d,
|
|
|
|
|
+ float *dst_d,
|
|
|
|
|
+ int ne0, int ne1, int ne2, int ne3,
|
|
|
|
|
+ int sh0, int sh1, int sh2, int sh3)
|
|
|
|
|
+{
|
|
|
|
|
+ if (ne0 == 0 || ne1 == 0 || ne2 == 0 || ne3 == 0) return;
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+ const int stride1 = ne0;
|
|
|
|
|
+ const int stride2 = ne0 * ne1;
|
|
|
|
|
+ const int stride3 = ne0 * ne1 * ne2;
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+ const int shNe0 = (ne0 - sh0) % ne0;
|
|
|
|
|
+ const int shNe1 = (ne1 - sh1) % ne1;
|
|
|
|
|
+ const int shNe2 = (ne2 - sh2) % ne2;
|
|
|
|
|
+ const int shNe3 = (ne3 - sh3) % ne3;
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+ const size_t g0 = (size_t) ne3;
|
|
|
|
|
+ const size_t g1 = (size_t) ne2;
|
|
|
|
|
+ const size_t g2 = (size_t) (ne1 * ne0);
|
|
|
|
|
+
|
|
|
|
|
+ const range<3> global{ g0, g1, g2 };
|
|
|
|
|
+
|
|
|
|
|
+ q.submit([&](handler &h) {
|
|
|
|
|
+ h.parallel_for(global, [=](id<3> idx) {
|
|
|
|
|
+ const int i3 = (int) idx[0];
|
|
|
|
|
+ const int i2 = (int) idx[1];
|
|
|
|
|
+
|
|
|
|
|
+ const int fused = (int) idx[2];
|
|
|
|
|
+ const int i1 = fused / ne0;
|
|
|
|
|
+ const int i0 = fused - i1 * ne0; // fused % ne0
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+ const int idx_dst = i0
|
|
|
|
|
+ + i1 * stride1
|
|
|
|
|
+ + i2 * stride2
|
|
|
|
|
+ + i3 * stride3;
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+ const int s0 = wrap_add(i0, shNe0, ne0);
|
|
|
|
|
+ const int s1 = wrap_add(i1, shNe1, ne1);
|
|
|
|
|
+ const int s2 = wrap_add(i2, shNe2, ne2);
|
|
|
|
|
+ const int s3 = wrap_add(i3, shNe3, ne3);
|
|
|
|
|
+
|
|
|
|
|
+ const int idx_src = s0
|
|
|
|
|
+ + s1 * stride1
|
|
|
|
|
+ + s2 * stride2
|
|
|
|
|
+ + s3 * stride3;
|
|
|
|
|
+
|
|
|
|
|
+ dst_d[idx_dst] = src_d[idx_src];
|
|
|
|
|
+ });
|
|
|
|
|
+ });
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+void ggml_sycl_roll(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
|
|
|
|
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
|
|
|
|
+
|
|
|
|
|
+ const ggml_tensor *src = dst->src[0];
|
|
|
|
|
+ GGML_ASSERT(src && src->type == GGML_TYPE_F32);
|
|
|
|
|
+
|
|
|
|
|
+ const int ne0 = (int) dst->ne[0];
|
|
|
|
|
+ const int ne1 = (int) dst->ne[1];
|
|
|
|
|
+ const int ne2 = (int) dst->ne[2];
|
|
|
|
|
+ const int ne3 = (int) dst->ne[3];
|
|
|
|
|
+
|
|
|
|
|
+ const int32_t *params = (const int32_t *) dst->op_params;
|
|
|
|
|
+ int shift0 = params[0];
|
|
|
|
|
+ int shift1 = params[1];
|
|
|
|
|
+ int shift2 = params[2];
|
|
|
|
|
+ int shift3 = params[3];
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+ if ((shift0 | shift1 | shift2 | shift3) == 0) {
|
|
|
|
|
+ const size_t nb = ggml_nbytes(src);
|
|
|
|
|
+ queue *q = ctx.stream();
|
|
|
|
|
+ SYCL_CHECK(CHECK_TRY_ERROR(q->memcpy(dst->data, src->data, nb)));
|
|
|
|
|
+ return;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ auto norm = [](int sh, int n) -> int {
|
|
|
|
|
+ if (n <= 0) return 0;
|
|
|
|
|
+ sh %= n;
|
|
|
|
|
+ if (sh < 0) sh += n;
|
|
|
|
|
+ return sh;
|
|
|
|
|
+ };
|
|
|
|
|
+ shift0 = norm(shift0, ne0);
|
|
|
|
|
+ shift1 = norm(shift1, ne1);
|
|
|
|
|
+ shift2 = norm(shift2, ne2);
|
|
|
|
|
+ shift3 = norm(shift3, ne3);
|
|
|
|
|
+
|
|
|
|
|
+ try {
|
|
|
|
|
+ queue *q = ctx.stream();
|
|
|
|
|
+
|
|
|
|
|
+ const float *src_d = (const float *) src->data;
|
|
|
|
|
+ float *dst_d = (float *) dst->data;
|
|
|
|
|
+ GGML_ASSERT(src_d && dst_d);
|
|
|
|
|
+
|
|
|
|
|
+ kernel_roll_fused_i0_i1(
|
|
|
|
|
+ *q, src_d, dst_d,
|
|
|
|
|
+ ne0, ne1, ne2, ne3,
|
|
|
|
|
+ shift0, shift1, shift2, shift3
|
|
|
|
|
+ );
|
|
|
|
|
+ } catch (const std::exception &e) {
|
|
|
|
|
+ std::fprintf(stderr, "[SYCL-ROLL] ERROR: %s\n", e.what());
|
|
|
|
|
+ throw;
|
|
|
|
|
+ }
|
|
|
|
|
+}
|