Browse Source

CUDA: add gqa_ratio 4 for GLM 4.7 flash (#18953)

Aman Gupta 1 week ago
parent
commit
b70d251076

+ 23 - 8
ggml/src/ggml-cuda/fattn-mma-f16.cuh

@@ -432,7 +432,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
     constexpr int  ncols           = ncols1 * ncols2;
     constexpr int  cols_per_warp   = T_B_KQ::I;
     constexpr int  cols_per_thread = get_cols_per_thread();
-    constexpr int  np              = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
+    constexpr int  np              = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
     constexpr int  nbatch_fa       = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
     constexpr int  nbatch_K2       = ggml_cuda_fattn_mma_get_nbatch_K2(DKQ, DV, ncols);
     constexpr int  nbatch_V2       = ggml_cuda_fattn_mma_get_nbatch_V2(DKQ, DV, ncols);
@@ -510,7 +510,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
                 }
             }
         } else {
-            static_assert(cols_per_warp != 8, "cols_per_warp == 8 not implemented");
 #pragma unroll
             for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) {
                 load_ldmatrix(Q_B[0], tile_Q + (threadIdx.y / np)*(T_B_KQ::I*stride_tile_Q) + k_KQ_0, stride_tile_Q);
@@ -522,14 +521,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
                     T_A_KQ K_A;
                     load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
 
-                    // Wide version of KQ_C is column-major
+                    if constexpr (cols_per_warp == 8) {
+                        mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
+                    } else {
+                        // Wide version of KQ_C is column-major
 #if defined(AMD_WMMA_AVAILABLE)
-                    // RDNA matrix C is column-major.
-                    mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
+                        // RDNA matrix C is column-major.
+                        mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
 #else
-                    // swap A and B for CUDA.
-                    mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
+                        // swap A and B for CUDA.
+                        mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
 #endif // defined(AMD_WMMA_AVAILABLE)
+                    }
                 }
             }
         }
@@ -953,7 +956,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
 
     constexpr int  cols_per_warp   = T_B_KQ::I;
     constexpr int  cols_per_thread = get_cols_per_thread();
-    constexpr int  np              = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
+    constexpr int  np              = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
     constexpr int  nbatch_fa       = ggml_cuda_fattn_mma_get_nbatch_fa     (DKQ, DV, ncols);
     constexpr int  nbatch_K2       = ggml_cuda_fattn_mma_get_nbatch_K2     (DKQ, DV, ncols);
     constexpr int  nbatch_V2       = ggml_cuda_fattn_mma_get_nbatch_V2     (DKQ, DV, ncols);
@@ -1484,6 +1487,13 @@ static __global__ void flash_attn_ext_f16(
         NO_DEVICE_CODE;
         return;
     }
+#ifdef VOLTA_MMA_AVAILABLE
+    if (ncols1*ncols2 < 32) {
+        NO_DEVICE_CODE;
+        return;
+    }
+#endif // VOLTA_MMA_AVAILABLE
+
 #if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
     if (ncols1*ncols2 > 32) {
         NO_DEVICE_CODE;
@@ -1728,3 +1738,8 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256,  64)
 extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
 extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
 extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
+
+// For GLM 4.7 Flash
+extern DECL_FATTN_MMA_F16_CASE(576, 512,  4,  4);
+extern DECL_FATTN_MMA_F16_CASE(576, 512,  8,  4);
+extern DECL_FATTN_MMA_F16_CASE(576, 512, 16,  4);

+ 12 - 0
ggml/src/ggml-cuda/fattn-tile.cuh

@@ -68,6 +68,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2,  64,  64)
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2,  64,  64)
 
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512,  4, 128, 2,  64,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512,  8, 256, 2,  64,  64)
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2,  64,  64)
 
     return 0;
@@ -122,6 +124,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2,  32, 128)
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2,  32,  64)
 
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512,  4, 128, 2,  32,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512,  8, 256, 2,  32,  64)
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2,  32,  64)
 
     return 0;
@@ -183,6 +187,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2,  32, 128)
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2,  32, 128)
 
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512,  4, 128, 2,  64,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512,  8, 256, 2,  64,  64)
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2,  64,  64)
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128,  64)
 
@@ -245,6 +251,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5,  32, 256)
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3,  64, 128)
 
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512,  4, 128, 2,  64,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512,  8, 256, 2,  64,  64)
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4,  64,  64)
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128,  64)
 
@@ -1187,6 +1195,10 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
             launch_fattn_tile_switch_ncols1<DKQ, DV, 16, use_logit_softcap>(ctx, dst);
             return;
         }
+        if (use_gqa_opt && gqa_ratio % 4 == 0) {
+            launch_fattn_tile_switch_ncols1<DKQ, DV, 4, use_logit_softcap>(ctx, dst);
+            return;
+        }
     }
 
     if constexpr (DV <= 256) {

+ 7 - 3
ggml/src/ggml-cuda/fattn.cu

@@ -121,8 +121,12 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
 
             GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
             const int gqa_ratio = Q->ne[2] / K->ne[2];
-            GGML_ASSERT(gqa_ratio % 16 == 0);
-            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
+            GGML_ASSERT(gqa_ratio % 4 == 0);
+            if (gqa_ratio % 16 == 0) {
+                ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
+            } else {
+                ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512,  4>(ctx, dst);
+            }
         } break;
         default:
             GGML_ABORT("fatal error");
@@ -262,7 +266,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
             if (V->ne[0] != 512) {
                 return BEST_FATTN_KERNEL_NONE;
             }
-            if (!gqa_opt_applies || gqa_ratio % 16 != 0) {
+            if (!gqa_opt_applies || gqa_ratio % 4 != 0) {
                 return BEST_FATTN_KERNEL_NONE;
             }
             break;

+ 1 - 0
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu

@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4);
 DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4);
 DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4);
 DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4);
+DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);

+ 1 - 0
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu

@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4);
 DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4);
 DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4);
 DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4);
+DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4);

+ 1 - 0
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu

@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4);
 DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4);
 DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4);
 DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4);
+DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);

+ 1 - 0
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu

@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4);
 DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4);
 DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4);
 DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4);
+DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);

+ 1 - 1
ggml/src/ggml-cuda/template-instances/generate_cu_files.py

@@ -85,7 +85,7 @@ for ncols in [8, 16, 32, 64]:
                     continue
                 if head_size_kq != 576 and ncols2 == 16:
                     continue
-                if head_size_kq == 576 and ncols2 != 16:
+                if head_size_kq == 576 and ncols2 not in (4, 16):
                     continue
                 head_size_v = head_size_kq if head_size_kq != 576 else 512
                 f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size_kq=head_size_kq, head_size_v=head_size_v))