|
|
@@ -29,8 +29,8 @@ static void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
|
|
const int nrows,
|
|
|
ggml_sort_order order,
|
|
|
cudaStream_t stream) {
|
|
|
- ggml_cuda_pool_alloc<int> temp_indices_alloc(pool, ncols * nrows);
|
|
|
- ggml_cuda_pool_alloc<float> temp_keys_alloc(pool, ncols * nrows);
|
|
|
+ ggml_cuda_pool_alloc<int> temp_indices_alloc(pool, ((size_t) ncols) * nrows);
|
|
|
+ ggml_cuda_pool_alloc<float> temp_keys_alloc(pool, ((size_t) ncols) * nrows);
|
|
|
ggml_cuda_pool_alloc<int> offsets_alloc(pool, nrows + 1);
|
|
|
|
|
|
int * temp_indices = temp_indices_alloc.get();
|