Browse Source

CUDA: Fix builds for older CCCL versions by ifdefing strided_iterator (#18964)

* CUDA: Fix builds for older CCCL versions by ifdefing strided_iterator

Strided iterator was added in [CCCL
3.1](https://github.com/NVIDIA/cccl/releases/tag/v3.1.0), which is packaged into
[CTK
13.1](https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html#id5)

* Unindent as per code review request
Oliver Simons 1 week ago
parent
commit
5bd341c9a1
1 changed files with 18 additions and 1 deletions
  1. 18 1
      ggml/src/ggml-cuda/argsort.cu

+ 18 - 1
ggml/src/ggml-cuda/argsort.cu

@@ -2,6 +2,9 @@
 
 #ifdef GGML_CUDA_USE_CUB
 #    include <cub/cub.cuh>
+#    if (CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 1)
+#        define STRIDED_ITERATOR_AVAILABLE
+#    endif
 using namespace cub;
 #endif  // GGML_CUDA_USE_CUB
 
@@ -14,6 +17,14 @@ static __global__ void init_indices(int * indices, const int ncols, const int nr
     }
 }
 
+#ifndef STRIDED_ITERATOR_AVAILABLE
+static __global__ void init_offsets(int * offsets, const int ncols, const int nrows) {
+    const int idx = blockIdx.x * blockDim.x + threadIdx.x;
+    if (idx <= nrows) {
+        offsets[idx] = idx * ncols;
+    }
+}
+#endif  // STRIDED_ITERATOR_AVAILABLE
 
 #ifdef GGML_CUDA_USE_CUB
 void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
@@ -33,8 +44,14 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
     const dim3 grid_size((ncols + block_size - 1) / block_size, nrows);
     init_indices<<<grid_size, block_size, 0, stream>>>(temp_indices, ncols, nrows);
 
+#ifdef STRIDED_ITERATOR_AVAILABLE
     auto offset_iterator = cuda::make_strided_iterator(cuda::make_counting_iterator(0), ncols);
-
+#else
+    ggml_cuda_pool_alloc<int> offsets_alloc(pool, nrows + 1);
+    int *                     offset_iterator = offsets_alloc.get();
+    const dim3                offset_grid((nrows + block_size - 1) / block_size);
+    init_offsets<<<offset_grid, block_size, 0, stream>>>(offset_iterator, ncols, nrows);
+#endif
     CUDA_CHECK(cudaMemcpyAsync(temp_keys, x, ncols * nrows * sizeof(float), cudaMemcpyDeviceToDevice, stream));
 
     size_t temp_storage_bytes = 0;