mma.cuh 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. #include "common.cuh"
  2. struct mma_int_A_I16K4 {
  3. static constexpr int I = 16;
  4. static constexpr int K = 4;
  5. static constexpr int ne = 2;
  6. int x[ne] = {0};
  7. static __device__ __forceinline__ int get_i(const int l) {
  8. const int ret = (l%2) * (I/2) + threadIdx.x / K;
  9. GGML_CUDA_ASSUME(ret >= 0);
  10. GGML_CUDA_ASSUME(ret < I);
  11. return ret;
  12. }
  13. static __device__ __forceinline__ int get_k(const int /* l */) {
  14. const int ret = threadIdx.x % K;
  15. GGML_CUDA_ASSUME(ret >= 0);
  16. GGML_CUDA_ASSUME(ret < K);
  17. return ret;
  18. }
  19. };
  20. struct mma_int_A_I16K8 {
  21. static constexpr int I = 16;
  22. static constexpr int K = 8;
  23. static constexpr int ne = 4;
  24. int x[ne] = {0};
  25. static __device__ __forceinline__ int get_i(const int l) {
  26. const int ret = (l%2) * (I/2) + threadIdx.x / (K/2);
  27. GGML_CUDA_ASSUME(ret >= 0);
  28. GGML_CUDA_ASSUME(ret < I);
  29. return ret;
  30. }
  31. static __device__ __forceinline__ int get_k(const int l) {
  32. const int ret = (l/2) * (K/2) + threadIdx.x % (K/2);
  33. GGML_CUDA_ASSUME(ret >= 0);
  34. GGML_CUDA_ASSUME(ret < K);
  35. return ret;
  36. }
  37. };
  38. struct mma_int_B_J8K4 {
  39. static constexpr int J = 8;
  40. static constexpr int K = 4;
  41. static constexpr int ne = 1;
  42. int x[ne] = {0};
  43. static __device__ __forceinline__ int get_j(const int /* l */) {
  44. const int ret = threadIdx.x / K;
  45. GGML_CUDA_ASSUME(ret >= 0);
  46. GGML_CUDA_ASSUME(ret < J);
  47. return ret;
  48. }
  49. static __device__ __forceinline__ int get_k(const int /* l */) {
  50. const int ret = threadIdx.x % K;
  51. GGML_CUDA_ASSUME(ret >= 0);
  52. GGML_CUDA_ASSUME(ret < K);
  53. return ret;
  54. }
  55. };
  56. struct mma_int_B_J8K8 {
  57. static constexpr int J = 8;
  58. static constexpr int K = 8;
  59. static constexpr int ne = 2;
  60. int x[ne] = {0};
  61. static __device__ __forceinline__ int get_j(const int /* l */) {
  62. const int ret = threadIdx.x / (K/2);
  63. GGML_CUDA_ASSUME(ret >= 0);
  64. GGML_CUDA_ASSUME(ret < J);
  65. return ret;
  66. }
  67. static __device__ __forceinline__ int get_k(const int l) {
  68. const int ret = l * (K/2) + threadIdx.x % (K/2);
  69. GGML_CUDA_ASSUME(ret >= 0);
  70. GGML_CUDA_ASSUME(ret < K);
  71. return ret;
  72. }
  73. };
  74. struct mma_int_C_I16J8 {
  75. static constexpr int I = 16;
  76. static constexpr int J = 8;
  77. static constexpr int ne = 4;
  78. int x[ne] = {0};
  79. static __device__ __forceinline__ int get_i(const int l) {
  80. const int ret = (l/2) * (I/2) + threadIdx.x / (J/2);
  81. GGML_CUDA_ASSUME(ret >= 0);
  82. GGML_CUDA_ASSUME(ret < I);
  83. return ret;
  84. }
  85. static __device__ __forceinline__ int get_j(const int l) {
  86. const int ret = 2 * (threadIdx.x % (J/2)) + l%2;
  87. GGML_CUDA_ASSUME(ret >= 0);
  88. GGML_CUDA_ASSUME(ret < J);
  89. return ret;
  90. }
  91. __device__ __forceinline__ void mma_K4(const mma_int_A_I16K4 & mma_A, const mma_int_B_J8K4 & mma_B) {
  92. #ifdef INT8_MMA_AVAILABLE
  93. #if __CUDA_ARCH__ >= CC_AMPERE
  94. asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
  95. : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
  96. : "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_B.x[0]));
  97. #else
  98. // On Turing m16n8k16 mma is not available, use 2x m8n8k16 mma instead:
  99. asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
  100. : "+r"(x[0]), "+r"(x[1])
  101. : "r"(mma_A.x[0]), "r"(mma_B.x[0]));
  102. asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
  103. : "+r"(x[2]), "+r"(x[3])
  104. : "r"(mma_A.x[1]), "r"(mma_B.x[0]));
  105. #endif // __CUDA_ARCH__ >= CC_AMPERE
  106. #else
  107. GGML_UNUSED(mma_A);
  108. GGML_UNUSED(mma_B);
  109. NO_DEVICE_CODE;
  110. #endif // INT8_MMA_AVAILABLE
  111. }
  112. __device__ __forceinline__ void mma_K8(const mma_int_A_I16K8 & mma_A, const mma_int_B_J8K8 & mma_B) {
  113. #ifdef INT8_MMA_AVAILABLE
  114. #if __CUDA_ARCH__ >= CC_AMPERE
  115. asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
  116. : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
  117. : "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_A.x[2]), "r"(mma_A.x[3]), "r"(mma_B.x[0]), "r"(mma_B.x[1]));
  118. #else
  119. // On Turing m16n8k32 mma is not available, use 4x m8n8k16 mma instead:
  120. asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
  121. : "+r"(x[0]), "+r"(x[1])
  122. : "r"(mma_A.x[0]), "r"(mma_B.x[0]));
  123. asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
  124. : "+r"(x[2]), "+r"(x[3])
  125. : "r"(mma_A.x[1]), "r"(mma_B.x[0]));
  126. asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
  127. : "+r"(x[0]), "+r"(x[1])
  128. : "r"(mma_A.x[2]), "r"(mma_B.x[1]));
  129. asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
  130. : "+r"(x[2]), "+r"(x[3])
  131. : "r"(mma_A.x[3]), "r"(mma_B.x[1]));
  132. #endif // __CUDA_ARCH__ >= CC_AMPERE
  133. #else
  134. GGML_UNUSED(mma_A);
  135. GGML_UNUSED(mma_B);
  136. NO_DEVICE_CODE;
  137. #endif // INT8_MMA_AVAILABLE
  138. }
  139. };