Prechádzať zdrojové kódy

cuda : add FILL op support (#17851)

* cuda : add FILL op support

* cuda : add missing FILL op files
Jay Zenith 1 mesiac pred
rodič
commit
51e0c2d917

+ 37 - 0
ggml/src/ggml-cuda/fill.cu

@@ -0,0 +1,37 @@
+#include "fill.cuh"
+#include "convert.cuh"
+
+#define CUDA_FILL_BLOCK_SIZE 256
+
+template <typename T>
+static __global__ void fill_kernel(T * __restrict__ dst, const int64_t k, const T value) {
+    const int64_t i = (int64_t)blockDim.x * blockIdx.x + threadIdx.x;
+    if (i >= k) {
+        return;
+    }
+    dst[i] = value;
+}
+
+void ggml_cuda_op_fill(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    void * dst_d = dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(ggml_is_contiguous(dst));
+
+    float value;
+    memcpy(&value, dst->op_params, sizeof(float));
+
+    const int64_t k = ggml_nelements(dst);
+    const int64_t num_blocks = (k + CUDA_FILL_BLOCK_SIZE - 1) / CUDA_FILL_BLOCK_SIZE;
+
+    switch (dst->type) {
+        case GGML_TYPE_F32:
+            fill_kernel<<<num_blocks, CUDA_FILL_BLOCK_SIZE, 0, stream>>>((float *)dst_d, k, value);
+            break;
+        case GGML_TYPE_F16:
+            fill_kernel<<<num_blocks, CUDA_FILL_BLOCK_SIZE, 0, stream>>>((half *)dst_d, k, ggml_cuda_cast<half>(value));
+            break;
+        default:
+            GGML_ABORT("unsupported type");
+    }
+}

+ 3 - 0
ggml/src/ggml-cuda/fill.cuh

@@ -0,0 +1,3 @@
+#include "common.cuh"
+
+void ggml_cuda_op_fill(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

+ 5 - 0
ggml/src/ggml-cuda/ggml-cuda.cu

@@ -56,6 +56,7 @@
 #include "ggml-cuda/solve_tri.cuh"
 #include "ggml-cuda/tri.cuh"
 #include "ggml-cuda/cumsum.cuh"
+#include "ggml-cuda/fill.cuh"
 #include "ggml.h"
 
 #include <algorithm>
@@ -2730,6 +2731,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_SOLVE_TRI:
             ggml_cuda_op_solve_tri(ctx, dst);
             break;
+        case GGML_OP_FILL:
+            ggml_cuda_op_fill(ctx, dst);
+            break;
         default:
             return false;
     }
@@ -4617,6 +4621,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
         case GGML_OP_OPT_STEP_ADAMW:
         case GGML_OP_OPT_STEP_SGD:
+        case GGML_OP_FILL:
         case GGML_OP_CUMSUM:
         case GGML_OP_TRI:
             return true;