Browse Source

HIP: enable mmf for RDNA3 (#17879)

* enable mmf for RDNA3

* disable mmf for some shape

* move some mmvf to mmf

* more mmfv to mmf

* 3 is good in mmvf

---------

Co-authored-by: zhang hui <you@example.com>
yulo 1 month ago
parent
commit
c33a58bced
4 changed files with 83 additions and 20 deletions
  1. 14 11
      ggml/src/ggml-cuda/common.cuh
  2. 60 5
      ggml/src/ggml-cuda/mma.cuh
  3. 5 3
      ggml/src/ggml-cuda/mmf.cu
  4. 4 1
      ggml/src/ggml-cuda/mmvf.cu

+ 14 - 11
ggml/src/ggml-cuda/common.cuh

@@ -67,19 +67,22 @@
 #define GGML_CUDA_CC_RDNA1      (GGML_CUDA_CC_OFFSET_AMD + 0x1010) // RX 5000
 #define GGML_CUDA_CC_RDNA2      (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a
 #define GGML_CUDA_CC_RDNA3      (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA
+#define GGML_CUDA_CC_RDNA3_5    (GGML_CUDA_CC_OFFSET_AMD + 0x1150) // AI 370, AI Max 395 laptops.
 #define GGML_CUDA_CC_RDNA4      (GGML_CUDA_CC_OFFSET_AMD + 0x1200) // RX 9000
 
-#define GGML_CUDA_CC_IS_AMD(cc)   (cc >= GGML_CUDA_CC_OFFSET_AMD)
-#define GGML_CUDA_CC_IS_RDNA(cc)  (cc >= GGML_CUDA_CC_RDNA1)
-#define GGML_CUDA_CC_IS_RDNA1(cc) (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2)
-#define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3)
-#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3 && cc < GGML_CUDA_CC_RDNA4)
-#define GGML_CUDA_CC_IS_RDNA4(cc) (cc >= GGML_CUDA_CC_RDNA4)
-#define GGML_CUDA_CC_IS_GCN(cc)   (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA1)
-#define GGML_CUDA_CC_IS_CDNA(cc)  (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_RDNA1)
-#define GGML_CUDA_CC_IS_CDNA1(cc) (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_CDNA2)
-#define GGML_CUDA_CC_IS_CDNA2(cc) (cc >= GGML_CUDA_CC_CDNA2 && cc < GGML_CUDA_CC_CDNA3)
-#define GGML_CUDA_CC_IS_CDNA3(cc) (cc >= GGML_CUDA_CC_CDNA3 && cc < GGML_CUDA_CC_RDNA1)
+#define GGML_CUDA_CC_IS_AMD(cc)     (cc >= GGML_CUDA_CC_OFFSET_AMD)
+#define GGML_CUDA_CC_IS_RDNA(cc)    (cc >= GGML_CUDA_CC_RDNA1)
+#define GGML_CUDA_CC_IS_RDNA1(cc)   (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2)
+#define GGML_CUDA_CC_IS_RDNA2(cc)   (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3)
+#define GGML_CUDA_CC_IS_RDNA3_0(cc) (cc >= GGML_CUDA_CC_RDNA3 && cc < GGML_CUDA_CC_RDNA3_5)
+#define GGML_CUDA_CC_IS_RDNA3_5(cc) (cc >= GGML_CUDA_CC_RDNA3_5 && cc < GGML_CUDA_CC_RDNA4)
+#define GGML_CUDA_CC_IS_RDNA3(cc)   (GGML_CUDA_CC_IS_RDNA3_0(cc) || GGML_CUDA_CC_IS_RDNA3_5(cc))
+#define GGML_CUDA_CC_IS_RDNA4(cc)   (cc >= GGML_CUDA_CC_RDNA4)
+#define GGML_CUDA_CC_IS_GCN(cc)     (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA1)
+#define GGML_CUDA_CC_IS_CDNA(cc)    (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_RDNA1)
+#define GGML_CUDA_CC_IS_CDNA1(cc)   (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_CDNA2)
+#define GGML_CUDA_CC_IS_CDNA2(cc)   (cc >= GGML_CUDA_CC_CDNA2 && cc < GGML_CUDA_CC_CDNA3)
+#define GGML_CUDA_CC_IS_CDNA3(cc)   (cc >= GGML_CUDA_CC_CDNA3 && cc < GGML_CUDA_CC_RDNA1)
 
 // Moore Threads
 #define MUSART_HMASK 40300 // MUSA rc4.3, min. ver. for half2 -> uint mask comparisons

+ 60 - 5
ggml/src/ggml-cuda/mma.cuh

@@ -189,6 +189,9 @@ namespace ggml_cuda_mma {
                 return 8 * (threadIdx.x / 16) + l;
 #elif defined(RDNA3)
                 return 2 * l + (threadIdx.x / 16);
+#else
+                NO_DEVICE_CODE;
+                return -1;
 #endif // defined(RDNA4)
             } else {
                 NO_DEVICE_CODE;
@@ -290,8 +293,12 @@ namespace ggml_cuda_mma {
             }
         }
 #elif defined(AMD_WMMA_AVAILABLE)
-
+#if defined(RDNA3)
+        // RDNA3 has duplicated data as input.
+        static constexpr int ne = I * J / 32 * 2;
+#else
         static constexpr int ne = I * J / 32;
+#endif // defined(RDNA3)
         half2 x[ne] = {{0.0f, 0.0f}};
 
         static constexpr __device__ bool supported() {
@@ -310,7 +317,14 @@ namespace ggml_cuda_mma {
 
         static __device__ __forceinline__ int get_j(const int l) {
             if constexpr (I == 16 && J == 8) {
+#if defined(RDNA4)
                 return 4 * (threadIdx.x / 16) + l;
+#elif defined(RDNA3)
+                return l;
+#else
+                NO_DEVICE_CODE;
+                return -1;
+#endif // defined(RDNA4)
             } else {
                 NO_DEVICE_CODE;
                 return -1;
@@ -366,11 +380,16 @@ namespace ggml_cuda_mma {
         static constexpr int         I  = I_;
         static constexpr int         J  = J_;
         static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
-        static constexpr int         ne = I * J / WARP_SIZE;
 
+#if defined(AMD_WMMA_AVAILABLE)
+#if defined(RDNA3)
+        // RDNA3 has duplicated data as input.
+        static constexpr int ne = I * J / 32 * 2;
+#else
+        static constexpr int ne = I * J / 32;
+#endif // defined(RDNA3)
         nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
 
-#if defined(AMD_WMMA_AVAILABLE)
         static constexpr __device__ bool supported() {
             if (I == 16 && J == 8) return true;
             return false;
@@ -387,13 +406,23 @@ namespace ggml_cuda_mma {
 
         static __device__ __forceinline__ int get_j(const int l) {
             if constexpr (I == 16 && J == 8) {
+#if defined(RDNA4)
                 return 4 * (threadIdx.x / 16) + l;
+#elif defined(RDNA3)
+                return l;
+#else
+                NO_DEVICE_CODE;
+                return -1;
+#endif // defined(RDNA4)
             } else {
                 NO_DEVICE_CODE;
                 return -1;
             }
         }
 #else
+        static constexpr int ne = I * J / WARP_SIZE;
+        nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
+
         static constexpr __device__ bool supported() {
             if (I ==  8 && J ==  8) return true;
             if (I == 16 && J ==  4) return true;
@@ -546,8 +575,14 @@ namespace ggml_cuda_mma {
         }
 #elif defined(AMD_WMMA_AVAILABLE)
         if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
-            ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
-
+#if defined(RDNA4)
+                ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
+#elif defined(RDNA3)
+                ggml_cuda_memcpy_1<sizeof(t.x)/2>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
+                ggml_cuda_memcpy_1<sizeof(t.x)/2>(t.x + t.ne/2, xs0 + t.get_i(0) * stride + t.get_j(t.ne/2));
+#else
+                NO_DEVICE_CODE;
+#endif // defined(RDNA4)
         } else if constexpr (std::is_same_v<T, int>) {
             if constexpr (I == 16 && J == 4) {
                 int64_t * xi = (int64_t *) t.x;
@@ -888,6 +923,16 @@ namespace ggml_cuda_mma {
         const halfx8_t& a_frag = reinterpret_cast<const halfx8_t&>(A.x[0]);
         const halfx8_t& b_frag = reinterpret_cast<const halfx8_t&>(B.x[0]);
         acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(a_frag, b_frag, acc_frag);
+#elif defined(RDNA3)
+        using halfx16_t = __attribute__((ext_vector_type(16))) _Float16;
+        using floatx8_t = __attribute__((ext_vector_type(8))) float;
+        floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
+        const halfx16_t& a_frag = reinterpret_cast<const halfx16_t&>(A.x[0]);
+        const halfx16_t& b_frag = reinterpret_cast<const halfx16_t&>(B.x[0]);
+        acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(a_frag, b_frag, acc_frag);
+#else
+        GGML_UNUSED_VARS(D, A, B);
+        NO_DEVICE_CODE;
 #endif // RDNA4
 #else
         GGML_UNUSED_VARS(D, A, B);
@@ -905,6 +950,16 @@ namespace ggml_cuda_mma {
         const bf16x8_t& a_frag = reinterpret_cast<const bf16x8_t&>(A.x[0]);
         const bf16x8_t& b_frag = reinterpret_cast<const bf16x8_t&>(B.x[0]);
         acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(a_frag, b_frag, acc_frag);
+#elif defined(RDNA3)
+        using bf16x16_t = __attribute__((ext_vector_type(16))) __bf16;
+        using floatx8_t = __attribute__((ext_vector_type(8))) float;
+        floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
+        const bf16x16_t& a_frag = reinterpret_cast<const bf16x16_t&>(A.x[0]);
+        const bf16x16_t& b_frag = reinterpret_cast<const bf16x16_t&>(B.x[0]);
+        acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(a_frag, b_frag, acc_frag);
+#else
+        GGML_UNUSED_VARS(D, A, B);
+        NO_DEVICE_CODE;
 #endif // RDNA4
 #else
         GGML_UNUSED_VARS(D, A, B);

+ 5 - 3
ggml/src/ggml-cuda/mmf.cu

@@ -151,7 +151,9 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
             return false;
         }
     } else {
-        if (src1_ncols > 16) {
+        if (GGML_CUDA_CC_IS_RDNA3_0(cc) && src1_ncols > 8) {
+            return false;
+        } else if (src1_ncols > 16) {
             return false;
         }
     }
@@ -160,9 +162,9 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
         case GGML_TYPE_F32:
             return ampere_mma_available(cc);
         case GGML_TYPE_F16:
-            return volta_mma_available(cc) || turing_mma_available(cc) || (amd_wmma_available(cc) && GGML_CUDA_CC_IS_RDNA4(cc));
+            return volta_mma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc);
         case GGML_TYPE_BF16:
-            return ampere_mma_available(cc) || (amd_wmma_available(cc) && GGML_CUDA_CC_IS_RDNA4(cc));
+            return ampere_mma_available(cc) || amd_wmma_available(cc);
         default:
             return false;
     }

+ 4 - 1
ggml/src/ggml-cuda/mmvf.cu

@@ -765,7 +765,10 @@ bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0
                 return ne11 <= 8;
             } else if (GGML_CUDA_CC_IS_AMD(cc)) {
                 if (fp16_mma_hardware_available(cc)) {
-                    if (GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
+                    if (GGML_CUDA_CC_IS_RDNA3(cc)) {
+                        return ne11 <= 3;
+                    }
+                    if (GGML_CUDA_CC_IS_RDNA4(cc)) {
                         return ne11 <= 5;
                     }
                     return ne11 <= 2;