mma.cuh 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. #include "common.cuh"
  2. struct mma_int_A_I16K4 {
  3. static constexpr int I = 16;
  4. static constexpr int K = 4;
  5. static constexpr int ne = 2;
  6. int x[ne] = {0};
  7. static __device__ __forceinline__ int get_i(const int l) {
  8. const int ret = (l%2) * (I/2) + threadIdx.x / K;
  9. GGML_CUDA_ASSUME(ret >= 0);
  10. GGML_CUDA_ASSUME(ret < I);
  11. return ret;
  12. }
  13. static __device__ __forceinline__ int get_k(const int /* l */) {
  14. const int ret = threadIdx.x % K;
  15. GGML_CUDA_ASSUME(ret >= 0);
  16. GGML_CUDA_ASSUME(ret < K);
  17. return ret;
  18. }
  19. __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
  20. #if defined(INT8_MMA_AVAILABLE)
  21. const int * xs = xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
  22. asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
  23. : "+r"(x[0]), "+r"(x[1])
  24. : "l"(xs));
  25. #else
  26. #pragma unroll
  27. for (int l = 0; l < ne; ++l) {
  28. x[l] = xs0[get_i(l)*stride + get_k(l)];
  29. }
  30. #endif // defined(INT8_MMA_AVAILABLE)
  31. }
  32. };
  33. struct mma_int_A_I16K8 {
  34. static constexpr int I = 16;
  35. static constexpr int K = 8;
  36. static constexpr int ne = 4;
  37. int x[ne] = {0};
  38. static __device__ __forceinline__ int get_i(const int l) {
  39. const int ret = (l%2) * (I/2) + threadIdx.x / (K/2);
  40. GGML_CUDA_ASSUME(ret >= 0);
  41. GGML_CUDA_ASSUME(ret < I);
  42. return ret;
  43. }
  44. static __device__ __forceinline__ int get_k(const int l) {
  45. const int ret = (l/2) * (K/2) + threadIdx.x % (K/2);
  46. GGML_CUDA_ASSUME(ret >= 0);
  47. GGML_CUDA_ASSUME(ret < K);
  48. return ret;
  49. }
  50. __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
  51. #if defined(INT8_MMA_AVAILABLE)
  52. const int * xs = xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
  53. asm("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
  54. : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
  55. : "l"(xs));
  56. #else
  57. #pragma unroll
  58. for (int l = 0; l < ne; ++l) {
  59. x[l] = xs0[get_i(l)*stride + get_k(l)];
  60. }
  61. #endif // defined(INT8_MMA_AVAILABLE)
  62. }
  63. };
  64. struct mma_int_B_J8K4 {
  65. static constexpr int J = 8;
  66. static constexpr int K = 4;
  67. static constexpr int ne = 1;
  68. int x[ne] = {0};
  69. static __device__ __forceinline__ int get_j(const int /* l */) {
  70. const int ret = threadIdx.x / K;
  71. GGML_CUDA_ASSUME(ret >= 0);
  72. GGML_CUDA_ASSUME(ret < J);
  73. return ret;
  74. }
  75. static __device__ __forceinline__ int get_k(const int /* l */) {
  76. const int ret = threadIdx.x % K;
  77. GGML_CUDA_ASSUME(ret >= 0);
  78. GGML_CUDA_ASSUME(ret < K);
  79. return ret;
  80. }
  81. __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
  82. #if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster
  83. const int * xs = xs0 + (threadIdx.x%J)*stride;
  84. asm("ldmatrix.sync.aligned.m8n8.x1.b16 {%0}, [%1];"
  85. : "+r"(x[0])
  86. : "l"(xs));
  87. #else
  88. #pragma unroll
  89. for (int l = 0; l < ne; ++l) {
  90. x[l] = xs0[get_j(l)*stride + get_k(l)];
  91. }
  92. #endif // defined(INT8_MMA_AVAILABLE)
  93. }
  94. };
  95. struct mma_int_B_J8K8 {
  96. static constexpr int J = 8;
  97. static constexpr int K = 8;
  98. static constexpr int ne = 2;
  99. int x[ne] = {0};
  100. static __device__ __forceinline__ int get_j(const int /* l */) {
  101. const int ret = threadIdx.x / (K/2);
  102. GGML_CUDA_ASSUME(ret >= 0);
  103. GGML_CUDA_ASSUME(ret < J);
  104. return ret;
  105. }
  106. static __device__ __forceinline__ int get_k(const int l) {
  107. const int ret = l * (K/2) + threadIdx.x % (K/2);
  108. GGML_CUDA_ASSUME(ret >= 0);
  109. GGML_CUDA_ASSUME(ret < K);
  110. return ret;
  111. }
  112. __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
  113. #if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster
  114. const int * xs = xs0 + (threadIdx.x%J)*stride + ((threadIdx.x/J)*(K/2)) % K;
  115. asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
  116. : "+r"(x[0]), "+r"(x[1])
  117. : "l"(xs));
  118. #else
  119. #pragma unroll
  120. for (int l = 0; l < ne; ++l) {
  121. x[l] = xs0[get_j(l)*stride + get_k(l)];
  122. }
  123. #endif // defined(INT8_MMA_AVAILABLE)
  124. }
  125. };
  126. struct mma_int_C_I16J8 {
  127. static constexpr int I = 16;
  128. static constexpr int J = 8;
  129. static constexpr int ne = 4;
  130. int x[ne] = {0};
  131. static __device__ __forceinline__ int get_i(const int l) {
  132. const int ret = (l/2) * (I/2) + threadIdx.x / (J/2);
  133. GGML_CUDA_ASSUME(ret >= 0);
  134. GGML_CUDA_ASSUME(ret < I);
  135. return ret;
  136. }
  137. static __device__ __forceinline__ int get_j(const int l) {
  138. const int ret = 2 * (threadIdx.x % (J/2)) + l%2;
  139. GGML_CUDA_ASSUME(ret >= 0);
  140. GGML_CUDA_ASSUME(ret < J);
  141. return ret;
  142. }
  143. __device__ __forceinline__ void mma_K4(const mma_int_A_I16K4 & mma_A, const mma_int_B_J8K4 & mma_B) {
  144. #ifdef INT8_MMA_AVAILABLE
  145. #if __CUDA_ARCH__ >= CC_AMPERE
  146. asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
  147. : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
  148. : "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_B.x[0]));
  149. #else
  150. // On Turing m16n8k16 mma is not available, use 2x m8n8k16 mma instead:
  151. asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
  152. : "+r"(x[0]), "+r"(x[1])
  153. : "r"(mma_A.x[0]), "r"(mma_B.x[0]));
  154. asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
  155. : "+r"(x[2]), "+r"(x[3])
  156. : "r"(mma_A.x[1]), "r"(mma_B.x[0]));
  157. #endif // __CUDA_ARCH__ >= CC_AMPERE
  158. #else
  159. GGML_UNUSED(mma_A);
  160. GGML_UNUSED(mma_B);
  161. NO_DEVICE_CODE;
  162. #endif // INT8_MMA_AVAILABLE
  163. }
  164. __device__ __forceinline__ void mma_K8(const mma_int_A_I16K8 & mma_A, const mma_int_B_J8K8 & mma_B) {
  165. #ifdef INT8_MMA_AVAILABLE
  166. #if __CUDA_ARCH__ >= CC_AMPERE
  167. asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
  168. : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
  169. : "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_A.x[2]), "r"(mma_A.x[3]), "r"(mma_B.x[0]), "r"(mma_B.x[1]));
  170. #else
  171. // On Turing m16n8k32 mma is not available, use 4x m8n8k16 mma instead:
  172. asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
  173. : "+r"(x[0]), "+r"(x[1])
  174. : "r"(mma_A.x[0]), "r"(mma_B.x[0]));
  175. asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
  176. : "+r"(x[2]), "+r"(x[3])
  177. : "r"(mma_A.x[1]), "r"(mma_B.x[0]));
  178. asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
  179. : "+r"(x[0]), "+r"(x[1])
  180. : "r"(mma_A.x[2]), "r"(mma_B.x[1]));
  181. asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
  182. : "+r"(x[2]), "+r"(x[3])
  183. : "r"(mma_A.x[3]), "r"(mma_B.x[1]));
  184. #endif // __CUDA_ARCH__ >= CC_AMPERE
  185. #else
  186. GGML_UNUSED(mma_A);
  187. GGML_UNUSED(mma_B);
  188. NO_DEVICE_CODE;
  189. #endif // INT8_MMA_AVAILABLE
  190. }
  191. };