|
|
@@ -12,7 +12,8 @@
|
|
|
// The methods get_i and get_j can be used to get the physical 32 bit index of the lth element of a thread within a tile.
|
|
|
// All matrix tiles have ne physical 32 bit elements per warp.
|
|
|
//
|
|
|
-// As described in the documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes.
|
|
|
+// As described in the PTX documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes.
|
|
|
+// The API in this file also assumes that the pointers for load_generic are aligned to 16 bytes, unaligned pointers are considered undefined behavior.
|
|
|
|
|
|
#include "common.cuh"
|
|
|
|
|
|
@@ -66,7 +67,44 @@ namespace ggml_cuda_mma {
|
|
|
struct tile {
|
|
|
static constexpr int I = I_;
|
|
|
static constexpr int J = J_;
|
|
|
- static constexpr int ne = I * J / WARP_SIZE;
|
|
|
+
|
|
|
+#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
|
|
+ static constexpr int ne = I * J / 64;
|
|
|
+ T x[ne] = {0};
|
|
|
+
|
|
|
+ static __device__ __forceinline__ int get_i(const int l) {
|
|
|
+ if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
|
|
|
+ return threadIdx.x % 16;
|
|
|
+ } else if constexpr (I == 16 && J == 8) {
|
|
|
+ return threadIdx.x % 16;
|
|
|
+ } else if constexpr (I == 32 && J == 4) {
|
|
|
+ return threadIdx.x % 32;
|
|
|
+ } else if constexpr (I == 16 && J == 16) {
|
|
|
+ return 4 * (threadIdx.x / 16) + l;
|
|
|
+ } else if constexpr (I == 32 && J == 32) {
|
|
|
+ return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4);
|
|
|
+ } else {
|
|
|
+ static_assert(I == -1 && J == -1, "template specialization not implemented");
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ static __device__ __forceinline__ int get_j(const int l) {
|
|
|
+ if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
|
|
|
+ return (2 * ((threadIdx.x / 16) % 2) + l);
|
|
|
+ } else if constexpr (I == 16 && J == 8) {
|
|
|
+ return 2 * (threadIdx.x / 16) + l;
|
|
|
+ } else if constexpr (I == 32 && J == 4) {
|
|
|
+ return 2 * (threadIdx.x / 32) + l;
|
|
|
+ } else if constexpr (I == 16 && J == 16) {
|
|
|
+ return threadIdx.x % 16;
|
|
|
+ } else if constexpr (I == 32 && J == 32) {
|
|
|
+ return threadIdx.x % 32;
|
|
|
+ } else {
|
|
|
+ static_assert(I == -1 && J == -1, "template specialization not implemented");
|
|
|
+ }
|
|
|
+ }
|
|
|
+#else
|
|
|
+ static constexpr int ne = I * J / 32;
|
|
|
T x[ne] = {0};
|
|
|
|
|
|
static __device__ __forceinline__ int get_i(const int l) {
|
|
|
@@ -94,6 +132,7 @@ namespace ggml_cuda_mma {
|
|
|
static_assert(I == -1 && J == -1, "template specialization not implemented");
|
|
|
}
|
|
|
}
|
|
|
+#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
|
|
};
|
|
|
|
|
|
template <int I_, int J_>
|
|
|
@@ -148,10 +187,23 @@ namespace ggml_cuda_mma {
|
|
|
|
|
|
template <int I, int J, typename T>
|
|
|
static __device__ __forceinline__ void load_generic(tile<I, J, T> & t, const T * __restrict__ xs0, const int stride) {
|
|
|
+#if defined(AMD_MFMA_AVAILABLE)
|
|
|
+ if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
|
|
|
+#pragma unroll
|
|
|
+ for (int l = 0; l < t.ne; ++l) {
|
|
|
+ t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ int64_t * xi = (int64_t *) t.x;
|
|
|
+ const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
|
|
|
+ xi[0] = xs[0];
|
|
|
+ }
|
|
|
+#else
|
|
|
#pragma unroll
|
|
|
for (int l = 0; l < t.ne; ++l) {
|
|
|
t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
|
|
|
}
|
|
|
+#endif // defined(AMD_MFMA_AVAILABLE)
|
|
|
}
|
|
|
|
|
|
template <typename T>
|
|
|
@@ -186,7 +238,7 @@ namespace ggml_cuda_mma {
|
|
|
template <typename T>
|
|
|
static __device__ __forceinline__ void load_ldmatrix(
|
|
|
tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
|
|
|
-#ifdef NEW_MMA_AVAILABLE
|
|
|
+#if defined(NEW_MMA_AVAILABLE)
|
|
|
int * xi = (int * ) t.x;
|
|
|
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
|
|
|
asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
|
|
|
@@ -393,4 +445,60 @@ namespace ggml_cuda_mma {
|
|
|
NO_DEVICE_CODE;
|
|
|
#endif // NEW_MMA_AVAILABLE
|
|
|
}
|
|
|
+
|
|
|
+ static __device__ __forceinline__ void mma(
|
|
|
+ tile<16, 16, int> & D, const tile<16, 8, int> & A, const tile<16, 8, int> & B) {
|
|
|
+#if defined(AMD_MFMA_AVAILABLE)
|
|
|
+ using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
|
|
|
+ int32x4_t * acc = (int32x4_t *) D.x;
|
|
|
+#if defined(CDNA3)
|
|
|
+ acc[0] = __builtin_amdgcn_mfma_i32_16x16x32_i8(((int64_t *) A.x)[0],
|
|
|
+ ((int64_t *) B.x)[0],
|
|
|
+ acc[0],
|
|
|
+ 0, 0, 0);
|
|
|
+#elif defined(CDNA2) || defined(CDNA)
|
|
|
+ acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[0],
|
|
|
+ B.x[0],
|
|
|
+ acc[0],
|
|
|
+ 0, 0, 0);
|
|
|
+ acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[1],
|
|
|
+ B.x[1],
|
|
|
+ acc[0],
|
|
|
+ 0, 0, 0);
|
|
|
+#endif // defined(CDNA3)
|
|
|
+#else
|
|
|
+ GGML_UNUSED(D);
|
|
|
+ GGML_UNUSED(A);
|
|
|
+ GGML_UNUSED(B);
|
|
|
+ NO_DEVICE_CODE;
|
|
|
+#endif // AMD_MFMA_AVAILABLE
|
|
|
+ }
|
|
|
+
|
|
|
+ static __device__ __forceinline__ void mma(
|
|
|
+ tile<32, 32, int> & D, const tile<32, 4, int> & A, const tile<32, 4, int> & B) {
|
|
|
+#if defined(AMD_MFMA_AVAILABLE)
|
|
|
+ using int32x16_t = __attribute__((__vector_size__(16 * sizeof(int)))) int;
|
|
|
+ int32x16_t * acc = (int32x16_t *) D.x;
|
|
|
+#if defined(CDNA3)
|
|
|
+ acc[0] = __builtin_amdgcn_mfma_i32_32x32x16_i8(((int64_t *) A.x)[0],
|
|
|
+ ((int64_t *) B.x)[0],
|
|
|
+ acc[0],
|
|
|
+ 0, 0, 0);
|
|
|
+#elif defined(CDNA2) || defined(CDNA)
|
|
|
+ acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[0],
|
|
|
+ B.x[0],
|
|
|
+ acc[0],
|
|
|
+ 0, 0, 0);
|
|
|
+ acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[1],
|
|
|
+ B.x[1],
|
|
|
+ acc[0],
|
|
|
+ 0, 0, 0);
|
|
|
+#endif // defined(CDNA3)
|
|
|
+#else
|
|
|
+ GGML_UNUSED(D);
|
|
|
+ GGML_UNUSED(A);
|
|
|
+ GGML_UNUSED(B);
|
|
|
+ NO_DEVICE_CODE;
|
|
|
+#endif // AMD_MFMA_AVAILABLE
|
|
|
+ }
|
|
|
}
|