소스 검색

Add DIAG for CUDA (#17873)

* Add DIAG for CUDA

* Refactor parameters
Piotr Wilkin (ilintar) 1 개월 전
부모
커밋
b63509262a
4개의 변경된 파일116개의 추가작업 그리고 0개의 파일을 삭제
  1. 77 0
      ggml/src/ggml-cuda/diag.cu
  2. 5 0
      ggml/src/ggml-cuda/diag.cuh
  3. 5 0
      ggml/src/ggml-cuda/ggml-cuda.cu
  4. 29 0
      tests/test-backend-ops.cpp

+ 77 - 0
ggml/src/ggml-cuda/diag.cu

@@ -0,0 +1,77 @@
+#include "convert.cuh"
+#include "diag.cuh"
+#include "ggml.h"
+
+template <typename T>
+static __global__ void diag_kernel(T * __restrict__ dst,
+                                   const T * __restrict__ src,
+                                   const int64_t ne0,
+                                   const int64_t ne1,
+                                   const int64_t ne2,
+                                   const int64_t ne3,
+                                   const int64_t total_elements) {
+    const int64_t global_idx = blockIdx.x * blockDim.x + threadIdx.x;
+
+    if (global_idx >= total_elements) {
+        return;
+    }
+
+    const int64_t i0 = global_idx % ne0;
+    const int64_t i1 = (global_idx / ne0) % ne1;
+    const int64_t i2 = (global_idx / (ne0 * ne1)) % ne2;
+    const int64_t i3 = global_idx / (ne0 * ne1 * ne2);
+
+    const int64_t dst_idx = ((i3 * ne2 + i2) * ne1 + i1) * ne0 + i0;
+
+    if (i0 == i1) {
+        const int64_t batch_idx = i3 * ne2 + i2;
+        const int64_t src_idx   = batch_idx * ne0 + i0;
+        dst[dst_idx]            = src[src_idx];
+    } else {
+        dst[dst_idx] = ggml_cuda_cast<T>(0);
+    }
+    GGML_UNUSED_VARS(ne3);
+}
+
+void ggml_cuda_op_diag(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+
+    void *       dst_d  = dst->data;
+    const void * src0_d = src0->data;
+
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(ggml_is_contiguous(dst));
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+    const int64_t ne03 = src0->ne[3];
+
+    const int64_t ne0 = dst->ne[0];
+    const int64_t ne1 = dst->ne[1];
+    const int64_t ne2 = dst->ne[2];
+    const int64_t ne3 = dst->ne[3];
+
+    GGML_ASSERT(ne00 == ne0);
+    GGML_ASSERT(ne01 == 1);
+    GGML_ASSERT(ne02 == ne2);
+    GGML_ASSERT(ne03 == ne3);
+
+    const int64_t n_elems    = ggml_nelements(dst);
+    const int64_t num_blocks = (n_elems + CUDA_DIAG_BLOCK_SIZE - 1) / CUDA_DIAG_BLOCK_SIZE;
+
+    switch (dst->type) {
+        case GGML_TYPE_F32:
+            diag_kernel<<<num_blocks, CUDA_DIAG_BLOCK_SIZE, 0, stream>>>((float *) dst_d, (const float *) src0_d, ne0,
+                                                                         ne1, ne2, ne3, n_elems);
+            break;
+        case GGML_TYPE_F16:
+            diag_kernel<<<num_blocks, CUDA_DIAG_BLOCK_SIZE, 0, stream>>>((half *) dst_d, (const half *) src0_d, ne0,
+                                                                         ne1, ne2, ne3, n_elems);
+            break;
+        default:
+            GGML_ABORT("unsupported type");
+    }
+}

+ 5 - 0
ggml/src/ggml-cuda/diag.cuh

@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+#define CUDA_DIAG_BLOCK_SIZE 256
+
+void ggml_cuda_op_diag(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

+ 5 - 0
ggml/src/ggml-cuda/ggml-cuda.cu

@@ -20,6 +20,7 @@
 #include "ggml-cuda/cpy.cuh"
 #include "ggml-cuda/cross-entropy-loss.cuh"
 #include "ggml-cuda/diagmask.cuh"
+#include "ggml-cuda/diag.cuh"
 #include "ggml-cuda/fattn.cuh"
 #include "ggml-cuda/getrows.cuh"
 #include "ggml-cuda/im2col.cuh"
@@ -2641,6 +2642,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_PERMUTE:
         case GGML_OP_TRANSPOSE:
                 break;
+        case GGML_OP_DIAG:
+            ggml_cuda_op_diag(ctx, dst);
+            break;
         case GGML_OP_DIAG_MASK_INF:
             ggml_cuda_op_diag_mask_inf(ctx, dst);
             break;
@@ -4624,6 +4628,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_FILL:
         case GGML_OP_CUMSUM:
         case GGML_OP_TRI:
+        case GGML_OP_DIAG:
             return true;
         case GGML_OP_SOLVE_TRI:
             return op->src[0]->ne[0] <= 64 && op->src[1]->ne[0] <= 32;

+ 29 - 0
tests/test-backend-ops.cpp

@@ -6253,6 +6253,31 @@ struct test_solve_tri : public test_case {
     }
 };
 
+// GGML_OP_DIAG
+struct test_diag : public test_case {
+    const ggml_type              type;
+    const std::array<int64_t, 4> ne;
+
+    std::string vars() override { return VARS_TO_STR2(type, ne); }
+
+    test_diag(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = { 10, 1, 4, 3 })
+        : type(type), ne(ne) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        GGML_ASSERT(ne[1] == 1);
+        ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
+        ggml_set_param(a);
+        ggml_set_name(a, "a");
+
+        ggml_tensor * out = ggml_diag(ctx, a);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+};
+
+
 enum llm_norm_type {
     LLM_NORM,
     LLM_NORM_RMS,
@@ -7826,6 +7851,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
     test_cases.emplace_back(new test_fill(-152.0f, GGML_TYPE_F32, { 800, 600, 4, 4 }));
     test_cases.emplace_back(new test_fill(3.5f, GGML_TYPE_F32, { 2048, 512, 2, 2 }));
 
+    test_cases.emplace_back(new test_diag());
+    test_cases.emplace_back(new test_diag(GGML_TYPE_F32, { 79, 1, 19, 13 }));
+    test_cases.emplace_back(new test_diag(GGML_TYPE_F32, { 256, 1, 8, 16 }));
+
     test_cases.emplace_back(new test_solve_tri());
     test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 11, 11, 1, 1 }, { 5, 11, 1, 1 }));
     test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 17, 17, 2, 4 }, { 9, 17, 2, 4 }));