|
@@ -50,6 +50,7 @@
|
|
|
#include "ggml-cuda/upscale.cuh"
|
|
#include "ggml-cuda/upscale.cuh"
|
|
|
#include "ggml-cuda/wkv.cuh"
|
|
#include "ggml-cuda/wkv.cuh"
|
|
|
#include "ggml-cuda/gla.cuh"
|
|
#include "ggml-cuda/gla.cuh"
|
|
|
|
|
+#include "ggml-cuda/set.cuh"
|
|
|
#include "ggml-cuda/set-rows.cuh"
|
|
#include "ggml-cuda/set-rows.cuh"
|
|
|
#include "ggml-cuda/pad_reflect_1d.cuh"
|
|
#include "ggml-cuda/pad_reflect_1d.cuh"
|
|
|
#include "ggml.h"
|
|
#include "ggml.h"
|
|
@@ -2416,6 +2417,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|
|
case GGML_OP_SET_ROWS:
|
|
case GGML_OP_SET_ROWS:
|
|
|
ggml_cuda_op_set_rows(ctx, dst);
|
|
ggml_cuda_op_set_rows(ctx, dst);
|
|
|
break;
|
|
break;
|
|
|
|
|
+ case GGML_OP_SET:
|
|
|
|
|
+ ggml_cuda_op_set(ctx, dst);
|
|
|
|
|
+ break;
|
|
|
case GGML_OP_DUP:
|
|
case GGML_OP_DUP:
|
|
|
ggml_cuda_dup(ctx, dst);
|
|
ggml_cuda_dup(ctx, dst);
|
|
|
break;
|
|
break;
|
|
@@ -3842,6 +3846,13 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|
|
op->src[0]->type == GGML_TYPE_F32 &&
|
|
op->src[0]->type == GGML_TYPE_F32 &&
|
|
|
(op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32);
|
|
(op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32);
|
|
|
} break;
|
|
} break;
|
|
|
|
|
+ case GGML_OP_SET:
|
|
|
|
|
+ {
|
|
|
|
|
+ const ggml_type t = op->type;
|
|
|
|
|
+ return (t == GGML_TYPE_F32 || t == GGML_TYPE_I32) &&
|
|
|
|
|
+ t == op->src[0]->type &&
|
|
|
|
|
+ t == op->src[1]->type;
|
|
|
|
|
+ } break;
|
|
|
case GGML_OP_CPY:
|
|
case GGML_OP_CPY:
|
|
|
{
|
|
{
|
|
|
ggml_type src0_type = op->src[0]->type;
|
|
ggml_type src0_type = op->src[0]->type;
|