فهرست منبع

HIP: WMMA-MMQ kernels for RDNA 4 (#17156)

* first commit naive test to enable mmq for RDNA4

* adding appropriate WMMA instructions

* git rebase on top of master: fixing the correctness of the mat mul operations, updating layout mappings for RDNA4

* clean up merge conflicts

* add comments and code clean up

* PR clean up, addressed comments

* enable MMQ fallback on RDNA4

* addressed comments: add guards in load generic, separate wmma branch for use_mmq function

* Revert build-xcframework.sh

* Formating: remove trailing whitespace

* revert CMake files

* clean up after rebase: remove duplicated change, revert cmake files

* clean up after rebase: revert changes from build-xcframework.sh

* clean up: remove extra space line in mma.cuh

* Revert "clean up: remove extra space line in mma.cuh"

This reverts commit b39ed57c4529906466bd0bc7c2a86e08fc2f8bee.
Jiacheng (Jason) Chen 1 ماه پیش
والد
کامیت
0543f928a3
3فایلهای تغییر یافته به همراه379 افزوده شده و 139 حذف شده
  1. 101 30
      ggml/src/ggml-cuda/mma.cuh
  2. 7 1
      ggml/src/ggml-cuda/mmq.cu
  3. 271 108
      ggml/src/ggml-cuda/mmq.cuh

+ 101 - 30
ggml/src/ggml-cuda/mma.cuh

@@ -73,34 +73,7 @@ namespace ggml_cuda_mma {
         static constexpr int I  = I_;
         static constexpr int I  = I_;
         static constexpr int J  = J_;
         static constexpr int J  = J_;
 
 
-#if defined(GGML_USE_HIP)
-#if defined(RDNA4)
-        static constexpr int ne = I * J / 32;
-        T x[ne] = {0};
-
-        static constexpr __device__ bool supported() {
-            if (I == 16 && J == 16) return true;
-            return false;
-        }
-
-        static __device__ __forceinline__ int get_i(const int l) {
-            if constexpr (I == 16 && J == 16) {
-                return 8 * (threadIdx.x / 16) + l;
-            } else {
-                NO_DEVICE_CODE;
-                return -1;
-            }
-        }
-
-        static __device__ __forceinline__ int get_j(const int l) {
-            if constexpr (I == 16 && J == 16) {
-                return threadIdx.x % 16;
-            } else {
-                NO_DEVICE_CODE;
-                return -1;
-            }
-        }
-#else
+#if defined(AMD_MFMA_AVAILABLE)
         static constexpr int ne = I * J / 64;
         static constexpr int ne = I * J / 64;
         T x[ne] = {0};
         T x[ne] = {0};
 
 
@@ -146,7 +119,6 @@ namespace ggml_cuda_mma {
                 return -1;
                 return -1;
             }
             }
         }
         }
-#endif // defined(RDNA4)
 #elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
 #elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
         static constexpr int ne = I * J / 32;
         static constexpr int ne = I * J / 32;
         T x[ne] = {0};
         T x[ne] = {0};
@@ -177,6 +149,34 @@ namespace ggml_cuda_mma {
                 return -1;
                 return -1;
             }
             }
         }
         }
+#elif defined(AMD_WMMA_AVAILABLE)
+#if defined(RDNA4)
+        static constexpr int ne = I * J / 32;
+        T x[ne] = {0};
+
+        static constexpr __device__ bool supported() {
+            if (I == 16 && J == 16) return true;
+            return false;
+        }
+
+        static __device__ __forceinline__ int get_i(const int l) {
+            if constexpr (I == 16 && J == 16) {
+                return 8 * (threadIdx.x / 16) + l;
+            } else {
+                NO_DEVICE_CODE;
+                return -1;
+            }
+        }
+
+        static __device__ __forceinline__ int get_j(const int l) {
+            if constexpr (I == 16 && J == 16) {
+                return threadIdx.x % 16;
+            } else {
+                NO_DEVICE_CODE;
+                return -1;
+            }
+        }
+#endif
 #else
 #else
         static constexpr int ne = I * J / 32;
         static constexpr int ne = I * J / 32;
         T x[ne] = {0};
         T x[ne] = {0};
@@ -437,7 +437,20 @@ namespace ggml_cuda_mma {
             xi[0] = xs[0];
             xi[0] = xs[0];
         }
         }
 #elif defined(AMD_WMMA_AVAILABLE)
 #elif defined(AMD_WMMA_AVAILABLE)
-        ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
+        if constexpr (I == 16 && J == 4) {
+            int64_t * xi = (int64_t *) t.x;
+            const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
+            xi[0] = xs[0];
+        }else if constexpr (I == 16 && J == 8) {
+            int64_t * xi = (int64_t *) t.x;
+            const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I));
+            xi[0] = xs[0];
+
+            const int64_t * xs1 = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I) + 2);
+            xi[1] = xs1[0];
+        }else{
+            NO_DEVICE_CODE;
+        }
 #else
 #else
 #pragma unroll
 #pragma unroll
         for (int l = 0; l < t.ne; ++l) {
         for (int l = 0; l < t.ne; ++l) {
@@ -772,6 +785,36 @@ namespace ggml_cuda_mma {
                                                       acc[0],
                                                       acc[0],
                                                       0, 0, 0);
                                                       0, 0, 0);
 #endif // defined(CDNA3)
 #endif // defined(CDNA3)
+
+#elif defined(AMD_WMMA_AVAILABLE)
+        using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
+        int32x2_t * a_vec = (int32x2_t *) A.x;
+        int32x2_t * b_vec = (int32x2_t *) B.x;
+
+        using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
+        int32x8_t * acc = (int32x8_t *) D.x;
+
+#if defined(RDNA4)
+
+        acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
+            true,
+            a_vec[0],
+            true,
+            b_vec[0],
+            acc[0],
+            true
+        );
+
+        acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
+            true,
+            a_vec[1],
+            true,
+            b_vec[1],
+            acc[0],
+            true
+        );
+#endif // defined(RDNA4)
+
 #else
 #else
         GGML_UNUSED_VARS(D, A, B);
         GGML_UNUSED_VARS(D, A, B);
         NO_DEVICE_CODE;
         NO_DEVICE_CODE;
@@ -798,6 +841,7 @@ namespace ggml_cuda_mma {
                                                      acc[0],
                                                      acc[0],
                                                      0, 0, 0);
                                                      0, 0, 0);
 #endif // defined(CDNA3)
 #endif // defined(CDNA3)
+
 #else
 #else
         GGML_UNUSED_VARS(D, A, B);
         GGML_UNUSED_VARS(D, A, B);
         NO_DEVICE_CODE;
         NO_DEVICE_CODE;
@@ -842,4 +886,31 @@ namespace ggml_cuda_mma {
         mma(D16[1], A16[1], B);
         mma(D16[1], A16[1], B);
 #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
 #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
     }
     }
+
+static __device__ __forceinline__ void mma(
+            tile<16, 16, int> & D, const tile<16, 4, int> & A, const tile<16, 4, int> & B) {
+#if defined(AMD_WMMA_AVAILABLE)
+    using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
+    int32x2_t * a_vec = (int32x2_t *) A.x;
+    int32x2_t * b_vec = (int32x2_t *) B.x;
+
+    using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
+    int32x8_t * acc = (int32x8_t *) D.x;
+
+    acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
+        true,
+        a_vec[0],
+        true,
+        b_vec[0],
+        acc[0],
+        false
+    );
+#else
+        GGML_UNUSED(D);
+        GGML_UNUSED(A);
+        GGML_UNUSED(B);
+        NO_DEVICE_CODE;
+#endif
+    }
 }
 }
+

+ 7 - 1
ggml/src/ggml-cuda/mmq.cu

@@ -306,5 +306,11 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
         return false;
         return false;
     }
     }
 
 
-    return (!GGML_CUDA_CC_IS_RDNA4(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
+    if (amd_wmma_available(cc)) {
+        if (GGML_CUDA_CC_IS_RDNA4(cc)) {
+            return true;
+        }
+    }
+
+    return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
 }
 }

تفاوت فایلی نمایش داده نمی شود زیرا این فایل بسیار بزرگ است
+ 271 - 108
ggml/src/ggml-cuda/mmq.cuh


برخی فایل ها در این مقایسه diff نمایش داده نمی شوند زیرا تعداد فایل ها بسیار زیاد است