ssm-conv.cu 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. #include "ssm-conv.cuh"
  2. template <size_t split_d_inner, size_t d_conv>
  3. static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float * __restrict__ src1,
  4. const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1,
  5. float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2,
  6. const int64_t n_t) {
  7. GGML_UNUSED(src0_nb0);
  8. const int tid = threadIdx.x;
  9. const int bidx = blockIdx.x;
  10. const int bidy = blockIdx.y;
  11. const float * x_block = (const float *) ((const char *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1);
  12. const float * w_block = (const float *) ((const char *) src1 + bidy * split_d_inner * src1_nb1);
  13. float * y_block = (float *) ((char *) dst + bidx * dst_nb2 + bidy * split_d_inner * dst_nb0);
  14. const int stride_x = src0_nb1 / sizeof(float);
  15. const int stride_w = src1_nb1 / sizeof(float);
  16. const int stride_y = dst_nb1 / sizeof(float);
  17. float x[d_conv] = { 0.0f };
  18. float w[d_conv] = { 0.0f };
  19. #pragma unroll
  20. for (size_t j = 0; j < d_conv; j++) {
  21. w[j] = w_block[tid * stride_w + j];
  22. }
  23. for (int64_t i = 0; i < n_t; i++) {
  24. float sumf = 0.0f;
  25. if (i == 0) {
  26. for (size_t j = 0; j < d_conv; j++) {
  27. x[j] = x_block[tid * stride_x + j];
  28. }
  29. } else {
  30. x[(i - 1) % d_conv] = x_block[tid * stride_x + i + d_conv - 1];
  31. }
  32. #pragma unroll
  33. for (size_t j = 0; j < d_conv; j++) {
  34. sumf += x[(i + j) % d_conv] * w[j];
  35. }
  36. y_block[i * stride_y + tid] = sumf;
  37. }
  38. }
  39. template <size_t split_d_inner, size_t d_conv, int64_t split_n_t>
  40. static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, const float * __restrict__ src1,
  41. const int src0_nb0, const int src0_nb1, const int src0_nb2,
  42. const int src1_nb1, float * __restrict__ dst, const int dst_nb0,
  43. const int dst_nb1, const int dst_nb2, const int64_t n_t) {
  44. const int tid = threadIdx.x;
  45. const int bidx = blockIdx.x;
  46. const int bidy = blockIdx.y;
  47. const int bidz = blockIdx.z;
  48. const float * x_block = (const float *) ((const char *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1 +
  49. bidz * split_n_t * src0_nb0);
  50. const float * w_block = (const float *) ((const char *) src1 + bidy * split_d_inner * src1_nb1);
  51. float * y_block =
  52. (float *) ((char *) dst + bidx * dst_nb2 + bidz * split_n_t * dst_nb1 + bidy * split_d_inner * dst_nb0);
  53. const int stride_x = src0_nb1 / sizeof(float);
  54. const int stride_w = src1_nb1 / sizeof(float);
  55. const int stride_y = dst_nb1 / sizeof(float);
  56. float x[d_conv] = { 0.0f };
  57. float w[d_conv] = { 0.0f };
  58. #pragma unroll
  59. for (size_t j = 0; j < d_conv; j++) {
  60. w[j] = w_block[tid * stride_w + j];
  61. }
  62. #pragma unroll
  63. for (int64_t i = 0; i < split_n_t; i++) {
  64. if (bidz * split_n_t + i < n_t) {
  65. float sumf = 0.0f;
  66. if (i == 0) {
  67. for (size_t j = 0; j < d_conv; j++) {
  68. x[j] = x_block[tid * stride_x + j];
  69. }
  70. } else {
  71. x[(i - 1) % d_conv] = x_block[tid * stride_x + i + d_conv - 1];
  72. }
  73. #pragma unroll
  74. for (size_t j = 0; j < d_conv; j++) {
  75. sumf += x[(i + j) % d_conv] * w[j];
  76. }
  77. y_block[i * stride_y + tid] = sumf;
  78. }
  79. }
  80. }
  81. static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int src0_nb0, const int src0_nb1,
  82. const int src0_nb2, const int src1_nb1, float * dst, const int dst_nb0, const int dst_nb1,
  83. const int dst_nb2, const int64_t nc, const int64_t nr, const int64_t n_t,
  84. const int64_t n_s, cudaStream_t stream) {
  85. const int threads = 128;
  86. GGML_ASSERT(nr % threads == 0);
  87. auto launch_kernel = [&](auto NC) {
  88. constexpr int kNC = decltype(NC)::value;
  89. if (n_t <= 32) {
  90. const dim3 blocks(n_s, (nr + threads - 1) / threads, 1);
  91. ssm_conv_f32<threads, kNC><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
  92. dst, dst_nb0, dst_nb1, dst_nb2, n_t);
  93. } else {
  94. const int64_t split_n_t = 32;
  95. dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
  96. ssm_conv_long_token_f32<threads, kNC, split_n_t><<<blocks, threads, 0, stream>>>(
  97. src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t);
  98. }
  99. };
  100. switch (nc) {
  101. case 3: launch_kernel(std::integral_constant<int, 3>{}); break;
  102. case 4: launch_kernel(std::integral_constant<int, 4>{}); break;
  103. case 9: launch_kernel(std::integral_constant<int, 9>{}); break;
  104. default: GGML_ABORT("Only support kernel sizes 3, 4, 9 right now.");
  105. }
  106. }
  107. void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  108. const struct ggml_tensor * src0 = dst->src[0]; // conv_x
  109. const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight
  110. const int64_t nc = src1->ne[0]; // d_conv
  111. const int64_t nr = src0->ne[1]; // d_inner
  112. const int64_t n_t = dst->ne[1]; // tokens per sequence
  113. const int64_t n_s = dst->ne[2]; // number of sequences in the batch
  114. GGML_ASSERT(dst->ne[0] == nr);
  115. GGML_ASSERT(src0->nb[0] == sizeof(float));
  116. GGML_ASSERT(src1->nb[0] == sizeof(float));
  117. GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float));
  118. const float * src0_d = (const float *) src0->data;
  119. const float * src1_d = (const float *) src1->data;
  120. float * dst_d = (float *) dst->data;
  121. cudaStream_t stream = ctx.stream();
  122. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  123. GGML_ASSERT(dst->type == GGML_TYPE_F32);
  124. ssm_conv_f32_cuda(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, dst->nb[0], dst->nb[1],
  125. dst->nb[2], nc, nr, n_t, n_s, stream);
  126. }