binbcast.cu 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. #include "binbcast.cuh"
  2. static __device__ __forceinline__ float op_repeat(const float a, const float b) {
  3. return b;
  4. GGML_UNUSED(a);
  5. }
  6. static __device__ __forceinline__ float op_add(const float a, const float b) {
  7. return a + b;
  8. }
  9. static __device__ __forceinline__ float op_mul(const float a, const float b) {
  10. return a * b;
  11. }
  12. static __device__ __forceinline__ float op_div(const float a, const float b) {
  13. return a / b;
  14. }
  15. template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
  16. static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
  17. int ne0, int ne1, int ne2, int ne3,
  18. int ne10, int ne11, int ne12, int ne13,
  19. /*int s0, */ int s1, int s2, int s3,
  20. /*int s10,*/ int s11, int s12, int s13) {
  21. const int i0s = blockDim.x*blockIdx.x + threadIdx.x;
  22. const int i1 = (blockDim.y*blockIdx.y + threadIdx.y);
  23. const int i2 = (blockDim.z*blockIdx.z + threadIdx.z) / ne3;
  24. const int i3 = (blockDim.z*blockIdx.z + threadIdx.z) % ne3;
  25. if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
  26. return;
  27. }
  28. const int i11 = i1 % ne11;
  29. const int i12 = i2 % ne12;
  30. const int i13 = i3 % ne13;
  31. const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
  32. const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
  33. const size_t i_dst = i_src0;
  34. const src0_t * src0_row = src0 + i_src0;
  35. const src1_t * src1_row = src1 + i_src1;
  36. dst_t * dst_row = dst + i_dst;
  37. for (int i0 = i0s; i0 < ne0; i0 += blockDim.x*gridDim.x) {
  38. const int i10 = i0 % ne10;
  39. dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
  40. }
  41. }
  42. template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
  43. static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
  44. int ne0, int ne1, int ne2, int ne3,
  45. int ne10, int ne11, int ne12, int ne13,
  46. /*int s0, */ int s1, int s2, int s3,
  47. /*int s10,*/ int s11, int s12, int s13) {
  48. const int i = blockDim.x*blockIdx.x + threadIdx.x;
  49. const int i3 = i/(ne2*ne1*ne0);
  50. const int i2 = (i/(ne1*ne0)) % ne2;
  51. const int i1 = (i/ne0) % ne1;
  52. const int i0 = i % ne0;
  53. if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
  54. return;
  55. }
  56. const int i11 = i1 % ne11;
  57. const int i12 = i2 % ne12;
  58. const int i13 = i3 % ne13;
  59. const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
  60. const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
  61. const size_t i_dst = i_src0;
  62. const src0_t * src0_row = src0 + i_src0;
  63. const src1_t * src1_row = src1 + i_src1;
  64. dst_t * dst_row = dst + i_dst;
  65. const int i10 = i0 % ne10;
  66. dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
  67. }
  68. template<float (*bin_op)(const float, const float)>
  69. struct bin_bcast_cuda {
  70. template<typename src0_t, typename src1_t, typename dst_t>
  71. void operator()(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst,
  72. const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd,
  73. cudaStream_t stream) {
  74. GGML_TENSOR_BINARY_OP_LOCALS
  75. int nr0 = ne10/ne0;
  76. int nr1 = ne11/ne1;
  77. int nr2 = ne12/ne2;
  78. int nr3 = ne13/ne3;
  79. int nr[4] = { nr0, nr1, nr2, nr3 };
  80. // collapse dimensions until first broadcast dimension
  81. int64_t cne0[] = {ne0, ne1, ne2, ne3};
  82. int64_t cne1[] = {ne10, ne11, ne12, ne13};
  83. size_t cnb0[] = {nb0, nb1, nb2, nb3};
  84. size_t cnb1[] = {nb10, nb11, nb12, nb13};
  85. auto collapse = [](int64_t cne[]) {
  86. cne[0] *= cne[1];
  87. cne[1] = cne[2];
  88. cne[2] = cne[3];
  89. cne[3] = 1;
  90. };
  91. auto collapse_nb = [](size_t cnb[], const int64_t cne[]) {
  92. cnb[1] *= cne[1];
  93. cnb[2] *= cne[2];
  94. cnb[3] *= cne[3];
  95. };
  96. for (int i = 0; i < 4; i++) {
  97. if (nr[i] != 1) {
  98. break;
  99. }
  100. if (i > 0) {
  101. collapse_nb(cnb0, cne0);
  102. collapse_nb(cnb1, cne1);
  103. collapse(cne0);
  104. collapse(cne1);
  105. }
  106. }
  107. {
  108. int64_t ne0 = cne0[0];
  109. int64_t ne1 = cne0[1];
  110. int64_t ne2 = cne0[2];
  111. int64_t ne3 = cne0[3];
  112. int64_t ne10 = cne1[0];
  113. int64_t ne11 = cne1[1];
  114. int64_t ne12 = cne1[2];
  115. int64_t ne13 = cne1[3];
  116. size_t nb0 = cnb0[0];
  117. size_t nb1 = cnb0[1];
  118. size_t nb2 = cnb0[2];
  119. size_t nb3 = cnb0[3];
  120. size_t nb10 = cnb1[0];
  121. size_t nb11 = cnb1[1];
  122. size_t nb12 = cnb1[2];
  123. size_t nb13 = cnb1[3];
  124. size_t s0 = nb0 / sizeof(dst_t);
  125. size_t s1 = nb1 / sizeof(dst_t);
  126. size_t s2 = nb2 / sizeof(dst_t);
  127. size_t s3 = nb3 / sizeof(dst_t);
  128. size_t s10 = nb10 / sizeof(src1_t);
  129. size_t s11 = nb11 / sizeof(src1_t);
  130. size_t s12 = nb12 / sizeof(src1_t);
  131. size_t s13 = nb13 / sizeof(src1_t);
  132. GGML_ASSERT(s0 == 1);
  133. GGML_ASSERT(s10 == 1);
  134. const int block_size = 128;
  135. int64_t hne0 = std::max(ne0/2LL, 1LL);
  136. dim3 block_dims;
  137. block_dims.x = std::min<unsigned int>(hne0, block_size);
  138. block_dims.y = std::min<unsigned int>(ne1, block_size / block_dims.x);
  139. block_dims.z = std::min(std::min<unsigned int>(ne2*ne3, block_size / block_dims.x / block_dims.y), 64U);
  140. dim3 block_nums(
  141. (hne0 + block_dims.x - 1) / block_dims.x,
  142. (ne1 + block_dims.y - 1) / block_dims.y,
  143. (ne2*ne3 + block_dims.z - 1) / block_dims.z
  144. );
  145. if (block_nums.z > 65535) {
  146. // this is the maximum number of blocks in z direction, fallback to 1D grid kernel
  147. int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size;
  148. k_bin_bcast_unravel<bin_op><<<block_num, block_size, 0, stream>>>(
  149. src0_dd, src1_dd, dst_dd,
  150. ne0, ne1, ne2, ne3,
  151. ne10, ne11, ne12, ne13,
  152. /* s0, */ s1, s2, s3,
  153. /* s10, */ s11, s12, s13);
  154. } else {
  155. k_bin_bcast<bin_op><<<block_nums, block_dims, 0, stream>>>(
  156. src0_dd, src1_dd, dst_dd,
  157. ne0, ne1, ne2, ne3,
  158. ne10, ne11, ne12, ne13,
  159. /* s0, */ s1, s2, s3,
  160. /* s10, */ s11, s12, s13);
  161. }
  162. }
  163. }
  164. };
  165. template<class op>
  166. static void ggml_cuda_op_bin_bcast(
  167. const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
  168. const void * src0_dd, const void * src1_dd, void * dst_dd, cudaStream_t stream) {
  169. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  170. if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
  171. op()(src0, src1, dst, (const float *)src0_dd, (const float *)src1_dd, (float *)dst_dd, stream);
  172. } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
  173. op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (half *) dst_dd, stream);
  174. } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
  175. op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (float *)dst_dd, stream);
  176. } else {
  177. fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__,
  178. ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
  179. GGML_ASSERT(false);
  180. }
  181. }
  182. void ggml_cuda_op_repeat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  183. ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_repeat>>(dst, dst->src[0], dst, nullptr, dst->src[0]->data, dst->data, ctx.stream());
  184. }
  185. void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  186. ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_add>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
  187. }
  188. void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  189. ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_mul>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
  190. }
  191. void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  192. ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_div>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
  193. }