1
0
Эх сурвалжийг харах

CUDA: generalized (mma) FA, add Volta support (#17505)

* CUDA: generalized (mma) FA, add Volta support

* use struct for MMA FA kernel config

---------

Co-authored-by: Aman Gupta <aman>
Johannes Gäßler 1 сар өмнө
parent
commit
2e1c9cd814

+ 1 - 1
ggml/include/ggml.h

@@ -2279,7 +2279,7 @@ extern "C" {
             float                 stop,
             float                 step);
 
-#define GGML_KQ_MASK_PAD 64
+#define GGML_KQ_MASK_PAD 1
 
     // q:    [n_embd_k, n_batch,     n_head,    ne3 ]
     // k:    [n_embd_k, n_kv,        n_head_kv, ne3 ]

+ 12 - 10
ggml/src/ggml-cuda/fattn-common.cuh

@@ -25,7 +25,7 @@ typedef void (* fattn_kernel_t)(
         const float m1,
         const uint32_t n_head_log2,
         const float logit_softcap,
-        const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
+        const int32_t ne00, const uint3   ne01, const int32_t ne02, const int32_t ne03,
                             const int32_t nb01, const int32_t nb02, const int32_t nb03,
         const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
                             const int32_t nb11, const int32_t nb12, const int64_t nb13,
@@ -621,7 +621,8 @@ static __global__ void flash_attn_mask_to_KV_max(
 template<int D, int ncols1, int ncols2> // D == head size
 __launch_bounds__(D, 1)
 static __global__ void flash_attn_stream_k_fixup(
-        float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11) {
+        float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11,
+        const int nbatch_fa) {
     constexpr int ncols = ncols1*ncols2;
 
     const int bidx0 = blockIdx.x;
@@ -632,8 +633,8 @@ static __global__ void flash_attn_stream_k_fixup(
 
     const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
 
-    const int iter_k = ne11 / FATTN_KQ_STRIDE;
-    const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
+    const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
+    const int iter_j = (ne01 + (ncols1    - 1)) / ncols1;
 
     const int kbc0      = (bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
     const int kbc0_stop = (bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
@@ -765,7 +766,7 @@ static __global__ void flash_attn_combine_results(
 template <int DV, int ncols1, int ncols2>
 void launch_fattn(
     ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
-    const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
+    const int nbatch_fa, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
 ) {
     constexpr int ncols = ncols1 * ncols2;
 
@@ -790,8 +791,6 @@ void launch_fattn(
     GGML_ASSERT(!V || V->nb[0] == ggml_element_size(V));
 
     GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
-    GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
-        "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
 
     ggml_cuda_pool & pool = ctx.pool();
     cudaStream_t main_stream = ctx.stream();
@@ -915,7 +914,7 @@ void launch_fattn(
 
         dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + DV) * sizeof(float));
     } else {
-        const int ntiles_KQ = (K->ne[1] + KQ_row_granularity - 1) / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
+        const int ntiles_KQ = (K->ne[1] + nbatch_fa - 1) / nbatch_fa; // Max. number of parallel blocks limited by tensor size.
 
         // parallel_blocks must not be larger than what the tensor size allows:
         parallel_blocks = std::min(parallel_blocks, ntiles_KQ);
@@ -970,6 +969,9 @@ void launch_fattn(
     const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
     const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
 
+    // TODO other tensor dimensions after removal of WMMA kernel:
+    const uint3 ne01 = init_fastdiv_values(Q->ne[1]);
+
     GGML_ASSERT(block_dim.x % warp_size == 0);
     fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>>(
         (const char *) Q->data,
@@ -980,7 +982,7 @@ void launch_fattn(
         KV_max.ptr,
         !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
         scale, max_bias, m0, m1, n_head_log2, logit_softcap,
-        Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3],
+        Q->ne[0], ne01,     Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3],
         K->ne[0], K->ne[1], K->ne[2], K->ne[3], nb11, nb12, nb13,
         nb21, nb22, nb23,
         mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0,
@@ -995,7 +997,7 @@ void launch_fattn(
 
             flash_attn_stream_k_fixup<DV, ncols1, ncols2>
                 <<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
-                ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1]);
+                ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], nbatch_fa);
         }
     } else if (parallel_blocks > 1) {
         const dim3 block_dim_combine(DV, 1, 1);

Файлын зөрүү хэтэрхий том тул дарагдсан байна
+ 383 - 400
ggml/src/ggml-cuda/fattn-mma-f16.cuh


+ 18 - 18
ggml/src/ggml-cuda/fattn-tile.cuh

@@ -501,6 +501,7 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
         const half2 * const __restrict__ K_h2,
         const half2 * const __restrict__ V_h2,
         const half  * const __restrict__ mask,
+        const uint3 ne01,
         const float logit_softcap,
         const float slope,
         T_KQ      * const KQ,
@@ -512,7 +513,8 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
         float * const KQ_sum,
         T_acc * const VKQ,
         const int k_VKQ_0,
-        const int k_VKQ_max) {
+        const int k_VKQ_max,
+        const int col_Q_0) {
     constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
     constexpr int cpy_ne = cpy_nb / 4;
 
@@ -556,7 +558,7 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
     // Apply logit softcap + mask, update KQ_max:
 #pragma unroll
     for (int jc0 = 0; jc0 < cpw; ++jc0) {
-        const int j = (jc0 + (threadIdx.y / np)*cpw)/ncols2;
+        const int j = fastmodulo(col_Q_0 + (jc0 + (threadIdx.y / np)*cpw)/ncols2, ne01);
 
 #pragma unroll
         for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) {
@@ -736,7 +738,7 @@ static __global__ void flash_attn_tile(
         const float m1,
         const uint32_t n_head_log2,
         const float logit_softcap,
-        const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
+        const int32_t ne00, const uint3   ne01, const int32_t ne02, const int32_t ne03,
                             const int32_t nb01, const int32_t nb02, const int32_t nb03,
         const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
                             const int32_t nb11, const int32_t nb12, const int64_t nb13,
@@ -781,11 +783,11 @@ static __global__ void flash_attn_tile(
     const int sequence = blockIdx.z / (ne02/ncols2);
     const int head0 = blockIdx.z*ncols2 - sequence*ne02; // == blockIdx.z % (ne02/ncols2)
     const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
-    const float * Q_f  = (const float *) (Q + nb03*sequence + nb02* head0              + nb01*col_Q_0);
+    const float * Q_f  = (const float *) (Q + nb03*sequence + nb02* head0);
     const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
     const half2 * V_h2 = (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); // K and V have same shape
 
-    const half * maskh = mask ? (const half *) (mask + nb33*(sequence % ne33) + nb31*col_Q_0) : nullptr;
+    const half * maskh = mask ? (const half *) (mask + nb33*(sequence % ne33)) : nullptr;
 
     const int stride_K2   = nb11 / sizeof(half2);
     const int stride_V2   = nb21 / sizeof(half2);
@@ -842,11 +844,9 @@ static __global__ void flash_attn_tile(
         for (int i0 = 0; i0 < DKQp; i0 += np*warp_size*cpy_ne_D) {
             if (i0 + np*warp_size*cpy_ne_D <= DKQ || i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D < DKQ) {
                 float tmp_f[cpy_ne_D] = {0.0f};
-                if (ncols1 == 1 || col_Q_0 + j < ne01) {
-                    ggml_cuda_memcpy_1<sizeof(tmp_f)>
-                        (tmp_f, &Q_f[c*(nb02/sizeof(float)) + j*(nb01/sizeof(float))
-                                     + i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D]);
-                }
+                ggml_cuda_memcpy_1<sizeof(tmp_f)>
+                    (tmp_f, &Q_f[c*(nb02/sizeof(float)) + fastmodulo(col_Q_0 + j, ne01)*(nb01/sizeof(float))
+                                 + i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D]);
 
 #pragma unroll
                 for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
@@ -881,23 +881,23 @@ static __global__ void flash_attn_tile(
         while (k_VKQ_0 < k_VKQ_max - nbatch_fa) {
             constexpr bool oob_check = false;
             flash_attn_tile_iter<warp_size, nwarps, ncols1, ncols2, DKQ, DV, nbatch_fa, nbatch_K, use_logit_softcap, oob_check>
-                (Q_tmp, K_h2, V_h2, maskh, logit_softcap, slope, KQ, KV_tmp,
-                stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max);
+                (Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp,
+                stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0);
             k_VKQ_0 += gridDim.y*nbatch_fa;
         }
         if (k_VKQ_0 < k_VKQ_max) {
             constexpr bool oob_check = true;
             flash_attn_tile_iter<warp_size, nwarps, ncols1, ncols2, DKQ, DV, nbatch_fa, nbatch_K, use_logit_softcap, oob_check>
-                (Q_tmp, K_h2, V_h2, maskh, logit_softcap, slope, KQ, KV_tmp,
-                stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max);
+                (Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp,
+                stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0);
         }
     } else {
         // Branch without out-of-bounds checks.
         for (int k_VKQ_0 = blockIdx.y*nbatch_fa; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*nbatch_fa) {
             constexpr bool oob_check = false;
             flash_attn_tile_iter<warp_size, nwarps, ncols1, ncols2, DKQ, DV, nbatch_fa, nbatch_K, use_logit_softcap, oob_check>
-                (Q_tmp, K_h2, V_h2, maskh, logit_softcap, slope, KQ, KV_tmp,
-                stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max);
+                (Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp,
+                stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0);
         }
     }
 
@@ -1010,13 +1010,13 @@ static __global__ void flash_attn_tile(
         const int j = jc / ncols2;
         const int c = jc % ncols2;
 
-        if (ncols1 > 1 && col_Q_0 + j >= ne01) {
+        if (ncols1 > 1 && col_Q_0 + j >= int(ne01.z)) {
             return;
         }
 
         const float scale = gridDim.y == 1 ? 1.0f/KQ_sum[jc0] : 1.0f;
 
-        const int j_dst_unrolled = ((sequence*ne01 + col_Q_0 + j)*ne02 + head0 + c)*gridDim.y + blockIdx.y;
+        const int j_dst_unrolled = ((sequence*int(ne01.z) + col_Q_0 + j)*ne02 + head0 + c)*gridDim.y + blockIdx.y;
 
 #ifdef FAST_FP16_AVAILABLE
         constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size;

+ 9 - 9
ggml/src/ggml-cuda/fattn-vec.cuh

@@ -33,7 +33,7 @@ static __global__ void flash_attn_ext_vec(
         const float m1,
         const uint32_t n_head_log2,
         const float logit_softcap,
-        const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
+        const int32_t ne00, const uint3   ne01, const int32_t ne02, const int32_t ne03,
                             const int32_t nb01, const int32_t nb02, const int32_t nb03,
         const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
                             const int32_t nb11, const int32_t nb12, const int64_t nb13,
@@ -150,7 +150,7 @@ static __global__ void flash_attn_ext_vec(
             float2 * tmp_q_ds  = (float2 *) (tmp_q_i32 + D/sizeof(int));
 
             // Set memory to zero if out of bounds:
-            if (ncols > 1 && ic0 + j >= ne01) {
+            if (ncols > 1 && ic0 + j >= int(ne01.z)) {
 #pragma unroll
                 for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) {
                     const int i = i0 + threadIdx.x;
@@ -201,7 +201,7 @@ static __global__ void flash_attn_ext_vec(
                 const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ)*cpy_ne;
 
                 float2 tmp[cpy_ne] = {{0.0f, 0.0f}};
-                if (ncols == 1 || ic0 + j < ne01) {
+                if (ncols == 1 || ic0 + j < int(ne01.z)) {
                     ggml_cuda_memcpy_1<cpy_nb>(tmp,            &Q_j[i]);
                     ggml_cuda_memcpy_1<cpy_nb>(tmp + cpy_ne/2, &Q_j[i + cpy_ne/2]);
                 }
@@ -222,7 +222,7 @@ static __global__ void flash_attn_ext_vec(
 #pragma unroll
             for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) {
                 const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ)*cpy_ne;
-                if (ncols == 1 || ic0 + j < ne01) {
+                if (ncols == 1 || ic0 + j < int(ne01.z)) {
                     ggml_cuda_memcpy_1<cpy_nb>(&Q_reg[j][i0/nthreads_KQ],            &Q_j[i]);
                     ggml_cuda_memcpy_1<cpy_nb>(&Q_reg[j][i0/nthreads_KQ + cpy_ne/2], &Q_j[i + cpy_ne/2]);
                 }
@@ -266,7 +266,7 @@ static __global__ void flash_attn_ext_vec(
                     sum = logit_softcap*tanhf(sum);
                 }
 
-                if (mask) {
+                if (mask && (ncols == 1 || ic0 + j < int(ne01.z))) {
                     sum += slope*__half2float(maskh[j*ne11 + i_KQ]);
                 }
 
@@ -412,7 +412,7 @@ static __global__ void flash_attn_ext_vec(
 
 #pragma unroll
     for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
-        if (ncols > 1 && ic0 + j_VKQ >= ne01) {
+        if (ncols > 1 && ic0 + j_VKQ >= int(ne01.z)) {
             break;
         }
 
@@ -479,7 +479,7 @@ static __global__ void flash_attn_ext_vec(
                 if (gridDim.y == 1) {
                     dst_val /= KQ_sum[j_VKQ];
                 }
-                dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + i0 + tid] = dst_val;
+                dst[(((sequence*int(ne01.z) + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + i0 + tid] = dst_val;
             }
         }
 
@@ -489,8 +489,8 @@ static __global__ void flash_attn_ext_vec(
 
     }
 
-    if (gridDim.y != 1 && tid < ncols && (ncols == 1 || ic0 + tid < ne01)) {
-        dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(KQ_max[tid], KQ_sum[tid]);
+    if (gridDim.y != 1 && tid < ncols && (ncols == 1 || ic0 + tid < int(ne01.z))) {
+        dst_meta[((sequence*int(ne01.z) + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(KQ_max[tid], KQ_sum[tid]);
     }
 #else
     GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,

+ 9 - 8
ggml/src/ggml-cuda/fattn-wmma-f16.cu

@@ -38,14 +38,14 @@ static __global__ void flash_attn_ext_f16(
         const float m1,
         const uint32_t n_head_log2,
         const float logit_softcap,
-        const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
+        const int32_t ne00, const uint3   ne01, const int32_t ne02, const int32_t ne03,
                             const int32_t nb01, const int32_t nb02, const int32_t nb03,
         const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
                             const int32_t nb11, const int32_t nb12, const int64_t nb13,
                             const int32_t nb21, const int32_t nb22, const int64_t nb23,
                             const int32_t ne31, const int32_t ne32, const int32_t ne33,
                             const int32_t nb31, const int32_t nb32, const int64_t nb33) {
-#if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN)))
+#if defined(FLASH_ATTN_AVAILABLE) && (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN))
     // Skip unused kernel variants for faster compilation:
     if (use_logit_softcap && !(D == 128 || D == 256)) {
         NO_DEVICE_CODE;
@@ -149,7 +149,7 @@ static __global__ void flash_attn_ext_f16(
             if (i0 + warp_size > D && i >= D) {
                 break;
             }
-            KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f;
+            KQ[j*D_padded + i] = ic0 + j < int(ne01.z) ? Q_f[j*stride_Q + i] * scale : 0.0f;
         }
     }
 
@@ -218,7 +218,8 @@ static __global__ void flash_attn_ext_f16(
                 for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) {
                     const int k = k0 + threadIdx.x;
 
-                    KQ_f_tmp[k0/warp_size] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
+                    KQ_f_tmp[k0/warp_size] += mask && ic0 + j < int(ne01.z) ?
+                        __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
                     KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/warp_size]);
                 }
                 KQ_max_new = warp_reduce_max<warp_size>(KQ_max_new);
@@ -270,7 +271,7 @@ static __global__ void flash_attn_ext_f16(
                 for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) {
                     const int k = k0 + threadIdx.x;
 
-                    KQ2_tmp[k0/warp_size] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
+                    KQ2_tmp[k0/warp_size] += mask && ic0 + j < int(ne01.z) ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
                     KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/warp_size]);
                 }
                 KQ_max_new = __half2half2(warp_reduce_max<warp_size>(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
@@ -431,7 +432,7 @@ static __global__ void flash_attn_ext_f16(
 #pragma unroll
     for (int j0 = 0; j0 < ncols; j0 += nwarps) {
         const int j_VKQ = j0 + threadIdx.y;
-        if (ic0 + j_VKQ >= ne01) {
+        if (ic0 + j_VKQ >= int(ne01.z)) {
             return;
         }
 
@@ -442,7 +443,7 @@ static __global__ void flash_attn_ext_f16(
             KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]);
         }
 
-        const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
+        const int j_dst_unrolled = ((sequence*int(ne01.z) + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
 
 #pragma unroll
         for (int i0 = 0; i0 < D; i0 += warp_size) {
@@ -481,7 +482,7 @@ static __global__ void flash_attn_ext_f16(
               ne31, ne32, ne33,
               nb31, nb32, nb33);
     NO_DEVICE_CODE;
-#endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN)))
+#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN))
 }
 
 constexpr int get_max_power_of_2(int x) {

+ 2 - 2
ggml/src/ggml-cuda/fattn-wmma-f16.cuh

@@ -2,9 +2,9 @@
 
 #include "common.cuh"
 
-#if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
+#if defined(GGML_USE_MUSA)
 #define GGML_USE_WMMA_FATTN
-#endif // (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
+#endif // defined(GGML_USE_MUSA)
 
 #if defined(GGML_HIP_ROCWMMA_FATTN)
 #if defined(CDNA) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0)

+ 19 - 5
ggml/src/ggml-cuda/fattn.cu

@@ -12,13 +12,13 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con
     const ggml_tensor * Q = dst->src[0];
 
     if constexpr (ncols2 <= 8) {
-        if (Q->ne[1] <= 8/ncols2) {
+        if (turing_mma_available(cc) && Q->ne[1] <= 8/ncols2) {
             ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 8/ncols2, ncols2>(ctx, dst);
             return;
         }
     }
 
-    if (Q->ne[1] <= 16/ncols2) {
+    if (turing_mma_available(cc) && Q->ne[1] <= 16/ncols2) {
         ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 16/ncols2, ncols2>(ctx, dst);
         return;
     }
@@ -41,7 +41,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con
     float max_bias = 0.0f;
     memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
 
-    const bool use_gqa_opt = mask && max_bias == 0.0f;
+    const bool use_gqa_opt = mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
 
     GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
     const int gqa_ratio = Q->ne[2] / K->ne[2];
@@ -275,8 +275,8 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
     // For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes:
     const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0;
 
-    // If Turing tensor cores available, use them:
-    if (turing_mma_available(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72) {
+    // If Turing tensor cores are available, use them:
+    if (turing_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) {
         if (can_use_vector_kernel) {
             if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
                 if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) {
@@ -297,7 +297,21 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
                 return BEST_FATTN_KERNEL_VEC;
             }
         }
+        return BEST_FATTN_KERNEL_MMA_F16;
+    }
 
+    if (volta_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) {
+        int gqa_ratio_eff = 1;
+        const int ncols2_max = Q->ne[0] == 576 ? 16 : 8;
+        while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) {
+            gqa_ratio_eff *= 2;
+        }
+        if (can_use_vector_kernel && Q->ne[1] * gqa_ratio_eff <= 2) {
+            return BEST_FATTN_KERNEL_VEC;
+        }
+        if (Q->ne[1] * gqa_ratio_eff <= 16) {
+            return BEST_FATTN_KERNEL_TILE; // On Volta tensor cores are only faster for sufficiently large matrices.
+        }
         return BEST_FATTN_KERNEL_MMA_F16;
     }
 

+ 194 - 73
ggml/src/ggml-cuda/mma.cuh

@@ -68,10 +68,31 @@ static __device__ __forceinline__ half2 ggml_cuda_movmatrix(const half2 x) {
 
 namespace ggml_cuda_mma {
 
+    // Some architectures like Volta or CDNA3 perform multiple matrix multiplications per warp in parallel,
+    //     effectively the warp is being split into subgroups of threads that each perform a single mma instruction.
+    // In those cases the data can be split in different ways across the warp.
+    enum data_layout {
+        // By default the data uses the I direction as its major dimension and the J direction as its minor dimension.
+        // For the A/C matrices this means I major == row major, J major == column major.
+        // For the B matrix this means I major == column major, J major == row major.
+        // MIRRORED == Each data value is held exactly once per thread subgroup.
+        DATA_LAYOUT_I_MAJOR           =  0, // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell.
+        DATA_LAYOUT_I_MAJOR_MIRRORED  = 10,
+        DATA_LAYOUT_J_MAJOR_MIRRORED  = 20,
+    };
+    // Implemented mma combinations are:
+    //   - (I_MAJOR, I_MAJOR)          -> I_MAJOR
+    //   - (I_MAJOR, I_MAJOR_MIRRORED) -> I_MAJOR
+    //   - (I_MAJOR, J_MAJOR_MIRRORED) -> I_MAJOR
+
+    template <int I_, int J_, typename T, data_layout ds_=DATA_LAYOUT_I_MAJOR>
+    struct tile {};
+
     template <int I_, int J_, typename T>
-    struct tile {
-        static constexpr int I  = I_;
-        static constexpr int J  = J_;
+    struct tile<I_, J_, T, DATA_LAYOUT_I_MAJOR> {
+        static constexpr int         I  = I_;
+        static constexpr int         J  = J_;
+        static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
 
 #if defined(AMD_MFMA_AVAILABLE)
         static constexpr int ne = I * J / 64;
@@ -131,9 +152,9 @@ namespace ggml_cuda_mma {
         static __device__ __forceinline__ int get_i(const int l) {
             if constexpr (I == 32 && J == 8) {
 #ifdef GGML_CUDA_MMA_NO_VOLTA_PERM
-                return (((threadIdx.x % 16) / 4) * 8) | ((threadIdx.x / 16) * 4) | (l & 2) | (threadIdx.x % 2);
+                return (((threadIdx.x % 16) / 4) * 8) + ((threadIdx.x / 16) * 4) + (l & 2) + (threadIdx.x % 2);
 #else
-                return (l & 2) | (threadIdx.x & ~2);
+                return (l & 2) + (threadIdx.x & ~2);
 #endif // GGML_CUDA_MMA_NO_VOLTA_PERM
             } else {
                 NO_DEVICE_CODE;
@@ -143,7 +164,7 @@ namespace ggml_cuda_mma {
 
         static __device__ __forceinline__ int get_j(const int l) {
             if constexpr (I == 32 && J == 8) {
-                return (threadIdx.x & 2) | (l & (4 + 1));
+                return (threadIdx.x & 2) + (l & (4 + 1));
             } else {
                 NO_DEVICE_CODE;
                 return -1;
@@ -196,9 +217,9 @@ namespace ggml_cuda_mma {
             } else if constexpr (I == 8 && J == 8) {
                 return threadIdx.x / 4;
             } else if constexpr (I == 16 && J == 8) {
-                return ((l / 2) * 8) | (threadIdx.x / 4);
+                return ((l / 2) * 8) + (threadIdx.x / 4);
             } else if constexpr (I == 16 && J == 16) {
-                return (((l / 2) % 2) * 8) | (threadIdx.x / 4);
+                return (((l / 2) % 2) * 8) + (threadIdx.x / 4);
             } else if constexpr (I == 32 && J == 8) {
                 return tile<16, 8, T>::get_i(l); // Memory layout simply repeated with same pattern in i direction.
             } else {
@@ -211,11 +232,11 @@ namespace ggml_cuda_mma {
             if constexpr (I == 8 && J == 4) {
                 return threadIdx.x % 4;
             } else if constexpr (I == 8 && J == 8) {
-                return (l * 4) | (threadIdx.x % 4);
+                return (l * 4) + (threadIdx.x % 4);
             } else if constexpr (I == 16 && J == 8) {
-                return ((threadIdx.x % 4) * 2) | (l % 2);
+                return ((threadIdx.x % 4) * 2) + (l % 2);
             } else if constexpr (I == 16 && J == 16) {
-                return ((l / 4) * 8) | ((threadIdx.x % 4) * 2) | (l % 2);
+                return ((l / 4) * 8) + ((threadIdx.x % 4) * 2) + (l % 2);
             } else if constexpr (I == 32 && J == 8) {
                 return tile<16, 8, T>::get_j(l); // Memory layout simply repeated with same pattern in i direction.
             } else {
@@ -227,26 +248,24 @@ namespace ggml_cuda_mma {
     };
 
     template <int I_, int J_>
-    struct tile<I_, J_, half2> {
-        static constexpr int I  = I_;
-        static constexpr int J  = J_;
+    struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR> {
+        static constexpr int         I  = I_;
+        static constexpr int         J  = J_;
+        static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
 
 #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
-        static constexpr int ne = I == 8 && J == 8 ? I * J / (WARP_SIZE/4) : I * J / WARP_SIZE;
+        static constexpr int ne = I * J / WARP_SIZE;
         half2 x[ne] = {{0.0f, 0.0f}};
 
         static constexpr __device__ bool supported() {
-            if (I ==  8 && J ==  8) return true;
-            if (I == 32 && J ==  8) return true;
+            if (I == 32 && J ==  4) return true;
             return false;
         }
 
         static __device__ __forceinline__ int get_i(const int l) {
-            if constexpr (I == 8 && J == 8) {
-                return ((threadIdx.x / 16) * 4) | (threadIdx.x % 4);
-            } else if constexpr (I == 32 && J == 8) {
+            if constexpr (I == 32 && J == 4) {
 #ifdef GGML_CUDA_MMA_NO_VOLTA_PERM
-                return (((threadIdx.x % 16) / 4) * 8) | ((threadIdx.x / 16) * 4) | (threadIdx.x % 4);
+                return (((threadIdx.x % 16) / 4) * 8) + ((threadIdx.x / 16) * 4) + (threadIdx.x % 4);
 #else
                 return threadIdx.x;
 #endif // GGML_CUDA_MMA_NO_VOLTA_PERM
@@ -257,7 +276,7 @@ namespace ggml_cuda_mma {
         }
 
         static __device__ __forceinline__ int get_j(const int l) {
-            if constexpr ((I == 8 || I == 32) && J == 8) {
+            if constexpr (I == 32 && J == 4) {
                 return l;
             } else {
                 NO_DEVICE_CODE;
@@ -307,11 +326,11 @@ namespace ggml_cuda_mma {
             if constexpr (I == 8 && J == 8) {
                 return threadIdx.x / 4;
             } else if constexpr (I == 16 && J == 4) {
-                return (l * 8) | (threadIdx.x / 4);
+                return (l * 8) + (threadIdx.x / 4);
             } else if constexpr (I == 16 && J == 8) {
-                return ((l % 2) * 8) | (threadIdx.x / 4);
+                return ((l % 2) * 8) + (threadIdx.x / 4);
             } else if constexpr (I == 32 && J == 8) {
-                return ((l / 4) * 16) | ((l % 2) * 8) | (threadIdx.x / 4);
+                return ((l / 4) * 16) + ((l % 2) * 8) + (threadIdx.x / 4);
             } else {
                 NO_DEVICE_CODE;
                 return -1;
@@ -320,13 +339,13 @@ namespace ggml_cuda_mma {
 
         static __device__ __forceinline__ int get_j(const int l) {
             if constexpr (I == 8 && J == 8) {
-                return (l * 4) | (threadIdx.x % 4);
+                return (l * 4) + (threadIdx.x % 4);
             } else if constexpr (I == 16 && J == 4) {
                 return threadIdx.x % 4;
             } else if constexpr (I == 16 && J == 8) {
-                return ((l / 2) * 4) | (threadIdx.x % 4);
+                return ((l / 2) * 4) + (threadIdx.x % 4);
             } else if constexpr (I == 32 && J == 8) {
-                return ((l & 2) * 2) | (threadIdx.x % 4);
+                return ((l & 2) * 2) + (threadIdx.x % 4);
             } else {
                 NO_DEVICE_CODE;
                 return -1;
@@ -336,14 +355,15 @@ namespace ggml_cuda_mma {
     };
 
     template <int I_, int J_>
-    struct tile<I_, J_, nv_bfloat162> {
-        static constexpr int I  = I_;
-        static constexpr int J  = J_;
+    struct tile<I_, J_, nv_bfloat162, DATA_LAYOUT_I_MAJOR> {
+        static constexpr int         I  = I_;
+        static constexpr int         J  = J_;
+        static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
+        static constexpr int         ne = I * J / WARP_SIZE;
 
-#if defined(AMD_WMMA_AVAILABLE)
-        static constexpr int ne = I * J / 32;
         nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
 
+#if defined(AMD_WMMA_AVAILABLE)
         static constexpr __device__ bool supported() {
             if (I == 16 && J == 8) return true;
             return false;
@@ -367,9 +387,6 @@ namespace ggml_cuda_mma {
             }
         }
 #else
-        static constexpr int ne = I * J / WARP_SIZE;
-        nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
-
         static constexpr __device__ bool supported() {
             if (I ==  8 && J ==  8) return true;
             if (I == 16 && J ==  4) return true;
@@ -381,9 +398,9 @@ namespace ggml_cuda_mma {
             if constexpr (I == 8 && J == 8) {
                 return threadIdx.x / 4;
             } else if constexpr (I == 16 && J == 4) {
-                return (l * 8) | (threadIdx.x / 4);
+                return (l * 8) + (threadIdx.x / 4);
             } else if constexpr (I == 16 && J == 8) {
-                return ((l % 2) * 8) | (threadIdx.x / 4);
+                return ((l % 2) * 8) + (threadIdx.x / 4);
             } else {
                 NO_DEVICE_CODE;
                 return -1;
@@ -392,11 +409,11 @@ namespace ggml_cuda_mma {
 
         static __device__ __forceinline__ int get_j(const int l) {
             if constexpr (I == 8 && J == 8) {
-                return (l * 4) | (threadIdx.x % 4);
+                return (l * 4) + (threadIdx.x % 4);
             } else if constexpr (I == 16 && J == 4) {
                 return threadIdx.x % 4;
             } else if constexpr (I == 16 && J == 8) {
-                return ((l / 2) * 4) | (threadIdx.x % 4);
+                return ((l / 2) * 4) + (threadIdx.x % 4);
             } else {
                 NO_DEVICE_CODE;
                 return -1;
@@ -405,6 +422,73 @@ namespace ggml_cuda_mma {
 #endif  // defined(AMD_WMMA_AVAILABLE)
     };
 
+    template <int I_, int J_>
+    struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> {
+        static constexpr int         I  = I_;
+        static constexpr int         J  = J_;
+        static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
+        static constexpr int         ne = I * J / (WARP_SIZE/4);
+
+        half2 x[ne] = {{0.0f, 0.0f}};
+
+        static constexpr __device__ bool supported() {
+            if (I ==  8 && J ==  4) return true;
+            return false;
+        }
+
+        static __device__ __forceinline__ int get_i(const int /*l*/) {
+            if constexpr (I == 8 && J == 4) {
+                return ((threadIdx.x / 16) * 4) + (threadIdx.x % 4);
+            } else {
+                NO_DEVICE_CODE;
+                return -1;
+            }
+        }
+
+        static __device__ __forceinline__ int get_j(const int l) {
+            if constexpr (I == 8 && J == 4) {
+                return l;
+            } else {
+                NO_DEVICE_CODE;
+                return -1;
+            }
+        }
+    };
+
+    template <int I_, int J_>
+    struct tile<I_, J_, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> {
+        static constexpr int         I  = I_;
+        static constexpr int         J  = J_;
+        static constexpr data_layout dl = DATA_LAYOUT_J_MAJOR_MIRRORED;
+        static constexpr int         ne = I * J / (WARP_SIZE/4);
+
+        half2 x[ne] = {{0.0f, 0.0f}};
+
+        static constexpr __device__ bool supported() {
+            if (I ==  8 && J ==  4) return true;
+            return false;
+        }
+
+        static __device__ __forceinline__ int get_i(const int l) {
+            if constexpr (I == 8 && J == 4) {
+                return ((l / 2) * 4) + (threadIdx.x % 4);
+            } else {
+                NO_DEVICE_CODE;
+                return -1;
+            }
+        }
+
+        static __device__ __forceinline__ int get_j(const int l) {
+            if constexpr (I == 8 && J == 4) {
+                return ((threadIdx.x / 16) * 2) + (l % 2);
+            } else {
+                NO_DEVICE_CODE;
+                return -1;
+            }
+        }
+    };
+
+#if defined(TURING_MMA_AVAILABLE)
     template <int I, int J>
     static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
         tile<I, J/2, half2> ret;
@@ -422,9 +506,26 @@ namespace ggml_cuda_mma {
 
         return ret;
     }
+#else // Volta
+    template <int I, int J>
+    static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
+        tile<I, J/2, half2> ret;
+#pragma unroll
+        for (int l0 = 0; l0 < tile_float.ne; l0 += 4) {
+            ret.x[l0/2 + 0] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]);
+            ret.x[l0/2 + 1] = make_half2(tile_float.x[l0 + 2], tile_float.x[l0 + 3]);
+
+            // On Volta FP16 and FP32 tiles have a different memory layout,
+            //     for the conversion threads with an offset of 2 need to exchange half their values:
+            ret.x[l0/2 + (((threadIdx.x % 4) / 2) ^ 1)] = __shfl_xor_sync(
+                0xFFFFFFFF, ret.x[l0/2 + (((threadIdx.x % 4) / 2) ^ 1)], 2, WARP_SIZE);
+        }
+        return ret;
+    }
+#endif // defined(TURING_MMA_AVAILABLE)
 
-    template <int I, int J, typename T>
-    static __device__ __forceinline__ void load_generic(tile<I, J, T> & t, const T * __restrict__ xs0, const int stride) {
+    template <int I, int J, typename T, data_layout dl>
+    static __device__ __forceinline__ void load_generic(tile<I, J, T, dl> & t, const T * __restrict__ xs0, const int stride) {
 #if defined(AMD_MFMA_AVAILABLE)
         if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
 #pragma unroll
@@ -511,18 +612,6 @@ namespace ggml_cuda_mma {
             : "=r"(xi[0]), "=r"(xi[1]), "=r"(xi[2]), "=r"(xi[3])
             : "l"(xs));
 #else
-#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
-        GGML_UNUSED_VARS(t, xs0, stride);
-        NO_DEVICE_CODE;
-#else
-        load_generic(t, xs0, stride);
-#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
-#endif // TURING_MMA_AVAILABLE
-    }
-
-    template <typename T>
-    static __device__ __forceinline__ void load_ldmatrix(
-            tile<32, 8, T> & t, const T * __restrict__ xs0, const int stride) {
 #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
 #if 1
         // TODO: more generic handling
@@ -533,9 +622,31 @@ namespace ggml_cuda_mma {
         load_generic(t, xs0, stride);
 #endif // 1
 #else
-        tile<16, 8, T> * t16 = (tile<16, 8, T> *) &t;
-        load_ldmatrix(t16[0], xs0 +  0*stride, stride);
-        load_ldmatrix(t16[1], xs0 + 16*stride, stride);
+        load_generic(t, xs0, stride);
+#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
+#endif // TURING_MMA_AVAILABLE
+    }
+
+    static __device__ __forceinline__ void load_ldmatrix(
+            tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & t, const half2 * __restrict__ xs0, const int stride) {
+        ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride);
+    }
+
+    static __device__ __forceinline__ void load_ldmatrix(
+            tile<8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> & t, const half2 * __restrict__ xs0, const int stride) {
+#pragma unroll
+        for (int l0 = 0; l0 < t.ne; l0 += 2) {
+            ggml_cuda_memcpy_1<2*sizeof(half2)>(t.x + l0, xs0 + t.get_i(l0)*stride + t.get_j(l0));
+        }
+    }
+
+    static __device__ __forceinline__ void load_ldmatrix(
+            tile<32, 4, half2> & t, const half2 * __restrict__ xs0, const int stride) {
+#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
+        ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride);
+#else
+        GGML_UNUSED_VARS(t, xs0, stride);
+        NO_DEVICE_CODE;
 #endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
     }
 
@@ -860,14 +971,14 @@ namespace ggml_cuda_mma {
     template <typename T1, typename T2, int J, int K>
     static __device__ __forceinline__ void mma(
             tile<32, J, T1> & D, const tile<32, K, T2> & A, const tile<J, K, T2> & B) {
-        tile<16, J, T1> * D16 = (tile<16, J, T1> *) &D;
-        tile<16, K, T2> * A16 = (tile<16, K, T2> *) &A;
+        tile      <16, J, T1> * D16 = reinterpret_cast<      tile<16, J, T1> *>(&D);
+        const tile<16, K, T2> * A16 = reinterpret_cast<const tile<16, K, T2> *>(&A);
         mma(D16[0], A16[0], B);
         mma(D16[1], A16[1], B);
     }
 
     static __device__ __forceinline__ void mma(
-            tile<32, 8, float> & D, const tile<32, 8, half2> & A, const tile<8, 8, half2> & B) {
+            tile<32, 8, float> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & B) {
 #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
         const int * Axi = (const int *) A.x;
         const int * Bxi = (const int *) B.x;
@@ -880,20 +991,30 @@ namespace ggml_cuda_mma {
             "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
             : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
             : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]), "r"(Bxi[3]));
-        asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
-            "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
-            : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
-            : "r"(Axi[4]), "r"(Axi[5]), "r"(Bxi[4]), "r"(Bxi[5]));
-        asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
-            "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
-            : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
-            : "r"(Axi[6]), "r"(Axi[7]), "r"(Bxi[6]), "r"(Bxi[7]));
 #else
-        tile      <16, 8, float> * D16 = reinterpret_cast<tile      <16, 8, float> *>(&D);
-        const tile<16, 8, half2> * A16 = reinterpret_cast<const tile<16, 8, half2> *>(&A);
-        mma(D16[0], A16[0], B);
-        mma(D16[1], A16[1], B);
-#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+        GGML_UNUSED_VARS(D, A, B);
+        NO_DEVICE_CODE;
+#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
+    }
+
+    static __device__ __forceinline__ void mma(
+            tile<32, 4, half2> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> & B) {
+#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
+        const int * Axi = (const int *) A.x;
+        const int * Bxi = (const int *) B.x;
+        int       * Dxi = (int       *) D.x;
+        asm("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 "
+            "{%0, %1, %2, %3}, {%4, %5}, {%6, %7}, {%0, %1, %2, %3};"
+            : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
+            : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]), "r"(Bxi[1]));
+        asm("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 "
+            "{%0, %1, %2, %3}, {%4, %5}, {%6, %7}, {%0, %1, %2, %3};"
+            : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
+            : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]), "r"(Bxi[3]));
+#else
+        GGML_UNUSED_VARS(D, A, B);
+        NO_DEVICE_CODE;
+#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
     }
 
 static __device__ __forceinline__ void mma(

+ 29 - 30
ggml/src/ggml-cuda/mmf.cuh

@@ -37,23 +37,19 @@ static __global__ void mul_mat_f(
     typedef tile<16,       8, T>     tile_A;
     typedef tile<tile_B_I, 8, T>     tile_B;
     typedef tile<16,       tile_C_J, float> tile_C;
-
-    constexpr bool a_supported = tile_A::supported();
-    constexpr bool b_supported = tile_B::supported();
-    constexpr bool c_supported = tile_C::supported();
-    constexpr bool supported = a_supported && b_supported && c_supported;
 #else
-    constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported();
-    constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported();
-    constexpr bool supported = I_16_supported || I_32_supported;
-
-    constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster.
-
-    typedef tile<I_preferred, 8, T>     tile_A;
-    typedef tile<8,           8, T>     tile_B;
-    typedef tile<I_preferred, 8, float> tile_C;
+#ifdef VOLTA_MMA_AVAILABLE
+    if constexpr (!std::is_same_v<T, half2>) {NO_DEVICE_CODE;} else {
+    typedef tile<32, 4, T,     DATA_LAYOUT_I_MAJOR>          tile_A;
+    typedef tile< 8, 4, T,     DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B;
+    typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR>          tile_C;
+#else
+    typedef tile<16, 8, T>     tile_A;
+    typedef tile<8,  8, T>     tile_B;
+    typedef tile<16, 8, float> tile_C;
+#endif // VOLTA_MMA_AVAILABLE
 #endif // defined(AMD_WMMA_AVAILABLE)
-    if constexpr (!supported) {
+    if constexpr (!tile_A::supported() || !tile_B::supported() || !tile_C::supported()) {
         NO_DEVICE_CODE;
         return;
     }
@@ -248,6 +244,9 @@ static __global__ void mul_mat_f(
             }
         }
     }
+#ifdef VOLTA_MMA_AVAILABLE
+    }
+#endif //VOLTA_MMA_AVAILABLE
 #else
     GGML_UNUSED_VARS(x, y, ids, dst,
         ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
@@ -278,27 +277,24 @@ static __global__ void mul_mat_f_ids(
     typedef tile<16,       8, T>     tile_A;
     typedef tile<tile_B_I, 8, T>     tile_B;
     typedef tile<16,       tile_C_J, float> tile_C;
-
-    constexpr bool a_supported = tile_A::supported();
-    constexpr bool b_supported = tile_B::supported();
-    constexpr bool c_supported = tile_C::supported();
-    constexpr bool supported = a_supported && b_supported && c_supported;
 #else
-    constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported();
-    constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported();
-    constexpr bool supported = I_16_supported || I_32_supported;
-
-    constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster.
-
-    typedef tile<I_preferred, 8, T>     tile_A;
-    typedef tile<8,           8, T>     tile_B;
-    typedef tile<I_preferred, 8, float> tile_C;
+#ifdef VOLTA_MMA_AVAILABLE
+    if constexpr (!std::is_same_v<T, half2>) {NO_DEVICE_CODE;} else {
+    typedef tile<32, 4, T,     DATA_LAYOUT_I_MAJOR>          tile_A;
+    typedef tile< 8, 4, T,     DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B;
+    typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR>          tile_C;
+#else
+    typedef tile<16, 8, T>     tile_A;
+    typedef tile<8,  8, T>     tile_B;
+    typedef tile<16, 8, float> tile_C;
+#endif // VOLTA_MMA_AVAILABLE
 #endif // defined(AMD_WMMA_AVAILABLE)
-    if constexpr (!supported) {
+    if constexpr (!tile_A::supported() || !tile_B::supported() || !tile_C::supported()) {
         NO_DEVICE_CODE;
         return;
     }
 
+
     constexpr int warp_size = ggml_cuda_get_physical_warp_size();
     constexpr int tile_k_padded = warp_size + 4;
     constexpr int ntA = rows_per_block / tile_A::I;
@@ -517,6 +513,9 @@ static __global__ void mul_mat_f_ids(
             }
         }
     }
+#ifdef VOLTA_MMA_AVAILABLE
+    }
+#endif // VOLTA_MMA_AVAILABLE
 #else
     GGML_UNUSED_VARS(x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst,
         ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,

Энэ ялгаанд хэт олон файл өөрчлөгдсөн тул зарим файлыг харуулаагүй болно