Parcourir la source

SOLVE_TRI extension to more dimensions (#17793)

* Extended TRI

* Fix whitespace

* chore: update webui build output

* Just use cuBLAS for everything...

* Merge both versions

* Remove incorrect imports causing failures for CI

* Still failing... remove all direct cublas imports and rely on common imports from "common.cuh"

* Defines for hipBlas

* Aaaand MUSA defines...

* I hate this job...

* Stupid typo...

* Update ggml/src/ggml-cuda/solve_tri.cu

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Piotr Wilkin (ilintar) il y a 1 mois
Parent
commit
53ecd4fdb9

+ 2 - 2
ggml/src/ggml-cuda/ggml-cuda.cu

@@ -4630,9 +4630,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_CUMSUM:
         case GGML_OP_TRI:
         case GGML_OP_DIAG:
-            return true;
         case GGML_OP_SOLVE_TRI:
-            return op->src[0]->ne[0] <= 64 && op->src[1]->ne[0] <= 32;
+            return true;
+
         default:
             return false;
     }

+ 95 - 15
ggml/src/ggml-cuda/solve_tri.cu

@@ -3,6 +3,80 @@
 #include "solve_tri.cuh"
 
 #define MAX_N_FAST 64
+#define MAX_K_FAST 32
+
+static __global__ void get_batch_pointers(const float *  A,
+                                          float *        X,
+                                          const float ** A_ptrs,
+                                          float **       X_ptrs,
+                                          int64_t        ne02,
+                                          int64_t        total_batches,
+                                          size_t         s02,
+                                          size_t         s03,
+                                          size_t         s2,
+                                          size_t         s3) {
+    const int idx = blockIdx.x * blockDim.x + threadIdx.x;
+    if (idx >= total_batches) {
+        return;
+    }
+
+    const int64_t i3 = idx / ne02;
+    const int64_t i2 = idx % ne02;
+
+    A_ptrs[idx] = A + i3 * s03 + i2 * s02;
+    X_ptrs[idx] = X + i3 * s3 + i2 * s2;
+}
+
+static void solve_tri_f32_cublas(ggml_backend_cuda_context & ctx,
+                                 const float *               A,
+                                 const float *               B,
+                                 float *                     X,
+                                 int                         n,
+                                 int                         k,
+                                 int64_t                     ne02,
+                                 int64_t                     ne03,
+                                 size_t                      s02,
+                                 size_t                      s03,
+                                 size_t                      s12,
+                                 size_t                      s13,
+                                 size_t                      s2,
+                                 size_t                      s3,
+                                 cudaStream_t                stream) {
+    const float   alpha         = 1.0f;
+    const int64_t total_batches = ne02 * ne03;
+    if (total_batches == 0) {
+        return;
+    }
+
+    // Bulk copy B -> X (contiguous tensors)
+    if (X != B) {
+        const int64_t total_elements_BX = n * k * total_batches;
+        CUDA_CHECK(cudaMemcpyAsync(X, B, total_elements_BX * sizeof(float), cudaMemcpyDeviceToDevice, stream));
+    }
+
+    const int id = ggml_cuda_get_device();
+
+    ggml_cuda_pool_alloc<const float *> A_ptrs_alloc(ctx.pool(id), total_batches);
+    ggml_cuda_pool_alloc<float *>       X_ptrs_alloc(ctx.pool(id), total_batches);
+
+    const float ** A_ptrs_dev = A_ptrs_alloc.get();
+    float **       X_ptrs_dev = X_ptrs_alloc.get();
+
+    get_batch_pointers<<<(total_batches + 255) / 256, 256, 0, stream>>>(A, X, A_ptrs_dev, X_ptrs_dev, ne02,
+                                                                        total_batches, s02, s03, s2, s3);
+
+    CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
+
+    // Yes, this is necessary, without this we get RMSE errors
+    CUBLAS_CHECK(cublasSetMathMode(ctx.cublas_handle(id), CUBLAS_DEFAULT_MATH));
+    CUBLAS_CHECK(cublasStrsmBatched(ctx.cublas_handle(id), CUBLAS_SIDE_RIGHT, CUBLAS_FILL_MODE_UPPER, CUBLAS_OP_N,
+                                    CUBLAS_DIAG_NON_UNIT, k, n, &alpha, A_ptrs_dev, n, X_ptrs_dev, k, total_batches));
+
+    // revert to standard mode from common.cuh
+    CUBLAS_CHECK(cublasSetMathMode(ctx.cublas_handle(id), CUBLAS_TF32_TENSOR_OP_MATH));
+
+    GGML_UNUSED_VARS(s12, s13);
+}
 
 // ======================
 // Fast Kernel (n <= 64, k <= 32) - Warp-based parallel reduction
@@ -63,7 +137,7 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A,
     float x_low  = (lane < n) ? B_batch[lane * k + col_idx] : 0.0f;
     float x_high = (WARP_SIZE + lane < n) ? B_batch[(WARP_SIZE + lane) * k + col_idx] : 0.0f;
 
-    const int half = WARP_SIZE;
+    const int half      = WARP_SIZE;
     const int nrows_low = (n < half) ? n : half;
 
 #pragma unroll
@@ -81,8 +155,8 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A,
 
 #pragma unroll
     for (int row = half; row < n; ++row) {
-        float sum = sA[row * n + lane] * x_low;
-        const int j = half + lane;
+        float     sum = sA[row * n + lane] * x_low;
+        const int j   = half + lane;
         if (j < row) {
             sum += sA[row * n + j] * x_high;
         }
@@ -97,7 +171,7 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A,
     for (int rr = 0; rr < 2; ++rr) {
         const int row = rr * WARP_SIZE + lane;
         if (row < n) {
-            const float val = (row < half) ? x_low : x_high;
+            const float val            = (row < half) ? x_low : x_high;
             X_batch[row * k + col_idx] = val;
         }
     }
@@ -176,20 +250,26 @@ static void solve_tri_f32_cuda(const float * A,
 }
 
 void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    const ggml_tensor * src0 = dst->src[0];  // A (triangular n x x matrix)
-    const ggml_tensor * src1 = dst->src[1];  // B (right hand side of n x k equation columns)
+    const ggml_tensor * src0 = dst->src[0];  // A (n×n, lower triangular)
+    const ggml_tensor * src1 = dst->src[1];  // B (n×k)
 
     ggml_is_contiguous(src0);
     ggml_is_contiguous(src1);
 
-    const int64_t n = src0->ne[0];
-    const int64_t k = src1->ne[0];
+    const int64_t n    = src0->ne[0];
+    const int64_t k    = src1->ne[0];
+    const int64_t ne02 = src0->ne[2];
+    const int64_t ne03 = src0->ne[3];
 
-    GGML_ASSERT(n <= 64);
-    GGML_ASSERT(k <= 32);
-
-    solve_tri_f32_cuda((const float *) src0->data, (const float *) src1->data, (float *) dst->data, n, k, src0->ne[2],
-                       src0->ne[3], src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float),
-                       src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float), dst->nb[2] / sizeof(float),
-                       dst->nb[3] / sizeof(float), ctx.stream());
+    if (n <= MAX_N_FAST && k <= MAX_K_FAST) {
+        solve_tri_f32_cuda((const float *) src0->data, (const float *) src1->data, (float *) dst->data, n, k,
+                           src0->ne[2], src0->ne[3], src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float),
+                           src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float), dst->nb[2] / sizeof(float),
+                           dst->nb[3] / sizeof(float), ctx.stream());
+    } else {
+        solve_tri_f32_cublas(ctx, (const float *) src0->data, (const float *) src1->data, (float *) dst->data, n, k,
+                             ne02, ne03, src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float),
+                             src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float), dst->nb[2] / sizeof(float),
+                             dst->nb[3] / sizeof(float), ctx.stream());
+    }
 }

+ 4 - 0
ggml/src/ggml-cuda/vendors/hip.h

@@ -19,6 +19,9 @@
 #define CUDA_R_16F  HIPBLAS_R_16F
 #define CUDA_R_16BF HIPBLAS_R_16B
 #define CUDA_R_32F  HIPBLAS_R_32F
+#define CUBLAS_SIDE_RIGHT HIPBLAS_SIDE_RIGHT
+#define CUBLAS_FILL_MODE_UPPER HIPBLAS_FILL_MODE_UPPER
+#define CUBLAS_DIAG_NON_UNIT HIPBLAS_DIAG_NON_UNIT
 #define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED hipDeviceAttributeVirtualMemoryManagementSupported
 #define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED hipMemAllocationGranularityRecommended
 #define CU_MEM_ALLOCATION_TYPE_PINNED hipMemAllocationTypePinned
@@ -30,6 +33,7 @@
 #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
 #define __all_sync(mask, var) __all(var)
 #define __any_sync(mask, var) __any(var)
+#define cublasStrsmBatched hipblasStrsmBatched
 #define cublasCreate hipblasCreate
 #define cublasDestroy hipblasDestroy
 #define cublasGemmEx hipblasGemmEx

+ 5 - 0
ggml/src/ggml-cuda/vendors/musa.h

@@ -12,11 +12,16 @@
 #define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT
 #define CUBLAS_OP_N MUBLAS_OP_N
 #define CUBLAS_OP_T MUBLAS_OP_T
+#define CUBLAS_DEFAULT_MATH MUBLAS_DEFAULT_MATH
+#define CUBLAS_SIDE_RIGHT MUBLAS_SIDE_RIGHT
+#define CUBLAS_FILL_MODE_UPPER MUBLAS_FILL_MODE_UPPER
+#define CUBLAS_DIAG_NON_UNIT MUBLAS_DIAG_NON_UNIT
 #define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS
 #define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_TENSOR_OP_MATH
 #define CUDA_R_16F  MUSA_R_16F
 #define CUDA_R_16BF MUSA_R_16BF
 #define CUDA_R_32F  MUSA_R_32F
+#define cublasStrsmBatched mublasStrsmBatched
 #define cublasComputeType_t cudaDataType_t
 #define cublasCreate mublasCreate
 #define cublasDestroy mublasDestroy

+ 19 - 3
tests/test-backend-ops.cpp

@@ -7861,9 +7861,24 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
     test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 30, 30, 7, 1 }, { 8, 30, 7, 1 }));
     test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 42, 42, 5, 2 }, { 10, 42, 5, 2 }));
     test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 2, 2 }, { 10, 64, 2, 2 }));
+    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 2, 2 }, { 64, 64, 2, 2 }));
+    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 79, 79, 5, 3 }, { 417, 79, 5, 3 }));
+    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 2 }, { 32, 128, 4, 2 }));
+    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 80, 80, 2, 8 }, { 80, 80, 2, 8 }));
+    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 80, 80, 2, 8 }, { 79, 80, 2, 8 }));
+    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 80, 80, 2, 8 }, { 81, 80, 2, 8 }));
+    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 80, 80, 8, 8 }, { 80, 80, 8, 8 }));
+    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 80, 80, 8, 8 }, { 79, 80, 8, 8 }));
+    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 80, 80, 8, 8 }, { 81, 80, 8, 8 }));
+    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 84, 84, 4, 4 }, { 32, 84, 4, 4 }));
+    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 95, 95, 8, 8 }, { 40, 95, 8, 8 }));
     test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 100, 100, 4, 4 }, { 41, 100, 4, 4 }));
     test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 4 }, { 31, 128, 4, 4 }));
-    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 4 }, { 300, 64, 4, 4 }));
+    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 4 }, { 32, 128, 4, 4 }));
+    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 3, 4 }, { 32, 128, 3, 4 }));
+    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 1 }, { 32, 128, 4, 1 }));
+    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 4 }, { 200, 64, 4, 4 }));
+    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 4 }, { 384, 64, 4, 4 }));
 
     for (bool v : {false, true}) {
         for (bool circular : {false, true}) {
@@ -8064,12 +8079,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
     test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16416, 1, 128, {8,  1}, {4, 1}, {0, 2, 1, 3}));
     test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 1, 16416, {8,  1}, {4, 1}, {0, 1, 2, 3}, 2*16416));
 
-    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 2 }, { 6, 64, 4, 2 }));
-    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 1 }, { 8, 128, 4, 1 }));
+    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 4 }, { 32, 64, 4, 4 }));
+    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 2 }, { 32, 128, 4, 2 }));
     // qwen3next with CHUNK_SIZE 64
     test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 8, 32 }, { 64, 64, 8, 32 }));
     // qwen3next with CHUNK_SIZE 128
     test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 32 }, { 128, 128, 4, 32 }));
+    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 256, 256, 4, 2 }, { 128, 256, 4, 2 }));
 
     test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER, GGML_TYPE_F32, { 256, 256, 4, 4 }));
     test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER_DIAG, GGML_TYPE_F32, { 1024, 1024, 8, 4 }));

BIN
tools/server/public/index.html.gz