Просмотр исходного кода

HIP: RDNA4 tensor core support for MMF (#17077)

* mmf for rdna4

* align the padding for rdna4

* forbit mul_mat_f for rdna4

* fix as comment

* remove device kernels

* add constexpr for early return

* update based on review comment

* change based on the review comment

* pass compile error

* keep code consistency

---------

Co-authored-by: zhang hui <you@example.com>
yulo 1 месяц назад
Родитель
Сommit
028f93ef98

+ 8 - 0
ggml/src/ggml-cuda/common.cuh

@@ -224,6 +224,10 @@ static const char * cu_get_error_str(CUresult err) {
 #define AMD_MFMA_AVAILABLE
 #endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)
 
+#if defined(GGML_USE_HIP) && defined(RDNA4)
+#define AMD_WMMA_AVAILABLE
+#endif // defined(GGML_USE_HIP) && defined(RDNA4)
+
 // The Volta instructions are in principle available on Turing or newer but they are effectively unusable:
 #if !defined(GGML_USE_HIP) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
 #define VOLTA_MMA_AVAILABLE
@@ -283,6 +287,10 @@ static bool amd_mfma_available(const int cc) {
 #endif //!defined(GGML_HIP_NO_MMQ_MFMA)
 }
 
+static bool amd_wmma_available(const int cc) {
+    return GGML_CUDA_CC_IS_RDNA4(cc);
+}
+
 static bool volta_mma_available(const int cc) {
     return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_VOLTA;
 }

+ 9 - 0
ggml/src/ggml-cuda/convert.cuh

@@ -39,6 +39,15 @@ template<typename dst_t, typename src_t>
         return __float2bfloat16(float(x));
     } else if constexpr(std::is_same_v<src_t, nv_bfloat16>) {
         return __bfloat162float(x);
+    } else if constexpr(std::is_same_v<src_t, float2> && std::is_same_v<dst_t, half2>) {
+        return __float22half2_rn(x);
+    } else if constexpr(std::is_same_v<src_t, float2> && std::is_same_v<dst_t, nv_bfloat162>) {
+        // bypass compile error on cuda 12.0.1
+#ifdef GGML_USE_HIP
+        return __float22bfloat162_rn(x);
+#else
+        return {x.x, x.y};
+#endif // GGML_USE_HIP
     } else if constexpr(std::is_same_v<dst_t, int32_t>) {
         return int32_t(x);
     } else {

+ 107 - 0
ggml/src/ggml-cuda/mma.cuh

@@ -74,6 +74,33 @@ namespace ggml_cuda_mma {
         static constexpr int J  = J_;
 
 #if defined(GGML_USE_HIP)
+#if defined(RDNA4)
+        static constexpr int ne = I * J / 32;
+        T x[ne] = {0};
+
+        static constexpr __device__ bool supported() {
+            if (I == 16 && J == 16) return true;
+            return false;
+        }
+
+        static __device__ __forceinline__ int get_i(const int l) {
+            if constexpr (I == 16 && J == 16) {
+                return 8 * (threadIdx.x / 16) + l;
+            } else {
+                NO_DEVICE_CODE;
+                return -1;
+            }
+        }
+
+        static __device__ __forceinline__ int get_j(const int l) {
+            if constexpr (I == 16 && J == 16) {
+                return threadIdx.x % 16;
+            } else {
+                NO_DEVICE_CODE;
+                return -1;
+            }
+        }
+#else
         static constexpr int ne = I * J / 64;
         T x[ne] = {0};
 
@@ -119,6 +146,7 @@ namespace ggml_cuda_mma {
                 return -1;
             }
         }
+#endif // defined(RDNA4)
 #elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
         static constexpr int ne = I * J / 32;
         T x[ne] = {0};
@@ -236,6 +264,32 @@ namespace ggml_cuda_mma {
                 return -1;
             }
         }
+#elif defined(AMD_WMMA_AVAILABLE)
+        static constexpr int ne = I * J / 32;
+        half2 x[ne] = {{0.0f, 0.0f}};
+
+        static constexpr __device__ bool supported() {
+            if (I == 16 && J == 8) return true;
+            return false;
+        }
+
+        static __device__ __forceinline__ int get_i(const int l) {
+            if constexpr (I == 16 && J == 8) {
+                return threadIdx.x % 16;
+            } else {
+                NO_DEVICE_CODE;
+                return -1;
+            }
+        }
+
+        static __device__ __forceinline__ int get_j(const int l) {
+            if constexpr (I == 16 && J == 8) {
+                return 4 * (threadIdx.x / 16) + l;
+            } else {
+                NO_DEVICE_CODE;
+                return -1;
+            }
+        }
 #else
         static constexpr int ne = I * J / WARP_SIZE;
         half2 x[ne] = {{0.0f, 0.0f}};
@@ -285,6 +339,34 @@ namespace ggml_cuda_mma {
     struct tile<I_, J_, nv_bfloat162> {
         static constexpr int I  = I_;
         static constexpr int J  = J_;
+
+#if defined(AMD_WMMA_AVAILABLE)
+        static constexpr int ne = I * J / 32;
+        nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
+
+        static constexpr __device__ bool supported() {
+            if (I == 16 && J == 8) return true;
+            return false;
+        }
+
+        static __device__ __forceinline__ int get_i(const int l) {
+            if constexpr (I == 16 && J == 8) {
+                return threadIdx.x % 16;
+            } else {
+                NO_DEVICE_CODE;
+                return -1;
+            }
+        }
+
+        static __device__ __forceinline__ int get_j(const int l) {
+            if constexpr (I == 16 && J == 8) {
+                return 4 * (threadIdx.x / 16) + l;
+            } else {
+                NO_DEVICE_CODE;
+                return -1;
+            }
+        }
+#else
         static constexpr int ne = I * J / WARP_SIZE;
         nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
 
@@ -320,6 +402,7 @@ namespace ggml_cuda_mma {
                 return -1;
             }
         }
+#endif  // defined(AMD_WMMA_AVAILABLE)
     };
 
     template <int I, int J>
@@ -353,6 +436,8 @@ namespace ggml_cuda_mma {
             const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
             xi[0] = xs[0];
         }
+#elif defined(AMD_WMMA_AVAILABLE)
+        ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
 #else
 #pragma unroll
         for (int l = 0; l < t.ne; ++l) {
@@ -639,12 +724,34 @@ namespace ggml_cuda_mma {
             : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
             : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
 #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+#elif defined(AMD_WMMA_AVAILABLE)
+        using halfx8_t = __attribute__((ext_vector_type(8))) _Float16;
+        using floatx8_t = __attribute__((ext_vector_type(8))) float;
+        floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
+        const halfx8_t& a_frag = reinterpret_cast<const halfx8_t&>(A.x[0]);
+        const halfx8_t& b_frag = reinterpret_cast<const halfx8_t&>(B.x[0]);
+        acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(a_frag, b_frag, acc_frag);
 #else
         GGML_UNUSED_VARS(D, A, B);
         NO_DEVICE_CODE;
 #endif // TURING_MMA_AVAILABLE
     }
 
+    static __device__ __forceinline__ void mma(
+            tile<16, 16, float> & D, const tile<16, 8, nv_bfloat162> & A, const tile<16, 8, nv_bfloat162> & B) {
+#if defined(AMD_WMMA_AVAILABLE)
+        using bf16x8_t = __attribute__((ext_vector_type(8))) __bf16;
+        using floatx8_t = __attribute__((ext_vector_type(8))) float;
+        floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
+        const bf16x8_t& a_frag = reinterpret_cast<const bf16x8_t&>(A.x[0]);
+        const bf16x8_t& b_frag = reinterpret_cast<const bf16x8_t&>(B.x[0]);
+        acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(a_frag, b_frag, acc_frag);
+#else
+        GGML_UNUSED_VARS(D, A, B);
+        NO_DEVICE_CODE;
+#endif // AMPERE_MMA_AVAILABLE
+    }
+
     static __device__ __forceinline__ void mma(
             tile<16, 16, int> & D, const tile<16, 8, int> & A, const tile<16, 8, int> & B) {
 #if defined(AMD_MFMA_AVAILABLE)

+ 3 - 3
ggml/src/ggml-cuda/mmf.cu

@@ -151,7 +151,7 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
             return false;
         }
     } else {
-        if (src1_ncols > 16) {
+        if (src1_ncols > 16 || GGML_CUDA_CC_IS_RDNA4(cc)) {
             return false;
         }
     }
@@ -160,9 +160,9 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
         case GGML_TYPE_F32:
             return ampere_mma_available(cc);
         case GGML_TYPE_F16:
-            return volta_mma_available(cc) || turing_mma_available(cc);
+            return volta_mma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc);
         case GGML_TYPE_BF16:
-            return ampere_mma_available(cc);
+            return ampere_mma_available(cc) || amd_wmma_available(cc);
         default:
             return false;
     }

+ 53 - 20
ggml/src/ggml-cuda/mmf.cuh

@@ -2,6 +2,7 @@
 
 #include "mma.cuh"
 #include "common.cuh"
+#include "convert.cuh"
 
 using namespace ggml_cuda_mma;
 
@@ -27,20 +28,35 @@ static __global__ void mul_mat_f(
         const int stride_col_id, const int stride_row_id,
         const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
         const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
-#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
+// TODO: handle this in a consistent and simpler way after AMD MFMA support has been added
+#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
+#if defined(AMD_WMMA_AVAILABLE)
+    // Special case for tf32, just dummy mma layout as wmma doesn't support it.
+    constexpr int tile_B_I = std::is_same_v<T, float> ? 8 : 16;
+    constexpr int tile_C_J = std::is_same_v<T, float> ? 8 : 16;
+    typedef tile<16,       8, T>     tile_A;
+    typedef tile<tile_B_I, 8, T>     tile_B;
+    typedef tile<16,       tile_C_J, float> tile_C;
+
+    constexpr bool a_supported = tile_A::supported();
+    constexpr bool b_supported = tile_B::supported();
+    constexpr bool c_supported = tile_C::supported();
+    constexpr bool supported = a_supported && b_supported && c_supported;
+#else
     constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported();
     constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported();
-
-    if (!I_16_supported && !I_32_supported) {
-        NO_DEVICE_CODE;
-        return;
-    }
+    constexpr bool supported = I_16_supported || I_32_supported;
 
     constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster.
 
     typedef tile<I_preferred, 8, T>     tile_A;
     typedef tile<8,           8, T>     tile_B;
     typedef tile<I_preferred, 8, float> tile_C;
+#endif // defined(AMD_WMMA_AVAILABLE)
+    if constexpr (!supported) {
+        NO_DEVICE_CODE;
+        return;
+    }
 
     constexpr int warp_size = ggml_cuda_get_physical_warp_size();
     constexpr int tile_k_padded = warp_size + 4;
@@ -161,11 +177,11 @@ static __global__ void mul_mat_f(
 
                     if constexpr (!has_ids) {
                         const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f);
-                        tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
+                        tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast<T>(tmp);
                     } else {
                         const bool valid = j < cols_per_block && (col_base + j) < ncols_dst_total && slot_map[j] >= 0;
                         float2 tmp = valid ? *(const float2*) &y[slot_map[j]*stride_channel_y + 2*(j*stride_col_y + col)] : make_float2(0.0f, 0.0f);
-                        tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
+                        tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast<T>(tmp);
                     }
                 }
             } else {
@@ -239,7 +255,7 @@ static __global__ void mul_mat_f(
         channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
         sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
     NO_DEVICE_CODE;
-#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
+#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
 }
 
 //This kernel is for larger batch sizes of mul_mat_id
@@ -253,20 +269,35 @@ static __global__ void mul_mat_f_ids(
         const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
         const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
         const uint3 sis1_fd, const uint3 nch_fd) {
-#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
+// TODO: handle this in a consistent and simpler way after AMD MFMA support has been added
+#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
+#if defined(AMD_WMMA_AVAILABLE)
+    // Special case for tf32, just dummy mma layout as wmma doesn't support it.
+    constexpr int tile_B_I = std::is_same_v<T, float> ? 8 : 16;
+    constexpr int tile_C_J = std::is_same_v<T, float> ? 8 : 16;
+    typedef tile<16,       8, T>     tile_A;
+    typedef tile<tile_B_I, 8, T>     tile_B;
+    typedef tile<16,       tile_C_J, float> tile_C;
+
+    constexpr bool a_supported = tile_A::supported();
+    constexpr bool b_supported = tile_B::supported();
+    constexpr bool c_supported = tile_C::supported();
+    constexpr bool supported = a_supported && b_supported && c_supported;
+#else
     constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported();
     constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported();
+    constexpr bool supported = I_16_supported || I_32_supported;
 
-    if (!I_16_supported && !I_32_supported) {
-        NO_DEVICE_CODE;
-        return;
-    }
-
-    constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work butr 16 is ~1% faster.
+    constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster.
 
     typedef tile<I_preferred, 8, T>     tile_A;
     typedef tile<8,           8, T>     tile_B;
     typedef tile<I_preferred, 8, float> tile_C;
+#endif // defined(AMD_WMMA_AVAILABLE)
+    if constexpr (!supported) {
+        NO_DEVICE_CODE;
+        return;
+    }
 
     constexpr int warp_size = ggml_cuda_get_physical_warp_size();
     constexpr int tile_k_padded = warp_size + 4;
@@ -408,7 +439,7 @@ static __global__ void mul_mat_f_ids(
 #pragma unroll
                 for (int j0 = 0; j0 < tile_B::I; ++j0) {
                     const float2 tmp = vals_buf[curr_buf][j0];
-                    tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
+                    tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast<T>(tmp);
                 }
 
                 if (itB + 1 < ntB) {
@@ -492,7 +523,7 @@ static __global__ void mul_mat_f_ids(
         channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
         sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd);
     NO_DEVICE_CODE;
-#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
+#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
 }
 
 template<typename T, int cols_per_block, int nwarps>
@@ -554,7 +585,8 @@ void mul_mat_f_cuda(
         cudaStream_t stream, const mmf_ids_data * ids_data) {
     typedef tile<16, 8, T>     tile_A_16;
     typedef tile<32, 8, T>     tile_A_32;
-    typedef tile< 8, 8, T>     tile_B;
+    typedef tile<16, 8, T>     tile_B_16;
+    typedef tile< 8, 8, T>     tile_B_8;
 
     GGML_ASSERT(ncols_x      % 2 == 0);
     GGML_ASSERT(stride_row   % 2 == 0);
@@ -581,7 +613,8 @@ void mul_mat_f_cuda(
 
     constexpr int rows_per_block = MMF_ROWS_PER_BLOCK;
     const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + 4) * 4;
-    const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4;
+    const int nbytes_cols_per_block_pad = amd_wmma_available(cc) ? tile_B_16::I : tile_B_8::I;
+    const int nbytes_shared_combine = GGML_PAD(cols_per_block, nbytes_cols_per_block_pad) * (nwarps_best*rows_per_block + 4) * 4;
     const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
     const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0;
     const int nbytes_shared_total = nbytes_shared + nbytes_slotmap;