Browse Source

CUDA: Replace init_offsets kernel with iterators in cub-based argsort (#18930)

* CUDA: Replace `init_offsets` with iterators in argsort

This is a QOL improvement, saving us the cost of materializing the
iterator

* Remove unnecessary include from top-k.cu
Oliver Simons 1 week ago
parent
commit
d1e3556481
2 changed files with 7 additions and 16 deletions
  1. 7 15
      ggml/src/ggml-cuda/argsort.cu
  2. 0 1
      ggml/src/ggml-cuda/top-k.cu

+ 7 - 15
ggml/src/ggml-cuda/argsort.cu

@@ -14,12 +14,6 @@ static __global__ void init_indices(int * indices, const int ncols, const int nr
     }
     }
 }
 }
 
 
-static __global__ void init_offsets(int * offsets, const int ncols, const int nrows) {
-    const int idx = blockIdx.x * blockDim.x + threadIdx.x;
-    if (idx <= nrows) {
-        offsets[idx] = idx * ncols;
-    }
-}
 
 
 #ifdef GGML_CUDA_USE_CUB
 #ifdef GGML_CUDA_USE_CUB
 void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
 void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
@@ -31,18 +25,15 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
                               cudaStream_t     stream) {
                               cudaStream_t     stream) {
     ggml_cuda_pool_alloc<int>   temp_indices_alloc(pool, ncols * nrows);
     ggml_cuda_pool_alloc<int>   temp_indices_alloc(pool, ncols * nrows);
     ggml_cuda_pool_alloc<float> temp_keys_alloc(pool, ncols * nrows);
     ggml_cuda_pool_alloc<float> temp_keys_alloc(pool, ncols * nrows);
-    ggml_cuda_pool_alloc<int>   offsets_alloc(pool, nrows + 1);
 
 
     int *   temp_indices = temp_indices_alloc.get();
     int *   temp_indices = temp_indices_alloc.get();
     float * temp_keys    = temp_keys_alloc.get();
     float * temp_keys    = temp_keys_alloc.get();
-    int *   d_offsets    = offsets_alloc.get();
 
 
     static const int block_size = 256;
     static const int block_size = 256;
     const dim3 grid_size((ncols + block_size - 1) / block_size, nrows);
     const dim3 grid_size((ncols + block_size - 1) / block_size, nrows);
     init_indices<<<grid_size, block_size, 0, stream>>>(temp_indices, ncols, nrows);
     init_indices<<<grid_size, block_size, 0, stream>>>(temp_indices, ncols, nrows);
 
 
-    const dim3 offset_grid((nrows + block_size - 1) / block_size);
-    init_offsets<<<offset_grid, block_size, 0, stream>>>(d_offsets, ncols, nrows);
+    auto offset_iterator = cuda::make_strided_iterator(cuda::make_counting_iterator(0), ncols);
 
 
     CUDA_CHECK(cudaMemcpyAsync(temp_keys, x, ncols * nrows * sizeof(float), cudaMemcpyDeviceToDevice, stream));
     CUDA_CHECK(cudaMemcpyAsync(temp_keys, x, ncols * nrows * sizeof(float), cudaMemcpyDeviceToDevice, stream));
 
 
@@ -57,7 +48,7 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
             DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys,  // keys (in-place)
             DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys,  // keys (in-place)
                                            temp_indices, dst,                                  // values (indices)
                                            temp_indices, dst,                                  // values (indices)
                                            ncols * nrows, nrows,  // num items, num segments
                                            ncols * nrows, nrows,  // num items, num segments
-                                           d_offsets, d_offsets + 1, stream);
+                                           offset_iterator, offset_iterator + 1, stream);
         }
         }
     } else {
     } else {
         if (nrows == 1) {
         if (nrows == 1) {
@@ -66,7 +57,8 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
                                                  ncols, 0, sizeof(float) * 8, stream);
                                                  ncols, 0, sizeof(float) * 8, stream);
         } else {
         } else {
             DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices,
             DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices,
-                                                     dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, stream);
+                                                     dst, ncols * nrows, nrows, offset_iterator, offset_iterator + 1,
+                                                     stream);
         }
         }
     }
     }
 
 
@@ -80,7 +72,7 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
                                        ncols, 0, sizeof(float) * 8, stream);
                                        ncols, 0, sizeof(float) * 8, stream);
         } else {
         } else {
             DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
             DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
-                                           ncols * nrows, nrows, d_offsets, d_offsets + 1, stream);
+                                           ncols * nrows, nrows, offset_iterator, offset_iterator + 1, stream);
         }
         }
     } else {
     } else {
         if (nrows == 1) {
         if (nrows == 1) {
@@ -89,8 +81,8 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
                                                  ncols, 0, sizeof(float) * 8, stream);
                                                  ncols, 0, sizeof(float) * 8, stream);
         } else {
         } else {
             DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
             DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
-                                                     temp_indices, dst, ncols * nrows, nrows, d_offsets, d_offsets + 1,
-                                                     stream);
+                                                     temp_indices, dst, ncols * nrows, nrows, offset_iterator,
+                                                     offset_iterator + 1, stream);
         }
         }
     }
     }
 }
 }

+ 0 - 1
ggml/src/ggml-cuda/top-k.cu

@@ -4,7 +4,6 @@
 #ifdef GGML_CUDA_USE_CUB
 #ifdef GGML_CUDA_USE_CUB
 #    include <cub/cub.cuh>
 #    include <cub/cub.cuh>
 #    if (CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2)
 #    if (CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2)
-#        include <cuda/iterator>
 #        define CUB_TOP_K_AVAILABLE
 #        define CUB_TOP_K_AVAILABLE
 using namespace cub;
 using namespace cub;
 #    endif  // CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2
 #    endif  // CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2