|
|
@@ -14,12 +14,6 @@ static __global__ void init_indices(int * indices, const int ncols, const int nr
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-static __global__ void init_offsets(int * offsets, const int ncols, const int nrows) {
|
|
|
- const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
|
- if (idx <= nrows) {
|
|
|
- offsets[idx] = idx * ncols;
|
|
|
- }
|
|
|
-}
|
|
|
|
|
|
#ifdef GGML_CUDA_USE_CUB
|
|
|
void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
|
|
@@ -31,18 +25,15 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
|
|
cudaStream_t stream) {
|
|
|
ggml_cuda_pool_alloc<int> temp_indices_alloc(pool, ncols * nrows);
|
|
|
ggml_cuda_pool_alloc<float> temp_keys_alloc(pool, ncols * nrows);
|
|
|
- ggml_cuda_pool_alloc<int> offsets_alloc(pool, nrows + 1);
|
|
|
|
|
|
int * temp_indices = temp_indices_alloc.get();
|
|
|
float * temp_keys = temp_keys_alloc.get();
|
|
|
- int * d_offsets = offsets_alloc.get();
|
|
|
|
|
|
static const int block_size = 256;
|
|
|
const dim3 grid_size((ncols + block_size - 1) / block_size, nrows);
|
|
|
init_indices<<<grid_size, block_size, 0, stream>>>(temp_indices, ncols, nrows);
|
|
|
|
|
|
- const dim3 offset_grid((nrows + block_size - 1) / block_size);
|
|
|
- init_offsets<<<offset_grid, block_size, 0, stream>>>(d_offsets, ncols, nrows);
|
|
|
+ auto offset_iterator = cuda::make_strided_iterator(cuda::make_counting_iterator(0), ncols);
|
|
|
|
|
|
CUDA_CHECK(cudaMemcpyAsync(temp_keys, x, ncols * nrows * sizeof(float), cudaMemcpyDeviceToDevice, stream));
|
|
|
|
|
|
@@ -57,7 +48,7 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
|
|
DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
|
|
temp_indices, dst, // values (indices)
|
|
|
ncols * nrows, nrows, // num items, num segments
|
|
|
- d_offsets, d_offsets + 1, stream);
|
|
|
+ offset_iterator, offset_iterator + 1, stream);
|
|
|
}
|
|
|
} else {
|
|
|
if (nrows == 1) {
|
|
|
@@ -66,7 +57,8 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
|
|
ncols, 0, sizeof(float) * 8, stream);
|
|
|
} else {
|
|
|
DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices,
|
|
|
- dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, stream);
|
|
|
+ dst, ncols * nrows, nrows, offset_iterator, offset_iterator + 1,
|
|
|
+ stream);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -80,7 +72,7 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
|
|
ncols, 0, sizeof(float) * 8, stream);
|
|
|
} else {
|
|
|
DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
|
|
|
- ncols * nrows, nrows, d_offsets, d_offsets + 1, stream);
|
|
|
+ ncols * nrows, nrows, offset_iterator, offset_iterator + 1, stream);
|
|
|
}
|
|
|
} else {
|
|
|
if (nrows == 1) {
|
|
|
@@ -89,8 +81,8 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
|
|
ncols, 0, sizeof(float) * 8, stream);
|
|
|
} else {
|
|
|
DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
|
|
|
- temp_indices, dst, ncols * nrows, nrows, d_offsets, d_offsets + 1,
|
|
|
- stream);
|
|
|
+ temp_indices, dst, ncols * nrows, nrows, offset_iterator,
|
|
|
+ offset_iterator + 1, stream);
|
|
|
}
|
|
|
}
|
|
|
}
|