argsort.cl 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. #pragma OPENCL EXTENSION cl_khr_fp16 : enable
  2. #ifdef cl_intel_subgroups
  3. #pragma OPENCL EXTENSION cl_intel_subgroups : enable
  4. #else
  5. #pragma OPENCL EXTENSION cl_khr_subgroups : enable
  6. #endif
  7. #ifdef cl_intel_required_subgroup_size
  8. #pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
  9. #define INTEL_GPU 1
  10. #define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
  11. #define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
  12. #elif defined(cl_qcom_reqd_sub_group_size)
  13. #pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
  14. #define ADRENO_GPU 1
  15. #define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
  16. #define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
  17. #endif
  18. #define SWAP(x, y, T) { T tmp = (x); (x) = (y); (y) = tmp; }
  19. enum ggml_sort_order {
  20. GGML_SORT_ORDER_ASC,
  21. GGML_SORT_ORDER_DESC,
  22. };
  23. kernel void kernel_argsort_f32_i32(
  24. global float * src0,
  25. ulong offset0,
  26. global int * dst,
  27. ulong offsetd,
  28. const int ne00,
  29. const int ne00_pad,
  30. const int order,
  31. local int * dst_row
  32. ) {
  33. // bitonic sort
  34. int col = get_local_id(0);
  35. int row = get_group_id(1);
  36. if (col >= ne00_pad) {
  37. return;
  38. }
  39. src0 = (global char *)((global char *)src0 + offset0);
  40. dst = (global float *)((global char *)dst + offsetd);
  41. global float * x_row = src0 + row * ne00;
  42. // initialize indices
  43. dst_row[col] = col;
  44. barrier(CLK_LOCAL_MEM_FENCE);
  45. for (int k = 2; k <= ne00_pad; k *= 2) {
  46. for (int j = k / 2; j > 0; j /= 2) {
  47. int ixj = col ^ j;
  48. if (ixj > col) {
  49. if ((col & k) == 0) {
  50. if (dst_row[col] >= ne00 ||
  51. (dst_row[ixj] < ne00 && (order == GGML_SORT_ORDER_ASC ?
  52. x_row[dst_row[col]] > x_row[dst_row[ixj]] :
  53. x_row[dst_row[col]] < x_row[dst_row[ixj]]))
  54. ) {
  55. SWAP(dst_row[col], dst_row[ixj], int);
  56. }
  57. } else {
  58. if (dst_row[ixj] >= ne00 ||
  59. (dst_row[col] < ne00 && (order == GGML_SORT_ORDER_ASC ?
  60. x_row[dst_row[col]] < x_row[dst_row[ixj]] :
  61. x_row[dst_row[col]] > x_row[dst_row[ixj]]))
  62. ) {
  63. SWAP(dst_row[col], dst_row[ixj], int);
  64. }
  65. }
  66. }
  67. barrier(CLK_LOCAL_MEM_FENCE);
  68. }
  69. }
  70. // copy the result to dst without the padding
  71. if (col < ne00) {
  72. dst[row * ne00 + col] = dst_row[col];
  73. }
  74. }