Browse Source

HIP: Refactor mma for RDNA and CDNA (#17990)

* mma.cuh for rdna4

* mma for rdna3

* mmq for rdna4

* mmq for rdna3

* align i-major and j-major

* cdna

* fix cuda error

* add missing tile of mfma

* fix j-major wrong ne on CDNA

* fix gramma and empty spaces

---------

Co-authored-by: zhang hui <you@example.com>
yulo 1 month ago
parent
commit
acec774ef6
3 changed files with 179 additions and 150 deletions
  1. 129 111
      ggml/src/ggml-cuda/mma.cuh
  2. 14 10
      ggml/src/ggml-cuda/mmf.cuh
  3. 36 29
      ggml/src/ggml-cuda/mmq.cuh

+ 129 - 111
ggml/src/ggml-cuda/mma.cuh

@@ -76,15 +76,31 @@ namespace ggml_cuda_mma {
         // For the A/C matrices this means I major == row major, J major == column major.
         // For the B matrix this means I major == column major, J major == row major.
         // MIRRORED == Each data value is held exactly once per thread subgroup.
-        DATA_LAYOUT_I_MAJOR           =  0, // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell.
-        DATA_LAYOUT_I_MAJOR_MIRRORED  = 10,
-        DATA_LAYOUT_J_MAJOR_MIRRORED  = 20,
+        DATA_LAYOUT_I_MAJOR           =  0, // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell, matrix A&B for RDNA4 and CDNA.
+        DATA_LAYOUT_J_MAJOR           = 10, // Matrix C for CDNA and RDNA4, int and float matrix C for RDNA3.
+        DATA_LAYOUT_I_MAJOR_MIRRORED  = 20,
+        DATA_LAYOUT_J_MAJOR_MIRRORED  = 30,
+        DATA_LAYOUT_I_MAJOR_DUAL      = 40, // Matrix A&B for RDNA3.
     };
     // Implemented mma combinations are:
     //   - (I_MAJOR, I_MAJOR)          -> I_MAJOR
     //   - (I_MAJOR, I_MAJOR_MIRRORED) -> I_MAJOR
     //   - (I_MAJOR, J_MAJOR_MIRRORED) -> I_MAJOR
 
+    constexpr bool is_i_major(const data_layout dl) {
+        return dl == DATA_LAYOUT_I_MAJOR ||
+               dl == DATA_LAYOUT_I_MAJOR_MIRRORED ||
+               dl == DATA_LAYOUT_I_MAJOR_DUAL;
+    }
+
+    constexpr data_layout get_input_data_layout() {
+#if defined(RDNA3)
+        return DATA_LAYOUT_I_MAJOR_DUAL;
+#else
+        return DATA_LAYOUT_I_MAJOR;
+#endif // defined(RDNA3)
+    }
+
     template <int I_, int J_, typename T, data_layout ds_=DATA_LAYOUT_I_MAJOR>
     struct tile {};
 
@@ -115,9 +131,9 @@ namespace ggml_cuda_mma {
             } else if constexpr (I == 32 && J == 4) {
                 return threadIdx.x % 32;
             } else if constexpr (I == 16 && J == 16) {
-                return 4 * (threadIdx.x / 16) + l;
+                return threadIdx.x % 16;
             } else if constexpr (I == 32 && J == 32) {
-                return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4);
+                return threadIdx.x % 32;
             } else {
                 NO_DEVICE_CODE;
                 return -1;
@@ -132,9 +148,9 @@ namespace ggml_cuda_mma {
             } else if constexpr (I == 32 && J == 4) {
                 return 2 * (threadIdx.x / 32) + l;
             } else if constexpr (I == 16 && J == 16) {
-                return threadIdx.x % 16;
+                return 4 * (threadIdx.x / 16) + l;
             } else if constexpr (I == 32 && J == 32) {
-                return threadIdx.x % 32;
+                return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4);
             } else {
                 NO_DEVICE_CODE;
                 return -1;
@@ -171,28 +187,19 @@ namespace ggml_cuda_mma {
             }
         }
 #elif defined(AMD_WMMA_AVAILABLE)
-#if defined(RDNA4)
         static constexpr int ne = I * J / 32;
-#elif defined(RDNA3)
-        static constexpr int ne = (I == 16 && J == 16) ? I * J / 32 : I * J / 16;
-#endif // defined(RDNA4)
         T x[ne] = {0};
 
         static constexpr __device__ bool supported() {
             if (I == 16 && J == 16) return true;
+            if (I == 16 && J == 8) return true;
+            if (I == 16 && J == 4) return true;
             return false;
         }
 
         static __device__ __forceinline__ int get_i(const int l) {
-            if constexpr (I == 16 && J == 16) {
-#if defined(RDNA4)
-                return 8 * (threadIdx.x / 16) + l;
-#elif defined(RDNA3)
-                return 2 * l + (threadIdx.x / 16);
-#else
-                NO_DEVICE_CODE;
-                return -1;
-#endif // defined(RDNA4)
+            if constexpr (supported()) {
+                return threadIdx.x % 16;
             } else {
                 NO_DEVICE_CODE;
                 return -1;
@@ -201,7 +208,17 @@ namespace ggml_cuda_mma {
 
         static __device__ __forceinline__ int get_j(const int l) {
             if constexpr (I == 16 && J == 16) {
-                return threadIdx.x % 16;
+                // matrix C
+#if defined(RDNA3)
+                return 2 * l + (threadIdx.x / 16);
+#else
+                return ne * (threadIdx.x / 16) + l;
+#endif // defined(RDNA3)
+            } else if constexpr (I == 16 && J == 8) {
+                // mmq input for RDNA4
+                return ne * (threadIdx.x / 16) + l;
+            } else if constexpr (I == 16 && J == 4) {
+                return ne * (threadIdx.x / 16) + l;
             } else {
                 NO_DEVICE_CODE;
                 return -1;
@@ -293,12 +310,7 @@ namespace ggml_cuda_mma {
             }
         }
 #elif defined(AMD_WMMA_AVAILABLE)
-#if defined(RDNA3)
-        // RDNA3 has duplicated data as input.
-        static constexpr int ne = I * J / 32 * 2;
-#else
         static constexpr int ne = I * J / 32;
-#endif // defined(RDNA3)
         half2 x[ne] = {{0.0f, 0.0f}};
 
         static constexpr __device__ bool supported() {
@@ -317,14 +329,7 @@ namespace ggml_cuda_mma {
 
         static __device__ __forceinline__ int get_j(const int l) {
             if constexpr (I == 16 && J == 8) {
-#if defined(RDNA4)
                 return 4 * (threadIdx.x / 16) + l;
-#elif defined(RDNA3)
-                return l;
-#else
-                NO_DEVICE_CODE;
-                return -1;
-#endif // defined(RDNA4)
             } else {
                 NO_DEVICE_CODE;
                 return -1;
@@ -382,42 +387,19 @@ namespace ggml_cuda_mma {
         static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
 
 #if defined(AMD_WMMA_AVAILABLE)
-#if defined(RDNA3)
-        // RDNA3 has duplicated data as input.
-        static constexpr int ne = I * J / 32 * 2;
-#else
         static constexpr int ne = I * J / 32;
-#endif // defined(RDNA3)
         nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
 
         static constexpr __device__ bool supported() {
-            if (I == 16 && J == 8) return true;
-            return false;
+            return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::supported();
         }
 
         static __device__ __forceinline__ int get_i(const int l) {
-            if constexpr (I == 16 && J == 8) {
-                return threadIdx.x % 16;
-            } else {
-                NO_DEVICE_CODE;
-                return -1;
-            }
+            return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_i(l);
         }
 
         static __device__ __forceinline__ int get_j(const int l) {
-            if constexpr (I == 16 && J == 8) {
-#if defined(RDNA4)
-                return 4 * (threadIdx.x / 16) + l;
-#elif defined(RDNA3)
-                return l;
-#else
-                NO_DEVICE_CODE;
-                return -1;
-#endif // defined(RDNA4)
-            } else {
-                NO_DEVICE_CODE;
-                return -1;
-            }
+            return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_j(l);
         }
 #else
         static constexpr int ne = I * J / WARP_SIZE;
@@ -458,6 +440,28 @@ namespace ggml_cuda_mma {
 #endif  // defined(AMD_WMMA_AVAILABLE)
     };
 
+    template <int I_, int J_, typename T>
+    struct tile<I_, J_, T, DATA_LAYOUT_J_MAJOR> {
+        static constexpr int         I  = I_;
+        static constexpr int         J  = J_;
+        static constexpr data_layout dl = DATA_LAYOUT_J_MAJOR;
+
+        static constexpr int ne = tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::ne;
+        T x[ne] = {0};
+
+        static constexpr __device__ bool supported() {
+            return tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::supported();
+        }
+
+        static __device__ __forceinline__ int get_i(const int l) {
+            return tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::get_j(l);
+        }
+
+        static __device__ __forceinline__ int get_j(const int l) {
+            return tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::get_i(l);
+        }
+    };
+
     template <int I_, int J_>
     struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> {
         static constexpr int         I  = I_;
@@ -524,6 +528,42 @@ namespace ggml_cuda_mma {
         }
     };
 
+    template <int I_, int J_, typename T>
+    struct tile<I_, J_, T, DATA_LAYOUT_I_MAJOR_DUAL> {
+        static constexpr int         I  = I_;
+        static constexpr int         J  = J_;
+        static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_DUAL;
+
+        static constexpr int         ne = I * J / 32 * 2;
+
+        T x[ne] = {0};
+
+        static constexpr __device__ bool supported() {
+            if (I == 16 && J == 16) return true;
+            if (I == 16 && J == 8)  return true;
+            if (I == 16 && J == 4)  return true;
+            return false;
+        }
+
+        static __device__ __forceinline__ int get_i(const int l) {
+            if constexpr (supported()) {
+                return threadIdx.x % 16;
+            } else {
+                NO_DEVICE_CODE;
+                return -1;
+            }
+        }
+
+        static __device__ __forceinline__ int get_j(const int l) {
+            if constexpr (supported()) {
+                return l;
+            } else {
+                NO_DEVICE_CODE;
+                return -1;
+            }
+        }
+    };
+
 #if defined(TURING_MMA_AVAILABLE)
     template <int I, int J>
     static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
@@ -569,55 +609,28 @@ namespace ggml_cuda_mma {
                 t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
             }
         } else {
-            int64_t * xi = (int64_t *) t.x;
-            const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
-            xi[0] = xs[0];
+            ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
         }
 #elif defined(AMD_WMMA_AVAILABLE)
-        if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
-#if defined(RDNA4)
-                ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
-#elif defined(RDNA3)
-                ggml_cuda_memcpy_1<sizeof(t.x)/2>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
-                ggml_cuda_memcpy_1<sizeof(t.x)/2>(t.x + t.ne/2, xs0 + t.get_i(0) * stride + t.get_j(t.ne/2));
-#else
-                NO_DEVICE_CODE;
-#endif // defined(RDNA4)
-        } else if constexpr (std::is_same_v<T, int>) {
-            if constexpr (I == 16 && J == 4) {
-                int64_t * xi = (int64_t *) t.x;
-#if defined(RDNA4)
-                const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
-                xi[0] = xs[0];
-#elif defined(RDNA3)
-                static_assert(tile<I,J,T>::ne >= 4, "fragment too small");
-                const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride);
-                xi[0] = xs[0];
-                xi[1] = xs[1];
-#endif // defined(RDNA4)
-            } else if constexpr (I == 16 && J == 8) {
-                int64_t * xi = (int64_t *) t.x;
-#if defined(RDNA4)
-                const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I));
-                xi[0] = xs[0];
-
-                const int64_t * xs1 = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I) + 2);
-                xi[1] = xs1[0];
-#elif defined(RDNA3)
-                static_assert(tile<I,J,T>::ne >= 8, "fragment too small");
-                const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride);
-                // contiguous four 64-bit chunks per lane for the wider RDNA3 fragment
-                xi[0] = xs[0];
-                xi[1] = xs[1];
-                const int64_t * xs1 = xs + 2;
-                xi[2] = xs1[0];
-                xi[3] = xs1[1];
-#endif // defined(RDNA4)
+        // All wmma layout has contiguous data when i-major.
+        if constexpr (is_i_major(dl)) {
+            // the data must be aligned to 16 bytes when bigger than ggml_cuda_get_max_cpy_bytes()
+            constexpr int aligned_copy_bytes = ggml_cuda_get_max_cpy_bytes();
+            if constexpr (sizeof(t.x) > aligned_copy_bytes) {
+                static_assert(sizeof(t.x) % aligned_copy_bytes == 0, "bad type size");
+                constexpr int aligned_copy_count = sizeof(t.x)/aligned_copy_bytes;
+#pragma unroll
+                for (int i = 0; i < aligned_copy_count; ++i) {
+                    ggml_cuda_memcpy_1<aligned_copy_bytes>(t.x + t.ne/aligned_copy_count*i, xs0 + t.get_i(0) * stride + t.get_j(t.ne/aligned_copy_count*i));
+                }
             } else {
-                NO_DEVICE_CODE;
+                ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
             }
         } else {
-            NO_DEVICE_CODE;
+#pragma unroll
+            for (int l = 0; l < t.ne; ++l) {
+                t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
+            }
         }
 #else
 #pragma unroll
@@ -660,9 +673,9 @@ namespace ggml_cuda_mma {
 #endif // TURING_MMA_AVAILABLE
     }
 
-    template <typename T>
+    template <typename T, data_layout dl>
     static __device__ __forceinline__ void load_ldmatrix(
-            tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
+            tile<16, 8, T, dl> & t, const T * __restrict__ xs0, const int stride) {
 #if defined(TURING_MMA_AVAILABLE)
         int * xi = (int * ) t.x;
         const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
@@ -832,8 +845,9 @@ namespace ggml_cuda_mma {
 #endif // TURING_MMA_AVAILABLE
     }
 
+    template <data_layout dl_ab, data_layout dl_d>
     static __device__ __forceinline__ void mma(
-            tile<16, 8, float> & D, const tile<16, 8, float> & A, const tile<8, 8, float> & B) {
+            tile<16, 8, float, dl_d> & D, const tile<16, 8, float, dl_ab> & A, const tile<8, 8, float, dl_ab> & B) {
 #ifdef AMPERE_MMA_AVAILABLE
         const int * Axi = (const int *) A.x;
         const int * Bxi = (const int *) B.x;
@@ -887,8 +901,9 @@ namespace ggml_cuda_mma {
 #endif // AMPERE_MMA_AVAILABLE
     }
 
+    template <data_layout dl_ab, data_layout dl_d>
     static __device__ __forceinline__ void mma(
-            tile<16, 16, float> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
+            tile<16, 16, float, dl_d> & D, const tile<16, 8, half2, dl_ab> & A, const tile<16, 8, half2, dl_ab> & B) {
 #ifdef TURING_MMA_AVAILABLE
         const int * Axi = (const int *) A.x;
         const int * Bxi = (const int *) B.x;
@@ -940,8 +955,9 @@ namespace ggml_cuda_mma {
 #endif // TURING_MMA_AVAILABLE
     }
 
+    template <data_layout dl_ab, data_layout dl_d>
     static __device__ __forceinline__ void mma(
-            tile<16, 16, float> & D, const tile<16, 8, nv_bfloat162> & A, const tile<16, 8, nv_bfloat162> & B) {
+            tile<16, 16, float, dl_d> & D, const tile<16, 8, nv_bfloat162, dl_ab> & A, const tile<16, 8, nv_bfloat162, dl_ab> & B) {
 #if defined(AMD_WMMA_AVAILABLE)
 #if defined(RDNA4)
         using bf16x8_t = __attribute__((ext_vector_type(8))) __bf16;
@@ -967,8 +983,9 @@ namespace ggml_cuda_mma {
 #endif // AMPERE_MMA_AVAILABLE
     }
 
+    template <data_layout dl_d, data_layout dl_ab>
     static __device__ __forceinline__ void mma(
-            tile<16, 16, int> & D, const tile<16, 8, int> & A, const tile<16, 8, int> & B) {
+            tile<16, 16, int, dl_d> & D, const tile<16, 8, int, dl_ab> & A, const tile<16, 8, int, dl_ab> & B) {
 #if defined(AMD_MFMA_AVAILABLE)
         using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
         int32x4_t * acc = (int32x4_t *) D.x;
@@ -1122,8 +1139,9 @@ namespace ggml_cuda_mma {
 #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
     }
 
-static __device__ __forceinline__ void mma(
-            tile<16, 16, int> & D, const tile<16, 4, int> & A, const tile<16, 4, int> & B) {
+    template <data_layout dl_d, data_layout dl_ab>
+    static __device__ __forceinline__ void mma(
+            tile<16, 16, int, dl_d> & D, const tile<16, 4, int, dl_ab> & A, const tile<16, 4, int, dl_ab> & B) {
 #if defined(AMD_WMMA_AVAILABLE)
         using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
         int32x8_t * acc = (int32x8_t *) D.x;

+ 14 - 10
ggml/src/ggml-cuda/mmf.cuh

@@ -32,11 +32,13 @@ static __global__ void mul_mat_f(
 #if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
 #if defined(AMD_WMMA_AVAILABLE)
     // Special case for tf32, just dummy mma layout as wmma doesn't support it.
-    constexpr int tile_B_I = std::is_same_v<T, float> ? 8 : 16;
-    constexpr int tile_C_J = std::is_same_v<T, float> ? 8 : 16;
-    typedef tile<16,       8, T>     tile_A;
-    typedef tile<tile_B_I, 8, T>     tile_B;
-    typedef tile<16,       tile_C_J, float> tile_C;
+    constexpr bool is_tf32 = std::is_same_v<T, float>;
+    constexpr int tile_B_I = is_tf32 ? 8 : 16;
+    constexpr int tile_C_J = is_tf32 ? 8 : 16;
+    constexpr data_layout ab_layout = is_tf32 ? DATA_LAYOUT_I_MAJOR : get_input_data_layout();
+    typedef tile<16,       8,        T,     ab_layout>           tile_A;
+    typedef tile<tile_B_I, 8,        T,     ab_layout>           tile_B;
+    typedef tile<16,       tile_C_J, float, DATA_LAYOUT_J_MAJOR> tile_C;
 #else
 #ifdef VOLTA_MMA_AVAILABLE
     if constexpr (!std::is_same_v<T, half2>) {NO_DEVICE_CODE;} else {
@@ -272,11 +274,13 @@ static __global__ void mul_mat_f_ids(
 #if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
 #if defined(AMD_WMMA_AVAILABLE)
     // Special case for tf32, just dummy mma layout as wmma doesn't support it.
-    constexpr int tile_B_I = std::is_same_v<T, float> ? 8 : 16;
-    constexpr int tile_C_J = std::is_same_v<T, float> ? 8 : 16;
-    typedef tile<16,       8, T>     tile_A;
-    typedef tile<tile_B_I, 8, T>     tile_B;
-    typedef tile<16,       tile_C_J, float> tile_C;
+    constexpr bool is_tf32 = std::is_same_v<T, float>;
+    constexpr int tile_B_I = is_tf32 ? 8 : 16;
+    constexpr int tile_C_J = is_tf32 ? 8 : 16;
+    constexpr data_layout ab_layout = is_tf32 ? DATA_LAYOUT_I_MAJOR : get_input_data_layout();
+    typedef tile<16,       8,        T,     ab_layout>           tile_A;
+    typedef tile<tile_B_I, 8,        T,     ab_layout>           tile_B;
+    typedef tile<16,       tile_C_J, float, DATA_LAYOUT_J_MAJOR> tile_C;
 #else
 #ifdef VOLTA_MMA_AVAILABLE
     if constexpr (!std::is_same_v<T, half2>) {NO_DEVICE_CODE;} else {

+ 36 - 29
ggml/src/ggml-cuda/mmq.cuh

@@ -797,9 +797,10 @@ template <int mmq_x, int mmq_y, mmq_q8_1_ds_layout ds_layout>
 static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
     const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
 #if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
-    typedef tile<16,  8, int> tile_A;
-    typedef tile<16,  8, int> tile_B;
-    typedef tile<16, 16, int> tile_C;
+    constexpr data_layout input_layout = get_input_data_layout();
+    typedef tile<16,  8, int, input_layout>        tile_A;
+    typedef tile<16,  8, int, input_layout>        tile_B;
+    typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
 
     constexpr int granularity = mmq_get_granularity_device(mmq_x);
     constexpr int rows_per_warp = granularity;
@@ -966,9 +967,10 @@ template <int mmq_x, int mmq_y>
 static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
     const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
 #if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
-    typedef tile<16,  8, int> tile_A;
-    typedef tile<16,  8, int> tile_B;
-    typedef tile<16, 16, int> tile_C;
+    constexpr data_layout input_layout = get_input_data_layout();
+    typedef tile<16,  8, int, input_layout>        tile_A;
+    typedef tile<16,  8, int, input_layout>        tile_B;
+    typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
 
     constexpr int granularity = mmq_get_granularity_device(mmq_x);
     constexpr int rows_per_warp = granularity;
@@ -1130,10 +1132,11 @@ template <int mmq_x, int mmq_y>
 static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
     const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
 #if defined(AMD_MFMA_AVAILABLE)
-    typedef tile<16,  8, int> tile_A;
-    typedef tile<16,  8, int> tile_B;
-    typedef tile<16, 16, int> tile_C;
-    typedef tile<64,  2, int> tile_load;
+    constexpr data_layout input_layout = get_input_data_layout();
+    typedef tile<16,  8, int, input_layout>        tile_A;
+    typedef tile<16,  8, int, input_layout>        tile_B;
+    typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
+    typedef tile<64,  2, int, input_layout>        tile_load;
 
     constexpr int granularity = mmq_get_granularity_device(mmq_x);
     constexpr int rows_per_warp = granularity;
@@ -1179,9 +1182,10 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
         }
     }
 #elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
-    typedef tile<16,  4, int> tile_A;
-    typedef tile<16,  4, int> tile_B;
-    typedef tile<16, 16, int> tile_C;
+    constexpr data_layout input_layout = get_input_data_layout();
+    typedef tile<16,  4, int, input_layout>        tile_A;
+    typedef tile<16,  4, int, input_layout>        tile_B;
+    typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
 
     constexpr int granularity = mmq_get_granularity_device(mmq_x);
     constexpr int rows_per_warp = granularity;
@@ -1435,10 +1439,11 @@ template <int mmq_x, int mmq_y>
 static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
     const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
 #if defined(AMD_MFMA_AVAILABLE)
-    typedef tile<16,  8, int> tile_A;
-    typedef tile<16,  8, int> tile_B;
-    typedef tile<16, 16, int> tile_C;
-    typedef tile<64,  2, int> tile_load;
+    constexpr data_layout input_layout = get_input_data_layout();
+    typedef tile<16,  8, int, input_layout>        tile_A;
+    typedef tile<16,  8, int, input_layout>        tile_B;
+    typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
+    typedef tile<64,  2, int, input_layout>        tile_load;
 
     constexpr int granularity = mmq_get_granularity_device(mmq_x);
     constexpr int rows_per_warp = granularity;
@@ -1501,10 +1506,10 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
         }
     }
 #elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
-
-    typedef tile<16,  4, int> tile_A;
-    typedef tile<16,  4, int> tile_B;
-    typedef tile<16, 16, int> tile_C;
+    constexpr data_layout input_layout = get_input_data_layout();
+    typedef tile<16,  4, int, input_layout>        tile_A;
+    typedef tile<16,  4, int, input_layout>        tile_B;
+    typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
 
     constexpr int granularity = mmq_get_granularity_device(mmq_x);
     constexpr int rows_per_warp = granularity;
@@ -2265,10 +2270,11 @@ template <int mmq_x, int mmq_y>
 static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
     const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
 #if defined(AMD_MFMA_AVAILABLE)
-    typedef tile<16,  8, int> tile_A;
-    typedef tile<16,  8, int> tile_B;
-    typedef tile<16, 16, int> tile_C;
-    typedef tile<64,  2, int> tile_load;
+    constexpr data_layout input_layout = get_input_data_layout();
+    typedef tile<16,  8, int, input_layout>        tile_A;
+    typedef tile<16,  8, int, input_layout>        tile_B;
+    typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
+    typedef tile<64,  2, int, input_layout>        tile_load;
 
     constexpr int granularity = mmq_get_granularity_device(mmq_x);
     constexpr int rows_per_warp = granularity;
@@ -2316,9 +2322,10 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
         }
     }
 #elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
-    typedef tile<16,  4, int> tile_A;
-    typedef tile<16,  4, int> tile_B;
-    typedef tile<16, 16, int> tile_C;
+    constexpr data_layout input_layout = get_input_data_layout();
+    typedef tile<16,  4, int, input_layout>        tile_A;
+    typedef tile<16,  4, int, input_layout>        tile_B;
+    typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
 
     constexpr int granularity = mmq_get_granularity_device(mmq_x);
     constexpr int rows_per_warp = granularity;
@@ -3015,7 +3022,7 @@ static __device__ __forceinline__ void mmq_write_back_mma(
 
 #if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
     constexpr int tileC_IJ = mmq_get_granularity_device(0);
-    typedef tile<tileC_IJ, tileC_IJ, int> tile_C;
+    typedef tile<tileC_IJ, tileC_IJ, int, DATA_LAYOUT_J_MAJOR> tile_C;
     constexpr int rows_per_warp = granularity;
 #else
     typedef tile<16, 8, int> tile_C;