|
|
@@ -5,7 +5,7 @@
|
|
|
#include "ggml.h"
|
|
|
|
|
|
#ifdef GGML_CUDA_USE_CUB
|
|
|
-# include <cub/device/device_scan.cuh>
|
|
|
+# include <cub/block/block_scan.cuh>
|
|
|
#endif // GGML_CUDA_USE_CUB
|
|
|
|
|
|
template<typename T, int BLOCK_SIZE>
|
|
|
@@ -16,12 +16,14 @@ static __global__ void cumsum_cub_kernel(
|
|
|
const int64_t s01, const int64_t s02, const int64_t s03,
|
|
|
const int64_t s1, const int64_t s2, const int64_t s3) {
|
|
|
#ifdef GGML_CUDA_USE_CUB
|
|
|
- using BlockScan = cub::BlockScan<T, BLOCK_SIZE>;
|
|
|
+ using BlockScanT = cub::BlockScan<T, BLOCK_SIZE>;
|
|
|
|
|
|
- __shared__ typename BlockScan::TempStorage temp_storage;
|
|
|
- __shared__ T block_carry; // carry from previous tile
|
|
|
+ __shared__ typename BlockScanT::TempStorage temp_storage;
|
|
|
+ __shared__ T block_carry;
|
|
|
|
|
|
const int tid = threadIdx.x;
|
|
|
+ constexpr int UNROLL_FACTOR = 4;
|
|
|
+ constexpr int TILE_SIZE = BLOCK_SIZE * UNROLL_FACTOR;
|
|
|
|
|
|
const int64_t i1 = blockIdx.x;
|
|
|
const int64_t i2 = blockIdx.y;
|
|
|
@@ -39,29 +41,38 @@ static __global__ void cumsum_cub_kernel(
|
|
|
}
|
|
|
__syncthreads();
|
|
|
|
|
|
- for (int64_t start = 0; start < ne00; start += BLOCK_SIZE) {
|
|
|
- int64_t idx = start + tid;
|
|
|
- T x = (idx < ne00) ? src_row[idx] : T(0);
|
|
|
+ for (int64_t start = 0; start < ne00; start += TILE_SIZE) {
|
|
|
+ T items[UNROLL_FACTOR];
|
|
|
+ T thread_sum = T(0);
|
|
|
|
|
|
- T inclusive;
|
|
|
- T block_total;
|
|
|
- BlockScan(temp_storage).InclusiveSum(x, inclusive, block_total);
|
|
|
+#pragma unroll
|
|
|
+ for (int i = 0; i < UNROLL_FACTOR; i++) {
|
|
|
+ int64_t idx = start + tid * UNROLL_FACTOR + i;
|
|
|
+ T val = (idx < ne00) ? src_row[idx] : T(0);
|
|
|
+ thread_sum += val;
|
|
|
+ items[i] = thread_sum;
|
|
|
+ }
|
|
|
|
|
|
+ // Block-wide scan on thread sums
|
|
|
+ T thread_prefix;
|
|
|
+ T block_total;
|
|
|
+ BlockScanT(temp_storage).InclusiveSum(thread_sum, thread_prefix, block_total);
|
|
|
__syncthreads();
|
|
|
|
|
|
- T final_val = inclusive + block_carry;
|
|
|
-
|
|
|
- // store result
|
|
|
- if (idx < ne00) {
|
|
|
- dst_row[idx] = final_val;
|
|
|
+ // Add offset to each item and store
|
|
|
+ T thread_offset = thread_prefix - thread_sum + block_carry;
|
|
|
+ #pragma unroll
|
|
|
+ for (int i = 0; i < UNROLL_FACTOR; i++) {
|
|
|
+ int64_t idx = start + tid * UNROLL_FACTOR + i;
|
|
|
+ if (idx < ne00) {
|
|
|
+ dst_row[idx] = items[i] + thread_offset;
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
- __syncthreads();
|
|
|
-
|
|
|
+ // Update carry for next tile
|
|
|
if (tid == 0) {
|
|
|
block_carry += block_total;
|
|
|
}
|
|
|
-
|
|
|
__syncthreads();
|
|
|
}
|
|
|
#else
|
|
|
@@ -200,7 +211,7 @@ static void cumsum_cuda(
|
|
|
const int warps_per_block = block_size / warp_size;
|
|
|
const size_t shmem_size = (block_size + warps_per_block + 2) * sizeof(float);
|
|
|
|
|
|
- if (use_cub) {
|
|
|
+ if (use_cub && ne00 >= 1024) {
|
|
|
cumsum_cub_kernel<T, CUDA_CUMSUM_BLOCK_SIZE><<<grid_dims, CUDA_CUMSUM_BLOCK_SIZE, 0, stream>>>(
|
|
|
src, dst,
|
|
|
ne00, ne01, ne02, ne03,
|