Bladeren bron

SOLVE_TRI CUDA kernel for small matrices (#17457)

Piotr Wilkin (ilintar) 1 maand geleden
bovenliggende
commit
cd0e3a7a3b
4 gewijzigde bestanden met toevoegingen van 215 en 0 verwijderingen
  1. 6 0
      ggml/src/ggml-cuda/ggml-cuda.cu
  2. 203 0
      ggml/src/ggml-cuda/solve_tri.cu
  3. 3 0
      ggml/src/ggml-cuda/solve_tri.cuh
  4. 3 0
      tests/test-backend-ops.cpp

+ 6 - 0
ggml/src/ggml-cuda/ggml-cuda.cu

@@ -53,6 +53,7 @@
 #include "ggml-cuda/set.cuh"
 #include "ggml-cuda/set-rows.cuh"
 #include "ggml-cuda/pad_reflect_1d.cuh"
+#include "ggml-cuda/solve_tri.cuh"
 #include "ggml.h"
 
 #include <algorithm>
@@ -2717,6 +2718,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_OPT_STEP_SGD:
             ggml_cuda_opt_step_sgd(ctx, dst);
             break;
+        case GGML_OP_SOLVE_TRI:
+            ggml_cuda_op_solve_tri(ctx, dst);
+            break;
         default:
             return false;
     }
@@ -4255,6 +4259,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_OPT_STEP_ADAMW:
         case GGML_OP_OPT_STEP_SGD:
             return true;
+        case GGML_OP_SOLVE_TRI:
+            return op->src[0]->ne[0] <= 64 && op->src[1]->ne[0] <= 32;
         default:
             return false;
     }

+ 203 - 0
ggml/src/ggml-cuda/solve_tri.cu

@@ -0,0 +1,203 @@
+#include "common.cuh"
+#include "ggml.h"
+#include "solve_tri.cuh"
+
+#define MAX_N_FAST 64
+#define MAX_K_FAST 32
+
+// ======================
+// Fast Kernel (n <= 64, k <= 32) - Warp-based parallel reduction
+// ======================
+// When ncols_template == 0 the bounds for the loops in this function are not
+// known and can't be unrolled. As we want to keep pragma unroll for all other
+// cases we supress the clang transformation warning here.
+#ifdef __clang__
+#    pragma clang diagnostic push
+#    pragma clang diagnostic ignored "-Wpass-failed"
+#endif  // __clang__
+template <int n_template, int k_template>
+static __global__ void solve_tri_f32_fast(const float * __restrict__ A,
+                                          const float * __restrict__ B,
+                                          float * __restrict__ X,
+                                          const uint3  ne02,
+                                          const size_t nb02,
+                                          const size_t nb03,
+                                          const size_t nb12,
+                                          const size_t nb13,
+                                          const size_t nb2,
+                                          const size_t nb3,
+                                          const int    n_arg,
+                                          const int    k_arg) {
+    const int n = n_template == 0 ? n_arg : n_template;
+    const int k = k_template == 0 ? k_arg : k_template;
+
+    const int batch_idx = blockIdx.x;
+    const int lane      = threadIdx.x;
+    const int col_idx   = threadIdx.y;
+
+    if (col_idx >= k) {
+        return;
+    }
+
+    const uint2   i02_i03 = fast_div_modulo(batch_idx, ne02);
+    const int64_t i02     = i02_i03.y;
+    const int64_t i03     = i02_i03.x;
+
+    const float * const A_batch = (const float *) (A + i02 * nb02 + i03 * nb03);
+    const float * const B_batch = (const float *) (B + i02 * nb12 + i03 * nb13);
+    float *             X_batch = (float *) (X + i02 * nb2 + i03 * nb3);
+
+    __shared__ float sA[MAX_N_FAST * MAX_N_FAST];
+    __shared__ float sXt[MAX_N_FAST * (MAX_K_FAST + 1)];
+
+    const int offset = threadIdx.x + threadIdx.y * blockDim.x;
+
+#pragma unroll
+    for (int i = 0; i < n * n; i += k * WARP_SIZE) {
+        int i0 = i + offset;
+        if (i0 < n * n) {
+            sA[i0] = A_batch[i0];
+        }
+    }
+
+    const int rows_per_warp = (n + WARP_SIZE - 1) / WARP_SIZE;
+
+#pragma unroll
+    for (int i = 0; i < rows_per_warp; i++) {
+        const int i0 = lane + i * WARP_SIZE;
+        if (i0 < n) {
+            sXt[col_idx * n + i0] = B_batch[i0 * k + col_idx];
+        }
+    }
+
+    __syncthreads();
+
+#pragma unroll
+    for (int row = 0; row < n; ++row) {
+        float sum = 0.0f;
+
+        {
+            int j = lane;
+            if (j < row) {
+                sum += sA[row * n + j] * sXt[col_idx * n + j];
+            }
+        }
+        if (row >= WARP_SIZE) {
+            int j = WARP_SIZE + lane;
+            if (j < row) {
+                sum += sA[row * n + j] * sXt[col_idx * n + j];
+            }
+        }
+
+        sum = warp_reduce_sum(sum);
+
+        if (lane == 0) {
+            const float b_val      = sXt[col_idx * n + row];
+            const float a_diag     = sA[row * n + row];
+            // no safeguards for division by zero because that indicates corrupt
+            // data anyway
+            sXt[col_idx * n + row] = (b_val - sum) / a_diag;
+        }
+    }
+
+    __syncthreads();
+
+#pragma unroll
+    for (int i = 0; i < rows_per_warp; i++) {
+        const int i0 = lane + i * WARP_SIZE;
+        if (i0 < n) {
+            X_batch[i0 * k + col_idx] = sXt[col_idx * n + i0];
+        }
+    }
+}
+#ifdef __clang__
+#    pragma clang diagnostic pop
+#endif  // __clang__
+
+static void solve_tri_f32_cuda(const float * A,
+                               const float * B,
+                               float *       X,
+                               int           n,
+                               int           k,
+                               int64_t       ne02,
+                               int64_t       ne03,
+                               size_t        nb02,
+                               size_t        nb03,
+                               size_t        nb12,
+                               size_t        nb13,
+                               size_t        nb2,
+                               size_t        nb3,
+                               cudaStream_t  stream) {
+    const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02);
+    dim3        threads(WARP_SIZE, k);
+    dim3        grid(ne02 * ne03);
+    if (n == 64) {
+        switch (k) {
+            case 32:
+                solve_tri_f32_fast<64, 32>
+                    <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
+                break;
+            case 16:
+                solve_tri_f32_fast<64, 16>
+                    <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
+                break;
+            case 14:
+                solve_tri_f32_fast<64, 14>
+                    <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
+                break;
+            case 12:
+                solve_tri_f32_fast<64, 12>
+                    <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
+                break;
+            case 10:
+                solve_tri_f32_fast<64, 10>
+                    <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
+                break;
+            case 8:
+                solve_tri_f32_fast<64, 8>
+                    <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
+                break;
+            case 6:
+                solve_tri_f32_fast<64, 6>
+                    <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
+                break;
+            case 4:
+                solve_tri_f32_fast<64, 4>
+                    <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
+                break;
+            case 2:
+                solve_tri_f32_fast<64, 2>
+                    <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
+                break;
+            case 1:
+                solve_tri_f32_fast<64, 1>
+                    <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
+                break;
+            default:
+                solve_tri_f32_fast<0, 0>
+                    <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
+        }
+    } else {  // run general case
+        solve_tri_f32_fast<0, 0>
+            <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
+    }
+}
+
+void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];  // A (triangular n x x matrix)
+    const ggml_tensor * src1 = dst->src[1];  // B (right hand side of n x k equation columns)
+
+    ggml_is_contiguous(src0);
+    ggml_is_contiguous(src1);
+
+    const int64_t n = src0->ne[0];
+    const int64_t k = src1->ne[0];
+
+    GGML_ASSERT(n <= 64);
+    GGML_ASSERT(k <= 32);
+
+    solve_tri_f32_cuda((const float *) src0->data, (const float *) src1->data, (float *) dst->data, n, k, src0->ne[2],
+                       src0->ne[3], src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float),
+                       src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float), dst->nb[2] / sizeof(float),
+                       dst->nb[3] / sizeof(float), ctx.stream());
+}

+ 3 - 0
ggml/src/ggml-cuda/solve_tri.cuh

@@ -0,0 +1,3 @@
+#include "common.cuh"
+
+void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

+ 3 - 0
tests/test-backend-ops.cpp

@@ -7935,6 +7935,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
     test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16416, 1, 128, {8,  1}, {4, 1}, {0, 2, 1, 3}));
     test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 1, 16416, {8,  1}, {4, 1}, {0, 1, 2, 3}, 2*16416));
 
+    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 2 }, { 6, 64, 4, 2 }));
+    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 1 }, { 8, 128, 4, 1 }));
+
     for (int bs : {1, 2, 3, 4, 5, 8, 512}) {
         for (ggml_type type_a : all_types) {
             for (ggml_type type_b : {GGML_TYPE_F32}) {