Browse Source

HIP: add fattn-mma-f16 for RDNA4 (#18481)

* finish VQ mma

* flash_attn_ext_f16_iter

* KQ_rowsum

* correct exp

* fix scale error

* fix softmax scale

* fix softmax scale

* enable fattn on cpu side

* fix random error

* disable fattn-mma-f16 on rdna3

* fix wrong col for rdna

* use identity mat to transpose

* resolve conflicts

* basic tuning for DeepSeek-R1-Distill-Qwen-1.5B

* fix volta compile error

* align rdna4 policy for fattn

* adjust fattn policy

* adjust kernel selection logic

* update as the review comments

* keep fattn-wmma logic

* adjust kernel selection logic

---------

Co-authored-by: zhang hui <you@example.com>
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
yulo 2 weeks ago
parent
commit
ea4a321f2a

+ 4 - 0
ggml/src/ggml-cuda/common.cuh

@@ -262,6 +262,10 @@ static const char * cu_get_error_str(CUresult err) {
 #define FLASH_ATTN_AVAILABLE
 #endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220)
 
+#if defined(TURING_MMA_AVAILABLE)
+#define LDMATRIX_TRANS_AVAILABLE
+#endif // defined(TURING_MMA_AVAILABLE)
+
 static bool fp16_available(const int cc) {
     return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL ||
         (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_PH1);

+ 1 - 1
ggml/src/ggml-cuda/fattn-common.cuh

@@ -914,7 +914,7 @@ void launch_fattn(
 
         const int nblocks_stream_k = max_blocks;
 
-        const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || tiles_efficiency_percent < 75;
+        const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || amd_wmma_available(cc) || tiles_efficiency_percent < 75;
 
         blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total;
         blocks_num.y = 1;

+ 173 - 30
ggml/src/ggml-cuda/fattn-mma-f16.cuh

@@ -98,6 +98,19 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
     return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
 }
 
+static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_rdna(const int DKQ, const int DV, const int ncols) {
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 128, 2,  64, 128, 128, 128, 2, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2,  64, 128, 128,  64, 2, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2,  64, 128, 128,  64, 2, true);
+
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16,  64, 4,  32,  96,  64, 128, 1, false);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2,  32, 160, 128, 128, 1, false);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1,  32, 160, 128, 128, 1, false);
+
+    // TODO tune specifically for RDNA
+    return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
+}
+
 static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, const int DV, const int ncols, const int cc) {
     if (ampere_mma_available(cc)) {
         return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
@@ -105,6 +118,9 @@ static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, c
     if (turing_mma_available(cc)) {
         return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols);
     }
+    if (amd_wmma_available(cc)) {
+        return ggml_cuda_fattn_mma_get_config_rdna(DKQ, DV, ncols);
+    }
     GGML_ASSERT(volta_mma_available(cc));
     return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols);
 }
@@ -116,6 +132,8 @@ static constexpr __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config(cons
     return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols);
 #elif defined(VOLTA_MMA_AVAILABLE)
     return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols);
+#elif defined(AMD_WMMA_AVAILABLE)
+    return ggml_cuda_fattn_mma_get_config_rdna(DKQ, DV, ncols);
 #else
     GGML_UNUSED_VARS(DKQ, DV, ncols);
     return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false);
@@ -186,6 +204,23 @@ static constexpr __device__ bool ggml_cuda_fattn_mma_get_Q_in_reg(const int DKQ,
     return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).Q_in_reg;
 }
 
+static constexpr __device__ int get_cols_per_thread() {
+#if defined(AMD_WMMA_AVAILABLE)
+    return 1; // RDNA has a single column.
+#else
+    return 2; // This is specifically KQ columns, Volta only has a single VKQ column.
+#endif // defined(AMD_WMMA_AVAILABLE)
+}
+
+static __host__ int get_cols_per_warp(const int cc) {
+    if (turing_mma_available(cc) || amd_wmma_available(cc)) {
+        return 16;
+    } else {
+        // Volta
+        return 32;
+    }
+}
+
 // ------------------------------------------------------------------------------------------------------------------
 
 static __host__ int ggml_cuda_fattn_mma_get_nstages(const int DKQ, const int DV, const int ncols1, const int ncols2, const int cc) {
@@ -393,10 +428,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
         const int jt,
         const int kb0,
         const int k_VKQ_sup) {
-#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
     constexpr int  ncols           = ncols1 * ncols2;
     constexpr int  cols_per_warp   = T_B_KQ::I;
-    constexpr int  cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column.
+    constexpr int  cols_per_thread = get_cols_per_thread();
     constexpr int  np              = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
     constexpr int  nbatch_fa       = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
     constexpr int  nbatch_K2       = ggml_cuda_fattn_mma_get_nbatch_K2(DKQ, DV, ncols);
@@ -413,6 +448,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
     const int k_VKQ_0 = kb0 * nbatch_fa;
 #if defined(TURING_MMA_AVAILABLE)
     T_C_KQ KQ_C[nbatch_fa/(np*(cols_per_warp == 8 ? T_C_KQ::I : T_C_KQ::J))];
+#elif defined(AMD_WMMA_AVAILABLE)
+    T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)];
 #else // Volta
     T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)];
 #endif // defined(TURING_MMA_AVAILABLE)
@@ -461,8 +498,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
                     if constexpr (cols_per_warp == 8) {
                         mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]);
                     } else {
-                        // Wide version of KQ_C is column-major => swap A and B.
+                        // Wide version of KQ_C is column-major
+#if defined(AMD_WMMA_AVAILABLE)
+                        // RDNA matrix C is column-major.
+                        mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]);
+#else
+                        // swap A and B for CUDA.
                         mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[k_KQ_0/T_A_KQ::J], K_A);
+#endif // defined(AMD_WMMA_AVAILABLE)
                     }
                 }
             }
@@ -479,8 +522,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
                     T_A_KQ K_A;
                     load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
 
-                    // Wide version of KQ_C is column-major => swap A and B.
+                    // Wide version of KQ_C is column-major
+#if defined(AMD_WMMA_AVAILABLE)
+                    // RDNA matrix C is column-major.
+                    mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
+#else
+                    // swap A and B for CUDA.
                     mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
+#endif // defined(AMD_WMMA_AVAILABLE)
                 }
             }
         }
@@ -532,7 +581,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
 #pragma unroll
             for (int l = 0; l < T_C_KQ::ne; ++l) {
                 if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) {
-                    KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k0/(np*T_C_KQ::I)].x[l] + FATTN_KQ_MAX_OFFSET);
+#if defined(AMD_WMMA_AVAILABLE)
+                    constexpr int KQ_idx = 0;
+#else
+                    // Turing + Volta:
+                    const int KQ_idx = l % 2;
+#endif // defined(AMD_WMMA_AVAILABLE)
+                    KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[k0/(np*T_C_KQ::I)].x[l] + FATTN_KQ_MAX_OFFSET);
                 }
             }
         }
@@ -552,8 +607,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
 #pragma unroll
             for (int l = 0; l < T_C_KQ::ne; ++l) {
                 if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) {
-                    KQ_C[k0/(np*T_C_KQ::I)].x[l] = expf(KQ_C[k0/(np*T_C_KQ::I)].x[l] - KQ_max_new[l % 2]);
-                    KQ_rowsum_add[l % 2] += KQ_C[k0/(np*T_C_KQ::I)].x[l];
+#if defined(AMD_WMMA_AVAILABLE)
+                    constexpr int KQ_idx = 0;
+#else
+                    // Turing + Volta:
+                    const int KQ_idx = l % 2;
+#endif // defined(AMD_WMMA_AVAILABLE)
+                    KQ_C[k0/(np*T_C_KQ::I)].x[l] = expf(KQ_C[k0/(np*T_C_KQ::I)].x[l] - KQ_max_new[KQ_idx]);
+                    KQ_rowsum_add[KQ_idx] += KQ_C[k0/(np*T_C_KQ::I)].x[l];
                 } else {
                     KQ_C[k0/(np*T_C_KQ::I)].x[l] = 0.0f;
                 }
@@ -584,8 +645,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
 #pragma unroll
             for (int l = 0; l < T_C_KQ::ne; ++l) {
                 if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) {
+#if defined(AMD_WMMA_AVAILABLE)
+                    constexpr int KQ_idx = 0;
+#else
                     // Turing + Volta:
-                    KQ_max_new[(l/2) % 2] = fmaxf(KQ_max_new[(l/2) % 2], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + FATTN_KQ_MAX_OFFSET);
+                    const int KQ_idx = (l/2) % 2;
+#endif // defined(AMD_WMMA_AVAILABLE)
+                    KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + FATTN_KQ_MAX_OFFSET);
                 }
             }
         }
@@ -596,7 +662,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
             // Values per KQ column are spread across 4 threads:
             constexpr int offset_first = 2;
             constexpr int offset_last  = 1;
-#else
+#elif defined(AMD_WMMA_AVAILABLE)
+            // Values per KQ column are spread across 2 threads:
+            constexpr int offset_first = 16;
+            constexpr int offset_last  = 16;
+#else // Volta
             // Values per KQ column are spread across 2 threads:
             constexpr int offset_first = 2;
             constexpr int offset_last  = 2;
@@ -612,10 +682,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
         for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::J) {
 #pragma unroll
             for (int l = 0; l < T_C_KQ::ne; ++l) {
-                // Turing + Volta:
                 if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) {
-                    KQ_C[(k0/(np*T_C_KQ::J))].x[l] = expf(KQ_C[(k0/(np*T_C_KQ::J))].x[l] - KQ_max_new[(l/2) % 2]);
-                    KQ_rowsum_add[(l/2) % 2] += KQ_C[(k0/(np*T_C_KQ::J))].x[l];
+#if defined(AMD_WMMA_AVAILABLE)
+                    constexpr int KQ_idx = 0;
+#else
+                    // Turing + Volta:
+                    const int KQ_idx = (l/2) % 2;
+#endif // defined(AMD_WMMA_AVAILABLE)
+                    KQ_C[(k0/(np*T_C_KQ::J))].x[l] = expf(KQ_C[(k0/(np*T_C_KQ::J))].x[l] - KQ_max_new[KQ_idx]);
+                    KQ_rowsum_add[KQ_idx] += KQ_C[(k0/(np*T_C_KQ::J))].x[l];
                 } else {
                     KQ_C[(k0/(np*T_C_KQ::J))].x[l] = 0.0f;
                 }
@@ -639,7 +714,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
 
 #if defined(TURING_MMA_AVAILABLE)
         if constexpr (cols_per_warp == 8) {
-            const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
+            const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[cols_per_thread - 1]);
 #pragma unroll
             for (int i = 0; i < DV/T_C_VKQ::I; ++i) {
 #pragma unroll
@@ -660,6 +735,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
                 }
             }
         }
+#elif defined(AMD_WMMA_AVAILABLE)
+        const half2 KQ_max_scale_h2 = make_half2(
+            KQ_max_scale[0], KQ_max_scale[0]);
+#pragma unroll
+        for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
+#pragma unroll
+            for (int l = 0; l < T_C_VKQ::ne; ++l) {
+                VKQ_C[i].x[l] *= KQ_max_scale_h2;
+            }
+        }
 #else // Volta
         const half2 KQ_max_scale_h2 = make_half2(
             KQ_max_scale[(threadIdx.x / 2) % 2], KQ_max_scale[(threadIdx.x / 2) % 2]);
@@ -707,6 +792,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
     // Therefore, iterate over V in reverse and re-use the data if possible.
     static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented");
     constexpr int reusable_cutoff = mla ? (DKQ - 1) - (DKQ - 1) % (2*nbatch_K2) - (DKQ - DV) : DV;
+#if defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE)
+    T_A_VKQ A_identity;
+    make_identity_mat(A_identity);
+#endif // defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE)
 
     // Calculate VKQ tile, need to use logical rather than physical elements for i0 due to transposition of V:
 #pragma unroll
@@ -727,7 +816,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
         }
         const half2 * tile_V_i = i0_start < reusable_cutoff ? tile_V : tile_V + (i0_start - reusable_cutoff)/2;
 
-#if defined(TURING_MMA_AVAILABLE)
+#if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
         constexpr int i0_stride = cols_per_warp == 8 ? T_C_VKQ::I : 2*T_C_VKQ::J;
 #pragma unroll
         for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += i0_stride) {
@@ -737,12 +826,26 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
                 const int k0 = k00 + (threadIdx.y % np)*T_A_VKQ::J;
 
                 T_A_VKQ A; // Transposed in SRAM but not in registers, gets transposed on load.
+#if defined(LDMATRIX_TRANS_AVAILABLE)
                 load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
+#else
+                // TODO: Try to transpose tile_V when loading gmem to smem.
+                // Use mma to transpose T_A_VKQ for RDNA.
+                T_A_VKQ A_trans;
+                load_ldmatrix(A_trans, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
+                mma(A, A_trans, A_identity);
+#endif // defined(TURING_MMA_AVAILABLE)
                 if constexpr (T_B_KQ::I == 8) {
                     mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]);
                 } else {
-                    // Wide version of VKQ_C is column-major => swap A and B.
+                    // Wide version of VKQ_C is column-major.
+#if defined(AMD_WMMA_AVAILABLE)
+                    // RDNA matrix C is column-major.
+                    mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]);
+#else
+                    // swap A and B for CUDA.
                     mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::J)], A);
+#endif // defined(AMD_WMMA_AVAILABLE)
                 }
             }
         }
@@ -761,7 +864,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
                 mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::I)], A);
             }
         }
-#endif // defined(TURING_MMA_AVAILABLE)
+#endif // defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 
         if constexpr (nstages <= 1) {
             __syncthreads(); // Only needed if tile_K == tile_V.
@@ -774,7 +877,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
         tile_Q, tile_K, tile_V, tile_mask,
         Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
     NO_DEVICE_CODE;
-#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
 }
 
 #if defined(TURING_MMA_AVAILABLE)
@@ -794,6 +897,15 @@ template<> struct mma_tile_sizes<8> {
     using T_B_VKQ = tile< 8,  8, half2>; // column-major
     using T_C_VKQ = tile<16,  4, half2>; // row-major
 };
+#elif defined(AMD_WMMA_AVAILABLE)
+template<int ncols> struct mma_tile_sizes {
+    using T_A_KQ  = tile<16,  8, half2>; // row-major
+    using T_B_KQ  = tile<16,  8, half2>; // column-major
+    using T_C_KQ  = tile<16, 16, float>; // column-major
+    using T_A_VKQ = tile<16,  8, half2>; // row-major
+    using T_B_VKQ = tile<16,  8, half2>; // column-major
+    using T_C_VKQ = tile<16,  8, half2>; // column-major
+};
 #else // Volta
 template<int ncols> struct mma_tile_sizes {
     using T_A_KQ  = tile< 8,  4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major
@@ -828,7 +940,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         const int jt,
         const int kb0_start,
         const int kb0_stop) {
-#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
     //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
 
     constexpr int ncols = ncols1 * ncols2;
@@ -840,7 +952,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
     using     T_C_VKQ   = typename mma_tile_sizes<ncols>::T_C_VKQ;
 
     constexpr int  cols_per_warp   = T_B_KQ::I;
-    constexpr int  cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column.
+    constexpr int  cols_per_thread = get_cols_per_thread();
     constexpr int  np              = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
     constexpr int  nbatch_fa       = ggml_cuda_fattn_mma_get_nbatch_fa     (DKQ, DV, ncols);
     constexpr int  nbatch_K2       = ggml_cuda_fattn_mma_get_nbatch_K2     (DKQ, DV, ncols);
@@ -871,6 +983,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
     T_B_KQ    Q_B[(Q_in_reg ? DKQ/(2*T_B_KQ::J) : 1)];
 #if defined(TURING_MMA_AVAILABLE)
     T_C_VKQ VKQ_C[cols_per_warp == 8 ? DV/T_C_VKQ::I : DV/(2*T_C_VKQ::J)];
+#elif defined(AMD_WMMA_AVAILABLE)
+    T_C_VKQ VKQ_C[                                     DV/(2*T_C_VKQ::J)];
 #else // Volta
     T_C_VKQ VKQ_C[                                     DV/(2*T_C_VKQ::J)];
 #endif // defined(TURING_MMA_AVAILABLE)
@@ -1010,6 +1124,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         // The partial sums are spread across 8/4 threads.
         constexpr int offset_first = cols_per_warp == 8 ? 16 : 2;
         constexpr int offset_last  = cols_per_warp == 8 ?  4 : 1;
+#elif defined(AMD_WMMA_AVAILABLE)
+        // The partial sums are spread across 2 threads.
+        constexpr int offset_first = 16;
+        constexpr int offset_last  = 16;
 #else // Volta
         // The partial sums are spread across 2 threads.
         constexpr int offset_first = 2;
@@ -1047,7 +1165,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
 
 #if defined(TURING_MMA_AVAILABLE)
         if constexpr (cols_per_warp == 8) {
-            const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
+            const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[cols_per_thread - 1]);
 #pragma unroll
             for (int i = 0; i < DV/T_C_VKQ::I; ++i) {
 #pragma unroll
@@ -1068,6 +1186,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
                 }
             }
         }
+#elif defined(AMD_WMMA_AVAILABLE)
+        const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[0]);
+#pragma unroll
+        for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
+#pragma unroll
+            for (int l = 0; l < T_C_VKQ::ne; ++l) {
+                VKQ_C[i].x[l] *= KQ_max_scale_h2;
+            }
+        }
 #else // Volta
         const int col = (threadIdx.x / 2) % 2;
         const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
@@ -1119,6 +1246,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(threadIdx.x % 4);
         const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]);
         const bool thread_should_write = threadIdx.x % 4 < cols_per_thread;
+#elif defined(AMD_WMMA_AVAILABLE)
+        const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(0);
+        const float2 KQ_cmr = make_float2(KQ_max[0], KQ_rowsum[0]);
+        const bool thread_should_write = threadIdx.x / 16 < cols_per_thread;
 #else // Volta
         const int jc_cwm = threadIdx.y*cols_per_warp + T_C_KQ::get_i(threadIdx.x & 2);
         const float2 KQ_cmr = make_float2(KQ_max[(threadIdx.x & 2) / 2], KQ_rowsum[(threadIdx.x & 2) / 2]);
@@ -1319,7 +1450,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         stride_Q1, stride_Q2, stride_K, stride_V, stride_mask,
         jt, kb0_start, kb0_stop);
     NO_DEVICE_CODE;
-#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
 }
 
 template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap, bool mla>
@@ -1346,7 +1477,7 @@ static __global__ void flash_attn_ext_f16(
                             const int32_t nb21, const int32_t nb22, const int64_t nb23,
                             const int32_t ne31, const int32_t ne32, const int32_t ne33,
                             const int32_t nb31, const int32_t nb32, const int64_t nb33) {
-#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE))
+#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)))
 
     // Skip unused kernel variants for faster compilation:
     if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) {
@@ -1360,6 +1491,13 @@ static __global__ void flash_attn_ext_f16(
     }
 #endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
 
+#if defined(AMD_WMMA_AVAILABLE)
+    if (ncols1*ncols2 > 32 || ncols1*ncols2 < 16 || DKQ > 128 || ncols2 == 1) {
+        NO_DEVICE_CODE;
+        return;
+    }
+#endif // defined(AMD_WMMA_AVAILABLE)
+
     static_assert(!mla || DKQ >= DV, "MLA needs DKQ >= DV");
 
     constexpr int ncols     = ncols1 * ncols2;
@@ -1473,7 +1611,7 @@ static __global__ void flash_attn_ext_f16(
               ne31, ne32, ne33,
               nb31, nb32, nb33);
     NO_DEVICE_CODE;
-#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE))
+#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)))
 }
 
 template <int DKQ, int DV, int ncols1, int ncols2>
@@ -1492,7 +1630,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
     const bool Q_in_reg       = ggml_cuda_fattn_mma_get_Q_in_reg      (DKQ, DV, ncols, cc);
     const int  nstages        = ggml_cuda_fattn_mma_get_nstages       (DKQ, DV, ncols1, ncols2, cc);
 
-    const int cols_per_warp = std::min(ncols, turing_mma_available(cc) ? 16 : 32);
+    const int cols_per_warp = std::min(ncols, get_cols_per_warp(cc));
     const int nwarps        = nthreads / WARP_SIZE;
 
     constexpr bool mla = DKQ == 576;
@@ -1512,29 +1650,34 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
     float logit_softcap;
     memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
 
+#if defined(GGML_USE_HIP)
+    using fattn_kernel_ptr_t = const void*;
+#else
+    using fattn_kernel_ptr_t = fattn_kernel_t;
+#endif // defined(GGML_USE_HIP)
     fattn_kernel_t fattn_kernel;
     if (logit_softcap == 0.0f) {
         constexpr bool use_logit_softcap = false;
         fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, mla>;
 
-#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
+#if !defined(GGML_USE_MUSA)
         static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
         if (!shared_memory_limit_raised[id]) {
-            CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
+            CUDA_CHECK(cudaFuncSetAttribute(reinterpret_cast<fattn_kernel_ptr_t>(fattn_kernel), cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
             shared_memory_limit_raised[id] = true;
         }
-#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
+#endif // !defined(GGML_USE_MUSA)
     } else {
         constexpr bool use_logit_softcap = true;
         fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, mla>;
 
-#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
+#if !defined(GGML_USE_MUSA)
         static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
         if (!shared_memory_limit_raised[id]) {
-            CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
+            CUDA_CHECK(cudaFuncSetAttribute(reinterpret_cast<fattn_kernel_ptr_t>(fattn_kernel), cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
             shared_memory_limit_raised[id] = true;
         }
-#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
+#endif // !defined(GGML_USE_MUSA)
     }
 
     launch_fattn<DV, ncols1, ncols2>

+ 39 - 3
ggml/src/ggml-cuda/fattn.cu

@@ -18,12 +18,12 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con
         }
     }
 
-    if (turing_mma_available(cc) && Q->ne[1] <= 16/ncols2) {
+    if ((turing_mma_available(cc) || amd_wmma_available(cc)) && Q->ne[1] <= 16/ncols2) {
         ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 16/ncols2, ncols2>(ctx, dst);
         return;
     }
 
-    if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING || Q->ne[1] <= 32/ncols2) {
+    if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING || amd_wmma_available(cc) || Q->ne[1] <= 32/ncols2) {
         ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 32/ncols2, ncols2>(ctx, dst);
         return;
     }
@@ -230,7 +230,18 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
 
     // The effective batch size for the kernel can be increased by gqa_ratio.
     // The kernel versions without this optimization are also used for ALiBi, if there is no mask, or if the KV cache is not padded,
-    const bool gqa_opt_applies = gqa_ratio % 2 == 0 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
+    bool gqa_opt_applies = gqa_ratio % 2 == 0 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
+    for (const ggml_tensor * t : {Q, K, V, mask}) {
+        if (t == nullptr) {
+            continue;
+        }
+        for (size_t i = 1; i < GGML_MAX_DIMS; ++i) {
+            if (t->nb[i] % 16 != 0) {
+                gqa_opt_applies = false;
+                break;
+            }
+        }
+    }
 
     const int cc = ggml_cuda_info().devices[device].cc;
 
@@ -337,6 +348,31 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
         return BEST_FATTN_KERNEL_WMMA_F16;
     }
 
+    if (amd_wmma_available(cc) && GGML_CUDA_CC_IS_RDNA4(cc) && gqa_opt_applies && Q->ne[0] <= 128 && Q->ne[0] != 40 && Q->ne[0] != 72) {
+        if (can_use_vector_kernel) {
+            if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
+                if (Q->ne[1] == 1) {
+                    if (!gqa_opt_applies) {
+                        return BEST_FATTN_KERNEL_VEC;
+                    }
+                }
+            } else {
+                if (Q->ne[1] <= 2) {
+                    return BEST_FATTN_KERNEL_VEC;
+                }
+            }
+        }
+        int gqa_ratio_eff = 1;
+        const int ncols2_max = Q->ne[0] == 576 ? 16 : 8;
+        while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) {
+            gqa_ratio_eff *= 2;
+        }
+        if (Q->ne[1] * gqa_ratio_eff <= 8) {
+            return BEST_FATTN_KERNEL_TILE; // AMD WMMA is only faster if the full tile width of 16 can be utilized.
+        }
+        return BEST_FATTN_KERNEL_MMA_F16;
+    }
+
     // If there are no tensor cores available, use the generic tile kernel:
     if (can_use_vector_kernel) {
         if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {

+ 47 - 2
ggml/src/ggml-cuda/mma.cuh

@@ -206,10 +206,16 @@ namespace ggml_cuda_mma {
 
         static __device__ __forceinline__ int get_j(const int l) {
             if constexpr (I == 16 && J == 16) {
-                // matrix C
 #if defined(RDNA3)
-                return 2 * l + (threadIdx.x / 16);
+                if constexpr (std::is_same_v<T, float> || std::is_same_v<T, int>) {
+                    // matrix C
+                    return 2 * l + (threadIdx.x / 16);
+                } else {
+                    // matrix A&B
+                    return l;
+                }
 #else
+                // matrix C is the transposed matrix A&B on RDNA4
                 return ne * (threadIdx.x / 16) + l;
 #endif // defined(RDNA3)
             } else if constexpr (I == 16 && J == 8) {
@@ -621,6 +627,21 @@ namespace ggml_cuda_mma {
 
         return ret;
     }
+#elif defined(AMD_WMMA_AVAILABLE)
+    template <int I, int J>
+    static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
+        tile<I, J/2, half2> ret;
+#pragma unroll
+        for (int l0 = 0; l0 < tile_float.ne; l0 += 2) {
+            ret.x[l0/2] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]);
+        }
+        return ret;
+    }
+
+    static __device__ __forceinline__ tile<8, 8, half2> get_transposed(const tile<16, 4, half2> & t) {
+        NO_DEVICE_CODE;
+        return tile<8, 8, half2>{};
+    }
 #else // Volta
     template <int I, int J>
     static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
@@ -639,6 +660,19 @@ namespace ggml_cuda_mma {
     }
 #endif // defined(TURING_MMA_AVAILABLE)
 
+    static __device__ __forceinline__ void make_identity_mat(tile<16, 8, half2> & t) {
+#if defined(RDNA4)
+        const int row = t.get_i(0);
+        const int left_right = t.get_j(0) / 4;
+        const int up_down = row / 8;
+        const int idx = row % 8;
+        reinterpret_cast<half*>(t.x)[idx] = left_right == up_down ? 1.0f : 0.0f;
+#else
+        GGML_UNUSED_VARS(t);
+        NO_DEVICE_CODE;
+#endif // defined(RDNA4)
+    }
+
     template <int I, int J, typename T, data_layout dl>
     static __device__ __forceinline__ void load_generic(tile<I, J, T, dl> & t, const T * __restrict__ xs0, const int stride) {
 #if defined(AMD_MFMA_AVAILABLE)
@@ -878,6 +912,17 @@ namespace ggml_cuda_mma {
             : "+r"(Dxi[2]), "+r"(Dxi[3])
             : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
 #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+#elif defined(AMD_WMMA_AVAILABLE)
+#if defined(RDNA4)
+        using halfx8_t = __attribute__((ext_vector_type(8))) _Float16;
+        halfx8_t& acc_frag = reinterpret_cast<halfx8_t&>(D.x[0]);
+        const halfx8_t& a_frag = reinterpret_cast<const halfx8_t&>(A.x[0]);
+        const halfx8_t& b_frag = reinterpret_cast<const halfx8_t&>(B.x[0]);
+        acc_frag = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12(a_frag, b_frag, acc_frag);
+#else
+        GGML_UNUSED_VARS(D, A, B);
+        NO_DEVICE_CODE;
+#endif // defined(RDNA4)
 #else
         GGML_UNUSED_VARS(D, A, B);
         NO_DEVICE_CODE;

+ 2 - 0
ggml/src/ggml-cuda/vendors/hip.h

@@ -138,6 +138,8 @@
 #define cudaStream_t hipStream_t
 #define cudaSuccess hipSuccess
 #define cudaOccupancyMaxActiveBlocksPerMultiprocessor hipOccupancyMaxActiveBlocksPerMultiprocessor
+#define cudaFuncSetAttribute hipFuncSetAttribute
+#define cudaFuncAttributeMaxDynamicSharedMemorySize hipFuncAttributeMaxDynamicSharedMemorySize
 #define __trap() do { abort(); __builtin_unreachable(); } while(0)
 #define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
 #define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED