Explorar el Código

CUDA: app option to compile without FlashAttention (#12025)

Johannes Gäßler hace 10 meses
padre
commit
a28e0d5eb1

+ 12 - 0
Makefile

@@ -680,6 +680,10 @@ ifdef GGML_CUDA_CCBIN
 	MK_NVCCFLAGS += -ccbin $(GGML_CUDA_CCBIN)
 endif # GGML_CUDA_CCBIN
 
+ifdef GGML_CUDA_NO_FA
+	MK_NVCCFLAGS += -DGGML_CUDA_NO_FA
+endif # GGML_CUDA_NO_FA
+
 ifdef GGML_CUDA_FA_ALL_QUANTS
 	MK_NVCCFLAGS += -DGGML_CUDA_FA_ALL_QUANTS
 endif # GGML_CUDA_FA_ALL_QUANTS
@@ -800,6 +804,10 @@ ifdef GGML_CUDA_NO_PEER_COPY
 	HIPFLAGS += -DGGML_CUDA_NO_PEER_COPY
 endif # GGML_CUDA_NO_PEER_COPY
 
+ifdef GGML_CUDA_NO_FA
+	HIPFLAGS += -DGGML_CUDA_NO_FA
+endif # GGML_CUDA_NO_FA
+
 	OBJ_GGML_EXT += ggml/src/ggml-cuda/ggml-cuda.o
 	OBJ_GGML_EXT += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/*.cu))
 	OBJ_GGML_EXT += $(OBJ_CUDA_TMPL)
@@ -876,6 +884,10 @@ ifdef GGML_CUDA_NO_PEER_COPY
 	MUSAFLAGS += -DGGML_CUDA_NO_PEER_COPY
 endif # GGML_CUDA_NO_PEER_COPY
 
+ifdef GGML_CUDA_NO_FA
+	MUSAFLAGS += -DGGML_CUDA_NO_FA
+endif # GGML_CUDA_NO_FA
+
 ifdef GGML_CUDA_FA_ALL_QUANTS
 	MUSAFLAGS += -DGGML_CUDA_FA_ALL_QUANTS
 endif # GGML_CUDA_FA_ALL_QUANTS

+ 1 - 0
ggml/CMakeLists.txt

@@ -151,6 +151,7 @@ set   (GGML_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING
                                             "ggml: max. batch size for using peer access")
 option(GGML_CUDA_NO_PEER_COPY               "ggml: do not use peer to peer copies"            OFF)
 option(GGML_CUDA_NO_VMM                     "ggml: do not try to use CUDA VMM"                OFF)
+option(GGML_CUDA_FA                         "ggml: compile ggml FlashAttention CUDA kernels"  ON)
 option(GGML_CUDA_FA_ALL_QUANTS              "ggml: compile all quants for FlashAttention"     OFF)
 option(GGML_CUDA_GRAPHS                     "ggml: use CUDA graphs (llama.cpp only)"          ${GGML_CUDA_GRAPHS_DEFAULT})
 

+ 4 - 0
ggml/src/ggml-cuda/CMakeLists.txt

@@ -69,6 +69,10 @@ if (CUDAToolkit_FOUND)
         add_compile_definitions(GGML_CUDA_NO_VMM)
     endif()
 
+    if (NOT GGML_CUDA_FA)
+        add_compile_definitions(GGML_CUDA_NO_FA)
+    endif()
+
     if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
         add_compile_definitions(GGML_CUDA_F16)
     endif()

+ 2 - 2
ggml/src/ggml-cuda/common.cuh

@@ -204,9 +204,9 @@ typedef float2 dfloat2;
 #define CP_ASYNC_AVAILABLE
 #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
 
-#if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
+#if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
 #define FLASH_ATTN_AVAILABLE
-#endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
+#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
 
 static bool fp16_available(const int cc) {
     return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL;

+ 4 - 4
ggml/src/ggml-cuda/fattn-mma-f16.cuh

@@ -839,10 +839,7 @@ static __global__ void flash_attn_ext_f16(
         const int ne1,
         const int ne2,
         const int ne3) {
-#ifndef NEW_MMA_AVAILABLE
-    NO_DEVICE_CODE;
-    return;
-#endif // NEW_MMA_AVAILABLE
+#if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
 
     // Skip unused kernel variants for faster compilation:
     if (use_logit_softcap && !(D == 128 || D == 256)) {
@@ -933,6 +930,9 @@ static __global__ void flash_attn_ext_f16(
     flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup>
         (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
          ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
+#else
+    NO_DEVICE_CODE;
+#endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
 }
 
 template <int D, int ncols1, int ncols2>

+ 2 - 7
ggml/src/ggml-cuda/fattn-tile-f16.cu

@@ -44,12 +44,7 @@ static __global__ void flash_attn_tile_ext_f16(
         const int ne1,
         const int ne2,
         const int ne3) {
-#ifdef FP16_AVAILABLE
-
-#ifndef FLASH_ATTN_AVAILABLE
-    NO_DEVICE_CODE;
-    return;
-#endif // FLASH_ATTN_AVAILABLE
+#if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
 
     // Skip unused kernel variants for faster compilation:
 #ifdef FP16_MMA_AVAILABLE
@@ -290,7 +285,7 @@ static __global__ void flash_attn_tile_ext_f16(
     }
 #else
    NO_DEVICE_CODE;
-#endif // FP16_AVAILABLE
+#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
 }
 
 template <int cols_per_block, int parallel_blocks, bool use_logit_softcap>

+ 4 - 4
ggml/src/ggml-cuda/fattn-tile-f32.cu

@@ -44,10 +44,7 @@ static __global__ void flash_attn_tile_ext_f32(
         const int ne1,
         const int ne2,
         const int ne3) {
-#ifndef FLASH_ATTN_AVAILABLE
-    NO_DEVICE_CODE;
-    return;
-#endif // FLASH_ATTN_AVAILABLE
+#ifdef FLASH_ATTN_AVAILABLE
 
     // Skip unused kernel variants for faster compilation:
 #ifdef FP16_MMA_AVAILABLE
@@ -285,6 +282,9 @@ static __global__ void flash_attn_tile_ext_f32(
             dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
         }
     }
+#else
+    NO_DEVICE_CODE;
+#endif // FLASH_ATTN_AVAILABLE
 }
 
 template <int cols_per_block, int parallel_blocks, bool use_logit_softcap>

+ 2 - 7
ggml/src/ggml-cuda/fattn-vec-f16.cuh

@@ -41,12 +41,7 @@ static __global__ void flash_attn_vec_ext_f16(
         const int ne1,
         const int ne2,
         const int ne3) {
-#ifdef FP16_AVAILABLE
-
-#ifndef FLASH_ATTN_AVAILABLE
-    NO_DEVICE_CODE;
-    return;
-#endif // FLASH_ATTN_AVAILABLE
+#if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
 
     // Skip unused kernel variants for faster compilation:
     if (use_logit_softcap && !(D == 128 || D == 256)) {
@@ -300,7 +295,7 @@ static __global__ void flash_attn_vec_ext_f16(
     }
 #else
    NO_DEVICE_CODE;
-#endif // FP16_AVAILABLE
+#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
 }
 
 template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>

+ 4 - 4
ggml/src/ggml-cuda/fattn-vec-f32.cuh

@@ -41,10 +41,7 @@ static __global__ void flash_attn_vec_ext_f32(
         const int ne1,
         const int ne2,
         const int ne3) {
-#ifndef FLASH_ATTN_AVAILABLE
-    NO_DEVICE_CODE;
-    return;
-#endif // FLASH_ATTN_AVAILABLE
+#ifdef FLASH_ATTN_AVAILABLE
 
     // Skip unused kernel variants for faster compilation:
     if (use_logit_softcap && !(D == 128 || D == 256)) {
@@ -281,6 +278,9 @@ static __global__ void flash_attn_vec_ext_f32(
     if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
         dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
     }
+#else
+    NO_DEVICE_CODE;
+#endif // FLASH_ATTN_AVAILABLE
 }
 
 template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>

+ 2 - 2
ggml/src/ggml-cuda/fattn-wmma-f16.cu

@@ -51,7 +51,7 @@ static __global__ void flash_attn_ext_f16(
         const int ne1,
         const int ne2,
         const int ne3) {
-#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
+#if defined(FLASH_ATTN_AVAILABLE) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
     // Skip unused kernel variants for faster compilation:
     if (use_logit_softcap && !(D == 128 || D == 256)) {
         NO_DEVICE_CODE;
@@ -425,7 +425,7 @@ static __global__ void flash_attn_ext_f16(
     }
 #else
    NO_DEVICE_CODE;
-#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
+#endif // defined(FLASH_ATTN_AVAILABLE) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
 }
 
 constexpr int get_max_power_of_2(int x) {

+ 1 - 1
ggml/src/ggml-cuda/ggml-cuda.cu

@@ -3203,7 +3203,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_FLASH_ATTN_EXT: {
 #ifndef FLASH_ATTN_AVAILABLE
             return false;
-#endif
+#endif // FLASH_ATTN_AVAILABLE
             if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
                 return false;
             }

+ 4 - 0
ggml/src/ggml-hip/CMakeLists.txt

@@ -107,6 +107,10 @@ if (GGML_HIP_NO_VMM)
     add_compile_definitions(GGML_HIP_NO_VMM)
 endif()
 
+if (NOT GGML_CUDA_FA)
+    add_compile_definitions(GGML_CUDA_NO_FA)
+endif()
+
 if (CXX_IS_HIPCC)
     set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE CXX)
     target_link_libraries(ggml-hip PRIVATE hip::device)

+ 4 - 0
ggml/src/ggml-musa/CMakeLists.txt

@@ -83,6 +83,10 @@ if (MUSAToolkit_FOUND)
         add_compile_definitions(GGML_CUDA_NO_VMM)
     endif()
 
+    if (NOT GGML_CUDA_FA)
+        add_compile_definitions(GGML_CUDA_NO_FA)
+    endif()
+
     if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
         add_compile_definitions(GGML_CUDA_F16)
     endif()