|
@@ -48,6 +48,7 @@
|
|
|
#include "ggml-sycl/set.hpp"
|
|
#include "ggml-sycl/set.hpp"
|
|
|
#include "ggml-sycl/sycl_hw.hpp"
|
|
#include "ggml-sycl/sycl_hw.hpp"
|
|
|
#include "ggml-sycl/getrows.hpp"
|
|
#include "ggml-sycl/getrows.hpp"
|
|
|
|
|
+#include "ggml-sycl/repeat_back.hpp"
|
|
|
#include "ggml-sycl/quantize.hpp"
|
|
#include "ggml-sycl/quantize.hpp"
|
|
|
#include "ggml.h"
|
|
#include "ggml.h"
|
|
|
|
|
|
|
@@ -2615,6 +2616,10 @@ catch (sycl::exception const &exc) {
|
|
|
std::exit(1);
|
|
std::exit(1);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+static void ggml_sycl_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
|
|
|
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
|
|
|
|
+ ggml_sycl_op_repeat_back(ctx, dst);
|
|
|
|
|
+}
|
|
|
|
|
|
|
|
static void ggml_sycl_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
static void ggml_sycl_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
|
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
|
|
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
|
|
@@ -3679,6 +3684,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
|
|
case GGML_OP_REPEAT:
|
|
case GGML_OP_REPEAT:
|
|
|
ggml_sycl_repeat(ctx, dst);
|
|
ggml_sycl_repeat(ctx, dst);
|
|
|
break;
|
|
break;
|
|
|
|
|
+ case GGML_OP_REPEAT_BACK:
|
|
|
|
|
+ ggml_sycl_repeat_back(ctx, dst);
|
|
|
|
|
+ break;
|
|
|
case GGML_OP_GET_ROWS:
|
|
case GGML_OP_GET_ROWS:
|
|
|
ggml_sycl_get_rows(ctx, dst);
|
|
ggml_sycl_get_rows(ctx, dst);
|
|
|
break;
|
|
break;
|
|
@@ -4516,6 +4524,11 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|
|
ggml_type src0_type = op->src[0]->type;
|
|
ggml_type src0_type = op->src[0]->type;
|
|
|
return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
|
|
return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
|
|
|
}
|
|
}
|
|
|
|
|
+ case GGML_OP_REPEAT_BACK:
|
|
|
|
|
+ {
|
|
|
|
|
+ ggml_type src0_type = op->src[0]->type;
|
|
|
|
|
+ return src0_type == GGML_TYPE_F32;
|
|
|
|
|
+ }
|
|
|
case GGML_OP_DUP:
|
|
case GGML_OP_DUP:
|
|
|
case GGML_OP_ARGMAX:
|
|
case GGML_OP_ARGMAX:
|
|
|
case GGML_OP_NONE:
|
|
case GGML_OP_NONE:
|