|
|
@@ -73,6 +73,8 @@ namespace ggml_cuda_mma {
|
|
|
return threadIdx.x / 4;
|
|
|
} else if constexpr (I == 16 && J == 8) {
|
|
|
return (l / 2) * 8 + threadIdx.x / 4;
|
|
|
+ } else if constexpr (I == 16 && J == 16) {
|
|
|
+ return ((l / 2) % 2) * 8 + threadIdx.x / 4;
|
|
|
} else {
|
|
|
static_assert(I == -1 && J == -1, "template specialization not implemented");
|
|
|
}
|
|
|
@@ -85,6 +87,8 @@ namespace ggml_cuda_mma {
|
|
|
return 4 * l + threadIdx.x % 4;
|
|
|
} else if constexpr (I == 16 && J == 8) {
|
|
|
return 2 * (threadIdx.x % 4) + l % 2;
|
|
|
+ } else if constexpr (I == 16 && J == 16) {
|
|
|
+ return 8 * (l / 4) + 2 * (threadIdx.x % 4) + l % 2;
|
|
|
} else {
|
|
|
static_assert(I == -1 && J == -1, "template specialization not implemented");
|
|
|
}
|
|
|
@@ -289,6 +293,42 @@ namespace ggml_cuda_mma {
|
|
|
#endif // NEW_MMA_AVAILABLE
|
|
|
}
|
|
|
|
|
|
+ static __device__ __forceinline__ void mma(
|
|
|
+ tile<16, 8, half2> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
|
|
|
+#ifdef NEW_MMA_AVAILABLE
|
|
|
+ const int * Axi = (const int *) A.x;
|
|
|
+ const int * Bxi = (const int *) B.x;
|
|
|
+ int * Dxi = (int *) D.x;
|
|
|
+#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
|
|
+ asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};"
|
|
|
+ : "+r"(Dxi[0]), "+r"(Dxi[1])
|
|
|
+ : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[2]));
|
|
|
+ asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};"
|
|
|
+ : "+r"(Dxi[2]), "+r"(Dxi[3])
|
|
|
+ : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]), "r"(Bxi[3]));
|
|
|
+#else
|
|
|
+ // On Turing m16n8k16 mma is not available, use 4x m8n8k8 mma instead:
|
|
|
+ asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
|
|
|
+ : "+r"(Dxi[0]), "+r"(Dxi[1])
|
|
|
+ : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
|
|
|
+ asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
|
|
|
+ : "+r"(Dxi[0]), "+r"(Dxi[1])
|
|
|
+ : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]));
|
|
|
+ asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
|
|
|
+ : "+r"(Dxi[2]), "+r"(Dxi[3])
|
|
|
+ : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[1]));
|
|
|
+ asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
|
|
|
+ : "+r"(Dxi[2]), "+r"(Dxi[3])
|
|
|
+ : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
|
|
|
+#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
|
|
+#else
|
|
|
+ GGML_UNUSED(D);
|
|
|
+ GGML_UNUSED(A);
|
|
|
+ GGML_UNUSED(B);
|
|
|
+ NO_DEVICE_CODE;
|
|
|
+#endif // NEW_MMA_AVAILABLE
|
|
|
+ }
|
|
|
+
|
|
|
static __device__ __forceinline__ void mma(
|
|
|
tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
|
|
|
#ifdef NEW_MMA_AVAILABLE
|
|
|
@@ -316,4 +356,39 @@ namespace ggml_cuda_mma {
|
|
|
#endif // NEW_MMA_AVAILABLE
|
|
|
}
|
|
|
|
|
|
+ static __device__ __forceinline__ void mma(
|
|
|
+ tile<16, 16, float> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
|
|
|
+#ifdef NEW_MMA_AVAILABLE
|
|
|
+ const int * Axi = (const int *) A.x;
|
|
|
+ const int * Bxi = (const int *) B.x;
|
|
|
+ int * Dxi = (int *) D.x;
|
|
|
+#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
|
|
+ asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
|
|
|
+ : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
|
|
|
+ : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[2]));
|
|
|
+ asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
|
|
|
+ : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
|
|
|
+ : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]), "r"(Bxi[3]));
|
|
|
+#else
|
|
|
+ // On Turing m16n8k16 mma is not available, use 4x m8n8k8 mma instead:
|
|
|
+ asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
|
|
|
+ : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
|
|
|
+ : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
|
|
|
+ asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
|
|
|
+ : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
|
|
|
+ : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]));
|
|
|
+ asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
|
|
|
+ : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
|
|
|
+ : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[1]));
|
|
|
+ asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
|
|
|
+ : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
|
|
|
+ : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
|
|
|
+#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
|
|
+#else
|
|
|
+ GGML_UNUSED(D);
|
|
|
+ GGML_UNUSED(A);
|
|
|
+ GGML_UNUSED(B);
|
|
|
+ NO_DEVICE_CODE;
|
|
|
+#endif // NEW_MMA_AVAILABLE
|
|
|
+ }
|
|
|
}
|