getrows.cu 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284
  1. #include "getrows.cuh"
  2. #include "dequantize.cuh"
  3. #include "convert.cuh"
  4. template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
  5. static __global__ void k_get_rows(
  6. const void * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
  7. const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/
  8. /*const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /*const int64_t ne13,*/
  9. /*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3,
  10. /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
  11. const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
  12. // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
  13. const int i00 = (blockIdx.y * blockDim.x + threadIdx.x)*2;
  14. const int i10 = blockIdx.x;
  15. const int i11 = blockIdx.z / ne12;
  16. const int i12 = blockIdx.z % ne12;
  17. if (i00 >= ne00) {
  18. return;
  19. }
  20. const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
  21. dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
  22. const void * src0_row = (const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03;
  23. const int ib = i00/qk; // block index
  24. const int iqs = (i00%qk)/qr; // quant index
  25. const int iybs = i00 - i00%qk; // dst block start index
  26. const int y_offset = qr == 1 ? 1 : qk/2;
  27. // dequantize
  28. dfloat2 v;
  29. dequantize_kernel(src0_row, ib, iqs, v);
  30. dst_row[iybs + iqs + 0] = ggml_cuda_cast<dst_t>(v.x);
  31. dst_row[iybs + iqs + y_offset] = ggml_cuda_cast<dst_t>(v.y);
  32. }
  33. template<typename src0_t, typename dst_t>
  34. static __global__ void k_get_rows_float(
  35. const src0_t * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
  36. const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/
  37. /*const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /*const int64_t ne13,*/
  38. /*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3,
  39. /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
  40. const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
  41. // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
  42. const int i00 = blockIdx.y * blockDim.x + threadIdx.x;
  43. const int i10 = blockIdx.x;
  44. const int i11 = blockIdx.z / ne12;
  45. const int i12 = blockIdx.z % ne12;
  46. if (i00 >= ne00) {
  47. return;
  48. }
  49. const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
  50. dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
  51. const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03);
  52. dst_row[i00] = ggml_cuda_cast<dst_t>(src0_row[i00]);
  53. }
  54. template<typename grad_t, typename dst_t>
  55. static __global__ void k_get_rows_back_float(
  56. const grad_t * __restrict__ grad, const int32_t * __restrict__ rows, dst_t * __restrict__ dst, const int64_t ncols, const int64_t nrows_grad) {
  57. const int col = blockIdx.x*blockDim.x + threadIdx.x;
  58. if (col >= ncols) {
  59. return;
  60. }
  61. const int dst_row = blockIdx.y*blockDim.y + threadIdx.y;
  62. float sum = 0.0f;
  63. for (int64_t i = 0; i < nrows_grad; ++i) {
  64. if (rows[i] != dst_row) {
  65. continue;
  66. }
  67. sum += grad[i*ncols + col];
  68. }
  69. dst[dst_row*ncols + col] = sum;
  70. }
  71. template<int qk, int qr, dequantize_kernel_t dq, typename dst_t>
  72. static void get_rows_cuda_q(
  73. const void * src0_d, const int32_t * src1_d, dst_t * dst_d,
  74. const int64_t ne00, const size_t nb01, const size_t nb02, const size_t nb03,
  75. const int64_t ne10, const int64_t ne11, const int64_t ne12, const size_t nb10, const size_t nb11, const size_t nb12,
  76. const size_t nb1, const size_t nb2, const size_t nb3,
  77. cudaStream_t stream) {
  78. const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
  79. const int block_num_y = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
  80. const dim3 block_nums(ne10, block_num_y, ne11*ne12);
  81. // strides in elements
  82. // const size_t s0 = nb0 / sizeof(dst_t);
  83. const size_t s1 = nb1 / sizeof(dst_t);
  84. const size_t s2 = nb2 / sizeof(dst_t);
  85. const size_t s3 = nb3 / sizeof(dst_t);
  86. const size_t s10 = nb10 / sizeof(int32_t);
  87. const size_t s11 = nb11 / sizeof(int32_t);
  88. const size_t s12 = nb12 / sizeof(int32_t);
  89. // const size_t s13 = nb13 / sizeof(int32_t);
  90. GGML_ASSERT(ne00 % 2 == 0);
  91. k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(
  92. src0_d, src1_d, dst_d,
  93. ne00, /*ne01, ne02, ne03,*/
  94. /*ne10, ne11,*/ ne12, /*ne13,*/
  95. /* s0,*/ s1, s2, s3,
  96. /* nb00,*/ nb01, nb02, nb03,
  97. s10, s11, s12/*, s13*/);
  98. }
  99. template<typename src0_t, typename dst_t>
  100. static void get_rows_cuda_float(
  101. const src0_t * src0_d, const int32_t * src1_d, dst_t * dst_d,
  102. const int64_t ne00, const size_t nb01, const size_t nb02, const size_t nb03,
  103. const int64_t ne10, const int64_t ne11, const int64_t ne12, const size_t nb10, const size_t nb11, const size_t nb12,
  104. const size_t nb1, const size_t nb2, const size_t nb3,
  105. cudaStream_t stream) {
  106. const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
  107. const int block_num_y = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
  108. const dim3 block_nums(ne10, block_num_y, ne11*ne12);
  109. // strides in elements
  110. // const size_t s0 = nb0 / sizeof(dst_t);
  111. const size_t s1 = nb1 / sizeof(dst_t);
  112. const size_t s2 = nb2 / sizeof(dst_t);
  113. const size_t s3 = nb3 / sizeof(dst_t);
  114. const size_t s10 = nb10 / sizeof(int32_t);
  115. const size_t s11 = nb11 / sizeof(int32_t);
  116. const size_t s12 = nb12 / sizeof(int32_t);
  117. // const size_t s13 = nb13 / sizeof(int32_t);
  118. k_get_rows_float<<<block_nums, block_dims, 0, stream>>>(
  119. src0_d, src1_d, dst_d,
  120. ne00, /*ne01, ne02, ne03,*/
  121. /*ne10, ne11,*/ ne12, /*ne13,*/
  122. /* s0,*/ s1, s2, s3,
  123. /* nb00,*/ nb01, nb02, nb03,
  124. s10, s11, s12/*, s13*/);
  125. }
  126. template <typename dst_t>
  127. static void ggml_cuda_get_rows_switch_src0_type(
  128. const void * src0_d, const ggml_type src0_type, const int32_t * src1_d, dst_t * dst_d,
  129. const int64_t ne00, const size_t nb01, const size_t nb02, const size_t nb03,
  130. const int64_t ne10, const int64_t ne11, const int64_t ne12, const size_t nb10, const size_t nb11, const size_t nb12,
  131. const size_t nb1, const size_t nb2, const size_t nb3,
  132. cudaStream_t stream) {
  133. switch (src0_type) {
  134. case GGML_TYPE_F16:
  135. get_rows_cuda_float((const half *) src0_d, src1_d, dst_d,
  136. ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
  137. break;
  138. case GGML_TYPE_F32:
  139. get_rows_cuda_float((const float *) src0_d, src1_d, dst_d,
  140. ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
  141. break;
  142. case GGML_TYPE_I32:
  143. get_rows_cuda_float((const int32_t *) src0_d, src1_d, dst_d,
  144. ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
  145. break;
  146. case GGML_TYPE_BF16:
  147. get_rows_cuda_float((const nv_bfloat16 *) src0_d, src1_d, dst_d,
  148. ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
  149. break;
  150. case GGML_TYPE_Q4_0:
  151. get_rows_cuda_q<QK4_0, QR4_0, dequantize_q4_0>(src0_d, src1_d, dst_d,
  152. ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
  153. break;
  154. case GGML_TYPE_Q4_1:
  155. get_rows_cuda_q<QK4_1, QR4_1, dequantize_q4_1>(src0_d, src1_d, dst_d,
  156. ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
  157. break;
  158. case GGML_TYPE_Q5_0:
  159. get_rows_cuda_q<QK5_0, QR5_0, dequantize_q5_0>(src0_d, src1_d, dst_d,
  160. ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
  161. break;
  162. case GGML_TYPE_Q5_1:
  163. get_rows_cuda_q<QK5_1, QR5_1, dequantize_q5_1>(src0_d, src1_d, dst_d,
  164. ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
  165. break;
  166. case GGML_TYPE_Q8_0:
  167. get_rows_cuda_q<QK8_0, QR8_0, dequantize_q8_0>(src0_d, src1_d, dst_d,
  168. ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
  169. break;
  170. default:
  171. // TODO: k-quants
  172. GGML_ABORT("%s: unsupported src0 type: %s\n", __func__, ggml_type_name(src0_type));
  173. break;
  174. }
  175. }
  176. void get_rows_cuda(
  177. const void * src0_d, ggml_type src0_type, const int32_t * src1_d, void * dst_d, ggml_type dst_type,
  178. int64_t ne00, size_t nb01, size_t nb02, size_t nb03,
  179. int64_t ne10, int64_t ne11, int64_t ne12, size_t nb10, size_t nb11, size_t nb12,
  180. size_t nb1, size_t nb2, size_t nb3,
  181. cudaStream_t stream) {
  182. switch (dst_type) {
  183. case GGML_TYPE_F32:
  184. ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (float *) dst_d,
  185. ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
  186. break;
  187. case GGML_TYPE_I32:
  188. ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (int32_t *) dst_d,
  189. ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
  190. break;
  191. case GGML_TYPE_F16:
  192. ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (half *) dst_d,
  193. ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
  194. break;
  195. case GGML_TYPE_BF16:
  196. ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (nv_bfloat16 *) dst_d,
  197. ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
  198. break;
  199. default:
  200. GGML_ABORT("%s: unsupported dst type: %s\n", __func__, ggml_type_name(dst_type));
  201. break;
  202. }
  203. }
  204. void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  205. const ggml_tensor * src0 = dst->src[0];
  206. const ggml_tensor * src1 = dst->src[1];
  207. cudaStream_t stream = ctx.stream();
  208. GGML_TENSOR_BINARY_OP_LOCALS
  209. GGML_ASSERT(src1->type == GGML_TYPE_I32);
  210. GGML_ASSERT(ne13 == 1);
  211. GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
  212. GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
  213. GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
  214. get_rows_cuda(src0->data, src0->type, (const int32_t *) src1->data, dst->data, dst->type,
  215. ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
  216. }
  217. void ggml_cuda_op_get_rows_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  218. const ggml_tensor * src0 = dst->src[0]; // gradients of forward pass output
  219. const ggml_tensor * src1 = dst->src[1]; // src1 in forward pass
  220. GGML_TENSOR_BINARY_OP_LOCALS
  221. const float * src0_d = (const float *) src0->data;
  222. const int32_t * src1_d = (const int32_t *) src1->data;
  223. float * dst_d = (float *) dst->data;
  224. cudaStream_t stream = ctx.stream();
  225. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  226. GGML_ASSERT(src1->type == GGML_TYPE_I32);
  227. GGML_ASSERT(dst->type == GGML_TYPE_F32);
  228. GGML_ASSERT(ggml_is_contiguous(src0));
  229. GGML_ASSERT(ggml_is_contiguous(src1));
  230. GGML_ASSERT(ggml_is_contiguous(dst));
  231. GGML_ASSERT(ne02*ne03 == 1);
  232. GGML_ASSERT(ne12*ne13 == 1);
  233. GGML_ASSERT(ne2*ne3 == 1);
  234. const dim3 block_dims(CUDA_GET_ROWS_BACK_BLOCK_SIZE, 1, 1);
  235. const int block_num_x = (ne00 + CUDA_GET_ROWS_BACK_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BACK_BLOCK_SIZE;
  236. const dim3 block_nums(block_num_x, ne1, 1);
  237. k_get_rows_back_float<<<block_nums, block_dims, 0, stream>>>(src0_d, src1_d, dst_d, ne00, ne10);
  238. }