Преглед на файлове

cuda: optimize SOLVE_TRI using registers and FMAF (#17703)

* ggml-cuda: optimize solve_tri_f32_fast and fix stride handling

- Switch from using shared memory for the RHS/solution matrix to a register-based approach (x_low, x_high), reducing shared memory pressure and bank conflicts.
- Implement explicit `fmaf` instructions for the reduction loop.
- Update kernel arguments to pass strides in bytes rather than elements to align with standard ggml tensor arithmetic (casting to `char *` before addition).
- Remove unused `MAX_K_FAST` definition.

* Small cleanup

* Remove comments in solve_tri.cu

* Update ggml/src/ggml-cuda/solve_tri.cu

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* Update ggml/src/ggml-cuda/solve_tri.cu

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* Update ggml/src/ggml-cuda/solve_tri.cu

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* Use const for variables in solve_tri.cu

* Replace fmaf with more readable code

* remove last fmaf

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
wsbagnsv1 преди 1 месец
родител
ревизия
5814b4dce1
променени са 1 файла, в които са добавени 28 реда и са изтрити 36 реда
  1. 28 36
      ggml/src/ggml-cuda/solve_tri.cu

+ 28 - 36
ggml/src/ggml-cuda/solve_tri.cu

@@ -3,7 +3,6 @@
 #include "solve_tri.cuh"
 #include "solve_tri.cuh"
 
 
 #define MAX_N_FAST 64
 #define MAX_N_FAST 64
-#define MAX_K_FAST 32
 
 
 // ======================
 // ======================
 // Fast Kernel (n <= 64, k <= 32) - Warp-based parallel reduction
 // Fast Kernel (n <= 64, k <= 32) - Warp-based parallel reduction
@@ -48,65 +47,58 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A,
     float *             X_batch = (float *) (X + i02 * nb2 + i03 * nb3);
     float *             X_batch = (float *) (X + i02 * nb2 + i03 * nb3);
 
 
     __shared__ float sA[MAX_N_FAST * MAX_N_FAST];
     __shared__ float sA[MAX_N_FAST * MAX_N_FAST];
-    __shared__ float sXt[MAX_N_FAST * (MAX_K_FAST + 1)];
 
 
     const int offset = threadIdx.x + threadIdx.y * blockDim.x;
     const int offset = threadIdx.x + threadIdx.y * blockDim.x;
 
 
 #pragma unroll
 #pragma unroll
     for (int i = 0; i < n * n; i += k * WARP_SIZE) {
     for (int i = 0; i < n * n; i += k * WARP_SIZE) {
-        int i0 = i + offset;
+        const int i0 = i + offset;
         if (i0 < n * n) {
         if (i0 < n * n) {
             sA[i0] = A_batch[i0];
             sA[i0] = A_batch[i0];
         }
         }
     }
     }
 
 
-    const int rows_per_warp = (n + WARP_SIZE - 1) / WARP_SIZE;
+    __syncthreads();
 
 
-#pragma unroll
-    for (int i = 0; i < rows_per_warp; i++) {
-        const int i0 = lane + i * WARP_SIZE;
-        if (i0 < n) {
-            sXt[col_idx * n + i0] = B_batch[i0 * k + col_idx];
-        }
-    }
+    float x_low  = (lane < n) ? B_batch[lane * k + col_idx] : 0.0f;
+    float x_high = (WARP_SIZE + lane < n) ? B_batch[(WARP_SIZE + lane) * k + col_idx] : 0.0f;
 
 
-    __syncthreads();
+    const int half = WARP_SIZE;
+    const int nrows_low = (n < half) ? n : half;
 
 
 #pragma unroll
 #pragma unroll
-    for (int row = 0; row < n; ++row) {
+    for (int row = 0; row < nrows_low; ++row) {
         float sum = 0.0f;
         float sum = 0.0f;
-
-        {
-            int j = lane;
-            if (j < row) {
-                sum += sA[row * n + j] * sXt[col_idx * n + j];
-            }
+        if (lane < row) {
+            sum += sA[row * n + lane] * x_low;
         }
         }
-        if (row >= WARP_SIZE) {
-            int j = WARP_SIZE + lane;
-            if (j < row) {
-                sum += sA[row * n + j] * sXt[col_idx * n + j];
-            }
+        sum = warp_reduce_sum(sum);
+
+        if (lane == row) {
+            x_low = (x_low - sum) / sA[row * n + row];
         }
         }
+    }
 
 
+#pragma unroll
+    for (int row = half; row < n; ++row) {
+        float sum = sA[row * n + lane] * x_low;
+        const int j = half + lane;
+        if (j < row) {
+            sum += sA[row * n + j] * x_high;
+        }
         sum = warp_reduce_sum(sum);
         sum = warp_reduce_sum(sum);
 
 
-        if (lane == 0) {
-            const float b_val      = sXt[col_idx * n + row];
-            const float a_diag     = sA[row * n + row];
-            // no safeguards for division by zero because that indicates corrupt
-            // data anyway
-            sXt[col_idx * n + row] = (b_val - sum) / a_diag;
+        if (lane == row - half) {
+            x_high = (x_high - sum) / sA[row * n + row];
         }
         }
     }
     }
 
 
-    __syncthreads();
-
 #pragma unroll
 #pragma unroll
-    for (int i = 0; i < rows_per_warp; i++) {
-        const int i0 = lane + i * WARP_SIZE;
-        if (i0 < n) {
-            X_batch[i0 * k + col_idx] = sXt[col_idx * n + i0];
+    for (int rr = 0; rr < 2; ++rr) {
+        const int row = rr * WARP_SIZE + lane;
+        if (row < n) {
+            const float val = (row < half) ? x_low : x_high;
+            X_batch[row * k + col_idx] = val;
         }
         }
     }
     }
 }
 }