argsort.cu 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. #include "argsort.cuh"
  2. template<typename T>
  3. static inline __device__ void ggml_cuda_swap(T & a, T & b) {
  4. T tmp = a;
  5. a = b;
  6. b = tmp;
  7. }
  8. template<ggml_sort_order order>
  9. static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols) {
  10. // bitonic sort
  11. int col = threadIdx.x;
  12. int row = blockIdx.y;
  13. if (col >= ncols) return;
  14. const float * x_row = x + row * ncols;
  15. int * dst_row = dst + row * ncols;
  16. // initialize indices
  17. if (col < ncols) {
  18. dst_row[col] = col;
  19. }
  20. __syncthreads();
  21. for (int k = 2; k <= ncols; k *= 2) {
  22. for (int j = k / 2; j > 0; j /= 2) {
  23. int ixj = col ^ j;
  24. if (ixj > col) {
  25. if ((col & k) == 0) {
  26. if (order == GGML_SORT_ORDER_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) {
  27. ggml_cuda_swap(dst_row[col], dst_row[ixj]);
  28. }
  29. } else {
  30. if (order == GGML_SORT_ORDER_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) {
  31. ggml_cuda_swap(dst_row[col], dst_row[ixj]);
  32. }
  33. }
  34. }
  35. __syncthreads();
  36. }
  37. }
  38. }
  39. static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) {
  40. // bitonic sort requires ncols to be power of 2
  41. GGML_ASSERT((ncols & (ncols - 1)) == 0);
  42. const dim3 block_dims(ncols, 1, 1);
  43. const dim3 block_nums(1, nrows, 1);
  44. if (order == GGML_SORT_ORDER_ASC) {
  45. k_argsort_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
  46. } else if (order == GGML_SORT_ORDER_DESC) {
  47. k_argsort_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
  48. } else {
  49. GGML_ASSERT(false);
  50. }
  51. }
  52. void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  53. const ggml_tensor * src0 = dst->src[0];
  54. const float * src0_d = (const float *)src0->data;
  55. float * dst_d = (float *)dst->data;
  56. cudaStream_t stream = ctx.stream();
  57. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  58. GGML_ASSERT( dst->type == GGML_TYPE_I32);
  59. GGML_ASSERT(ggml_is_contiguous(src0));
  60. const int64_t ncols = src0->ne[0];
  61. const int64_t nrows = ggml_nrows(src0);
  62. enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
  63. argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, order, stream);
  64. }