|
|
@@ -3,7 +3,6 @@
|
|
|
#include "solve_tri.cuh"
|
|
|
|
|
|
#define MAX_N_FAST 64
|
|
|
-#define MAX_K_FAST 32
|
|
|
|
|
|
// ======================
|
|
|
// Fast Kernel (n <= 64, k <= 32) - Warp-based parallel reduction
|
|
|
@@ -48,65 +47,58 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A,
|
|
|
float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3);
|
|
|
|
|
|
__shared__ float sA[MAX_N_FAST * MAX_N_FAST];
|
|
|
- __shared__ float sXt[MAX_N_FAST * (MAX_K_FAST + 1)];
|
|
|
|
|
|
const int offset = threadIdx.x + threadIdx.y * blockDim.x;
|
|
|
|
|
|
#pragma unroll
|
|
|
for (int i = 0; i < n * n; i += k * WARP_SIZE) {
|
|
|
- int i0 = i + offset;
|
|
|
+ const int i0 = i + offset;
|
|
|
if (i0 < n * n) {
|
|
|
sA[i0] = A_batch[i0];
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- const int rows_per_warp = (n + WARP_SIZE - 1) / WARP_SIZE;
|
|
|
+ __syncthreads();
|
|
|
|
|
|
-#pragma unroll
|
|
|
- for (int i = 0; i < rows_per_warp; i++) {
|
|
|
- const int i0 = lane + i * WARP_SIZE;
|
|
|
- if (i0 < n) {
|
|
|
- sXt[col_idx * n + i0] = B_batch[i0 * k + col_idx];
|
|
|
- }
|
|
|
- }
|
|
|
+ float x_low = (lane < n) ? B_batch[lane * k + col_idx] : 0.0f;
|
|
|
+ float x_high = (WARP_SIZE + lane < n) ? B_batch[(WARP_SIZE + lane) * k + col_idx] : 0.0f;
|
|
|
|
|
|
- __syncthreads();
|
|
|
+ const int half = WARP_SIZE;
|
|
|
+ const int nrows_low = (n < half) ? n : half;
|
|
|
|
|
|
#pragma unroll
|
|
|
- for (int row = 0; row < n; ++row) {
|
|
|
+ for (int row = 0; row < nrows_low; ++row) {
|
|
|
float sum = 0.0f;
|
|
|
-
|
|
|
- {
|
|
|
- int j = lane;
|
|
|
- if (j < row) {
|
|
|
- sum += sA[row * n + j] * sXt[col_idx * n + j];
|
|
|
- }
|
|
|
+ if (lane < row) {
|
|
|
+ sum += sA[row * n + lane] * x_low;
|
|
|
}
|
|
|
- if (row >= WARP_SIZE) {
|
|
|
- int j = WARP_SIZE + lane;
|
|
|
- if (j < row) {
|
|
|
- sum += sA[row * n + j] * sXt[col_idx * n + j];
|
|
|
- }
|
|
|
+ sum = warp_reduce_sum(sum);
|
|
|
+
|
|
|
+ if (lane == row) {
|
|
|
+ x_low = (x_low - sum) / sA[row * n + row];
|
|
|
}
|
|
|
+ }
|
|
|
|
|
|
+#pragma unroll
|
|
|
+ for (int row = half; row < n; ++row) {
|
|
|
+ float sum = sA[row * n + lane] * x_low;
|
|
|
+ const int j = half + lane;
|
|
|
+ if (j < row) {
|
|
|
+ sum += sA[row * n + j] * x_high;
|
|
|
+ }
|
|
|
sum = warp_reduce_sum(sum);
|
|
|
|
|
|
- if (lane == 0) {
|
|
|
- const float b_val = sXt[col_idx * n + row];
|
|
|
- const float a_diag = sA[row * n + row];
|
|
|
- // no safeguards for division by zero because that indicates corrupt
|
|
|
- // data anyway
|
|
|
- sXt[col_idx * n + row] = (b_val - sum) / a_diag;
|
|
|
+ if (lane == row - half) {
|
|
|
+ x_high = (x_high - sum) / sA[row * n + row];
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- __syncthreads();
|
|
|
-
|
|
|
#pragma unroll
|
|
|
- for (int i = 0; i < rows_per_warp; i++) {
|
|
|
- const int i0 = lane + i * WARP_SIZE;
|
|
|
- if (i0 < n) {
|
|
|
- X_batch[i0 * k + col_idx] = sXt[col_idx * n + i0];
|
|
|
+ for (int rr = 0; rr < 2; ++rr) {
|
|
|
+ const int row = rr * WARP_SIZE + lane;
|
|
|
+ if (row < n) {
|
|
|
+ const float val = (row < half) ? x_low : x_high;
|
|
|
+ X_batch[row * k + col_idx] = val;
|
|
|
}
|
|
|
}
|
|
|
}
|