Explorar el Código

CUDA: fix unpadded strides in MMA FA kernel (#17891)

Johannes Gäßler hace 1 mes
padre
commit
17f7f4baad
Se han modificado 2 ficheros con 32 adiciones y 21 borrados
  1. 17 20
      ggml/src/ggml-cuda/fattn-mma-f16.cuh
  2. 15 1
      ggml/src/ggml-cuda/fattn.cu

+ 17 - 20
ggml/src/ggml-cuda/fattn-mma-f16.cuh

@@ -955,22 +955,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
             (K_h2 + int64_t(kb0)*nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K, k_VKQ_sup);
     }
 
-    for (; kb0 < kb0_stop-1; ++kb0) {
-        constexpr bool last_iter = false;
-        constexpr bool oob_check = false;
-        constexpr int  k_VKQ_sup = nbatch_fa;
-        flash_attn_ext_f16_iter
-            <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
-             T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
-            (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
-             ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
-             KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
-    }
     // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
     if constexpr (ncols2 == 1) {
-        if (ne11 % nbatch_fa == 0) {
-            constexpr bool last_iter = true;
-            constexpr bool oob_check = false;
+        constexpr bool oob_check = true;
+        for (; kb0 < kb0_stop-1; ++kb0) {
+            constexpr bool last_iter = false;
             constexpr int  k_VKQ_sup = nbatch_fa;
             flash_attn_ext_f16_iter
                 <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
@@ -978,10 +967,20 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
                 (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
                  ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
                  KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
-        } else {
-            constexpr bool last_iter = true;
-            constexpr bool oob_check = true;
-            const     int  k_VKQ_sup = ne11 - kb0*nbatch_fa;
+        }
+        constexpr bool last_iter = true;
+        const     int  k_VKQ_sup = ne11 - kb0*nbatch_fa;
+        flash_attn_ext_f16_iter
+            <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
+              T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
+            (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
+             ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
+             KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
+    } else {
+        constexpr bool oob_check = false;
+        for (; kb0 < kb0_stop-1; ++kb0) {
+            constexpr bool last_iter = false;
+            constexpr int  k_VKQ_sup = nbatch_fa;
             flash_attn_ext_f16_iter
                 <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
                  T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
@@ -989,9 +988,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
                  ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
                  KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
         }
-    } else {
         constexpr bool last_iter = true;
-        constexpr bool oob_check = false;
         constexpr int  k_VKQ_sup = nbatch_fa;
         flash_attn_ext_f16_iter
             <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,

+ 15 - 1
ggml/src/ggml-cuda/fattn.cu

@@ -36,12 +36,26 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con
     const ggml_tensor * KQV  = dst;
     const ggml_tensor * Q    = dst->src[0];
     const ggml_tensor * K    = dst->src[1];
+    const ggml_tensor * V    = dst->src[2];
     const ggml_tensor * mask = dst->src[3];
 
     float max_bias = 0.0f;
     memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
 
-    const bool use_gqa_opt = mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
+    // Edge cases like no mask, ALiBi, unpadded K/V, or misaligned addresses for large data transfers
+    //     are put into the template specialization without GQA optimizations.
+    bool use_gqa_opt = mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
+    for (const ggml_tensor * t : {Q, K, V, mask}) {
+        if (t == nullptr) {
+            continue;
+        }
+        for (size_t i = 1; i < GGML_MAX_DIMS; ++i) {
+            if (t->nb[i] % 16 != 0) {
+                use_gqa_opt = false;
+                break;
+            }
+        }
+    }
 
     GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
     const int gqa_ratio = Q->ne[2] / K->ne[2];