|
|
@@ -0,0 +1,203 @@
|
|
|
+#include "common.cuh"
|
|
|
+#include "ggml.h"
|
|
|
+#include "solve_tri.cuh"
|
|
|
+
|
|
|
+#define MAX_N_FAST 64
|
|
|
+#define MAX_K_FAST 32
|
|
|
+
|
|
|
+// ======================
|
|
|
+// Fast Kernel (n <= 64, k <= 32) - Warp-based parallel reduction
|
|
|
+// ======================
|
|
|
+// When ncols_template == 0 the bounds for the loops in this function are not
|
|
|
+// known and can't be unrolled. As we want to keep pragma unroll for all other
|
|
|
+// cases we supress the clang transformation warning here.
|
|
|
+#ifdef __clang__
|
|
|
+# pragma clang diagnostic push
|
|
|
+# pragma clang diagnostic ignored "-Wpass-failed"
|
|
|
+#endif // __clang__
|
|
|
+template <int n_template, int k_template>
|
|
|
+static __global__ void solve_tri_f32_fast(const float * __restrict__ A,
|
|
|
+ const float * __restrict__ B,
|
|
|
+ float * __restrict__ X,
|
|
|
+ const uint3 ne02,
|
|
|
+ const size_t nb02,
|
|
|
+ const size_t nb03,
|
|
|
+ const size_t nb12,
|
|
|
+ const size_t nb13,
|
|
|
+ const size_t nb2,
|
|
|
+ const size_t nb3,
|
|
|
+ const int n_arg,
|
|
|
+ const int k_arg) {
|
|
|
+ const int n = n_template == 0 ? n_arg : n_template;
|
|
|
+ const int k = k_template == 0 ? k_arg : k_template;
|
|
|
+
|
|
|
+ const int batch_idx = blockIdx.x;
|
|
|
+ const int lane = threadIdx.x;
|
|
|
+ const int col_idx = threadIdx.y;
|
|
|
+
|
|
|
+ if (col_idx >= k) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ const uint2 i02_i03 = fast_div_modulo(batch_idx, ne02);
|
|
|
+ const int64_t i02 = i02_i03.y;
|
|
|
+ const int64_t i03 = i02_i03.x;
|
|
|
+
|
|
|
+ const float * const A_batch = (const float *) (A + i02 * nb02 + i03 * nb03);
|
|
|
+ const float * const B_batch = (const float *) (B + i02 * nb12 + i03 * nb13);
|
|
|
+ float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3);
|
|
|
+
|
|
|
+ __shared__ float sA[MAX_N_FAST * MAX_N_FAST];
|
|
|
+ __shared__ float sXt[MAX_N_FAST * (MAX_K_FAST + 1)];
|
|
|
+
|
|
|
+ const int offset = threadIdx.x + threadIdx.y * blockDim.x;
|
|
|
+
|
|
|
+#pragma unroll
|
|
|
+ for (int i = 0; i < n * n; i += k * WARP_SIZE) {
|
|
|
+ int i0 = i + offset;
|
|
|
+ if (i0 < n * n) {
|
|
|
+ sA[i0] = A_batch[i0];
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ const int rows_per_warp = (n + WARP_SIZE - 1) / WARP_SIZE;
|
|
|
+
|
|
|
+#pragma unroll
|
|
|
+ for (int i = 0; i < rows_per_warp; i++) {
|
|
|
+ const int i0 = lane + i * WARP_SIZE;
|
|
|
+ if (i0 < n) {
|
|
|
+ sXt[col_idx * n + i0] = B_batch[i0 * k + col_idx];
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ __syncthreads();
|
|
|
+
|
|
|
+#pragma unroll
|
|
|
+ for (int row = 0; row < n; ++row) {
|
|
|
+ float sum = 0.0f;
|
|
|
+
|
|
|
+ {
|
|
|
+ int j = lane;
|
|
|
+ if (j < row) {
|
|
|
+ sum += sA[row * n + j] * sXt[col_idx * n + j];
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if (row >= WARP_SIZE) {
|
|
|
+ int j = WARP_SIZE + lane;
|
|
|
+ if (j < row) {
|
|
|
+ sum += sA[row * n + j] * sXt[col_idx * n + j];
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ sum = warp_reduce_sum(sum);
|
|
|
+
|
|
|
+ if (lane == 0) {
|
|
|
+ const float b_val = sXt[col_idx * n + row];
|
|
|
+ const float a_diag = sA[row * n + row];
|
|
|
+ // no safeguards for division by zero because that indicates corrupt
|
|
|
+ // data anyway
|
|
|
+ sXt[col_idx * n + row] = (b_val - sum) / a_diag;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ __syncthreads();
|
|
|
+
|
|
|
+#pragma unroll
|
|
|
+ for (int i = 0; i < rows_per_warp; i++) {
|
|
|
+ const int i0 = lane + i * WARP_SIZE;
|
|
|
+ if (i0 < n) {
|
|
|
+ X_batch[i0 * k + col_idx] = sXt[col_idx * n + i0];
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+#ifdef __clang__
|
|
|
+# pragma clang diagnostic pop
|
|
|
+#endif // __clang__
|
|
|
+
|
|
|
+static void solve_tri_f32_cuda(const float * A,
|
|
|
+ const float * B,
|
|
|
+ float * X,
|
|
|
+ int n,
|
|
|
+ int k,
|
|
|
+ int64_t ne02,
|
|
|
+ int64_t ne03,
|
|
|
+ size_t nb02,
|
|
|
+ size_t nb03,
|
|
|
+ size_t nb12,
|
|
|
+ size_t nb13,
|
|
|
+ size_t nb2,
|
|
|
+ size_t nb3,
|
|
|
+ cudaStream_t stream) {
|
|
|
+ const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02);
|
|
|
+ dim3 threads(WARP_SIZE, k);
|
|
|
+ dim3 grid(ne02 * ne03);
|
|
|
+ if (n == 64) {
|
|
|
+ switch (k) {
|
|
|
+ case 32:
|
|
|
+ solve_tri_f32_fast<64, 32>
|
|
|
+ <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
|
|
+ break;
|
|
|
+ case 16:
|
|
|
+ solve_tri_f32_fast<64, 16>
|
|
|
+ <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
|
|
+ break;
|
|
|
+ case 14:
|
|
|
+ solve_tri_f32_fast<64, 14>
|
|
|
+ <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
|
|
+ break;
|
|
|
+ case 12:
|
|
|
+ solve_tri_f32_fast<64, 12>
|
|
|
+ <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
|
|
+ break;
|
|
|
+ case 10:
|
|
|
+ solve_tri_f32_fast<64, 10>
|
|
|
+ <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
|
|
+ break;
|
|
|
+ case 8:
|
|
|
+ solve_tri_f32_fast<64, 8>
|
|
|
+ <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
|
|
+ break;
|
|
|
+ case 6:
|
|
|
+ solve_tri_f32_fast<64, 6>
|
|
|
+ <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
|
|
+ break;
|
|
|
+ case 4:
|
|
|
+ solve_tri_f32_fast<64, 4>
|
|
|
+ <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
|
|
+ break;
|
|
|
+ case 2:
|
|
|
+ solve_tri_f32_fast<64, 2>
|
|
|
+ <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
|
|
+ break;
|
|
|
+ case 1:
|
|
|
+ solve_tri_f32_fast<64, 1>
|
|
|
+ <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
|
|
+ break;
|
|
|
+ default:
|
|
|
+ solve_tri_f32_fast<0, 0>
|
|
|
+ <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
|
|
|
+ }
|
|
|
+ } else { // run general case
|
|
|
+ solve_tri_f32_fast<0, 0>
|
|
|
+ <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
+ const ggml_tensor * src0 = dst->src[0]; // A (triangular n x x matrix)
|
|
|
+ const ggml_tensor * src1 = dst->src[1]; // B (right hand side of n x k equation columns)
|
|
|
+
|
|
|
+ ggml_is_contiguous(src0);
|
|
|
+ ggml_is_contiguous(src1);
|
|
|
+
|
|
|
+ const int64_t n = src0->ne[0];
|
|
|
+ const int64_t k = src1->ne[0];
|
|
|
+
|
|
|
+ GGML_ASSERT(n <= 64);
|
|
|
+ GGML_ASSERT(k <= 32);
|
|
|
+
|
|
|
+ solve_tri_f32_cuda((const float *) src0->data, (const float *) src1->data, (float *) dst->data, n, k, src0->ne[2],
|
|
|
+ src0->ne[3], src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float),
|
|
|
+ src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float), dst->nb[2] / sizeof(float),
|
|
|
+ dst->nb[3] / sizeof(float), ctx.stream());
|
|
|
+}
|