concat.cu 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. #include "concat.cuh"
  2. // contiguous kernels
  3. static __global__ void concat_f32_dim0(const float * x, const float * y, float * dst, const int ne0, const int ne00) {
  4. int nidx = threadIdx.x + blockIdx.x * blockDim.x;
  5. if (nidx >= ne0) {
  6. return;
  7. }
  8. int offset_dst =
  9. nidx +
  10. blockIdx.y * ne0 +
  11. blockIdx.z * ne0 * gridDim.y;
  12. if (nidx < ne00) { // src0
  13. int offset_src =
  14. nidx +
  15. blockIdx.y * ne00 +
  16. blockIdx.z * ne00 * gridDim.y;
  17. dst[offset_dst] = x[offset_src];
  18. } else {
  19. int offset_src =
  20. (nidx - ne00) +
  21. blockIdx.y * (ne0 - ne00) +
  22. blockIdx.z * (ne0 - ne00) * gridDim.y;
  23. dst[offset_dst] = y[offset_src];
  24. }
  25. }
  26. static __global__ void concat_f32_dim1(const float * x, const float * y, float * dst, const int ne0, const int ne01) {
  27. int nidx = threadIdx.x + blockIdx.x * blockDim.x;
  28. if (nidx >= ne0) {
  29. return;
  30. }
  31. int offset_dst =
  32. nidx +
  33. blockIdx.y * ne0 +
  34. blockIdx.z * ne0 * gridDim.y;
  35. if (blockIdx.y < ne01) { // src0
  36. int offset_src =
  37. nidx +
  38. blockIdx.y * ne0 +
  39. blockIdx.z * ne0 * ne01;
  40. dst[offset_dst] = x[offset_src];
  41. } else {
  42. int offset_src =
  43. nidx +
  44. (blockIdx.y - ne01) * ne0 +
  45. blockIdx.z * ne0 * (gridDim.y - ne01);
  46. dst[offset_dst] = y[offset_src];
  47. }
  48. }
  49. static __global__ void concat_f32_dim2(const float * x, const float * y, float * dst, const int ne0, const int ne02) {
  50. int nidx = threadIdx.x + blockIdx.x * blockDim.x;
  51. if (nidx >= ne0) {
  52. return;
  53. }
  54. int offset_dst =
  55. nidx +
  56. blockIdx.y * ne0 +
  57. blockIdx.z * ne0 * gridDim.y;
  58. if (blockIdx.z < ne02) { // src0
  59. int offset_src =
  60. nidx +
  61. blockIdx.y * ne0 +
  62. blockIdx.z * ne0 * gridDim.y;
  63. dst[offset_dst] = x[offset_src];
  64. } else {
  65. int offset_src =
  66. nidx +
  67. blockIdx.y * ne0 +
  68. (blockIdx.z - ne02) * ne0 * gridDim.y;
  69. dst[offset_dst] = y[offset_src];
  70. }
  71. }
  72. static void concat_f32_cuda(const float * x, const float * y, float * dst, int ne00, int ne01, int ne02, int ne0, int ne1, int ne2, int dim, cudaStream_t stream) {
  73. int num_blocks = (ne0 + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE;
  74. dim3 gridDim(num_blocks, ne1, ne2);
  75. if (dim == 0) {
  76. concat_f32_dim0<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne00);
  77. return;
  78. }
  79. if (dim == 1) {
  80. concat_f32_dim1<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne01);
  81. return;
  82. }
  83. concat_f32_dim2<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne02);
  84. }
  85. // non-contiguous kernel (slow)
  86. static __global__ void concat_f32_non_cont(
  87. const char * src0,
  88. const char * src1,
  89. char * dst,
  90. int64_t ne00,
  91. int64_t ne01,
  92. int64_t ne02,
  93. int64_t ne03,
  94. uint64_t nb00,
  95. uint64_t nb01,
  96. uint64_t nb02,
  97. uint64_t nb03,
  98. int64_t /*ne10*/,
  99. int64_t /*ne11*/,
  100. int64_t /*ne12*/,
  101. int64_t /*ne13*/,
  102. uint64_t nb10,
  103. uint64_t nb11,
  104. uint64_t nb12,
  105. uint64_t nb13,
  106. int64_t ne0,
  107. int64_t /*ne1*/,
  108. int64_t /*ne2*/,
  109. int64_t /*ne3*/,
  110. uint64_t nb0,
  111. uint64_t nb1,
  112. uint64_t nb2,
  113. uint64_t nb3,
  114. int32_t dim) {
  115. const int64_t i3 = blockIdx.z;
  116. const int64_t i2 = blockIdx.y;
  117. const int64_t i1 = blockIdx.x;
  118. int64_t o[4] = {0, 0, 0, 0};
  119. o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
  120. const float * x;
  121. for (int i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) {
  122. if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
  123. x = (const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00);
  124. } else {
  125. x = (const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10);
  126. }
  127. float * y = (float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
  128. *y = *x;
  129. }
  130. }
  131. void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  132. const ggml_tensor * src0 = dst->src[0];
  133. const ggml_tensor * src1 = dst->src[1];
  134. cudaStream_t stream = ctx.stream();
  135. const int32_t dim = ((int32_t *) dst->op_params)[0];
  136. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  137. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  138. GGML_ASSERT(dst->type == GGML_TYPE_F32);
  139. if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
  140. const float * src0_d = (const float *)src0->data;
  141. const float * src1_d = (const float *)src1->data;
  142. float * dst_d = (float *)dst->data;
  143. if (dim != 3) {
  144. for (int i3 = 0; i3 < dst->ne[3]; i3++) {
  145. concat_f32_cuda(
  146. src0_d + i3 * (src0->nb[3] / 4),
  147. src1_d + i3 * (src1->nb[3] / 4),
  148. dst_d + i3 * ( dst->nb[3] / 4),
  149. src0->ne[0], src0->ne[1], src0->ne[2],
  150. dst->ne[0], dst->ne[1], dst->ne[2], dim, stream);
  151. }
  152. } else {
  153. const size_t size0 = ggml_nbytes(src0);
  154. const size_t size1 = ggml_nbytes(src1);
  155. CUDA_CHECK(cudaMemcpyAsync(dst_d, src0_d, size0, cudaMemcpyDeviceToDevice, stream));
  156. CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream));
  157. }
  158. } else {
  159. dim3 grid_dim(dst->ne[1], dst->ne[2], dst->ne[3]);
  160. concat_f32_non_cont<<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(
  161. (const char *)src0->data,
  162. (const char *)src1->data,
  163. ( char *)dst->data,
  164. src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
  165. src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
  166. src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
  167. src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3],
  168. dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
  169. dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], dim);
  170. }
  171. }