Просмотр исходного кода

HIP: disable rocwmma on gfx12 by default until rocm 7.0 (#14202)

uvos 7 месяцев назад
Родитель
Сommit
7d6d91babf
3 измененных файлов с 7 добавлено и 2 удалено
  1. 1 0
      ggml/CMakeLists.txt
  2. 2 2
      ggml/src/ggml-cuda/common.cuh
  3. 4 0
      ggml/src/ggml-hip/CMakeLists.txt

+ 1 - 0
ggml/CMakeLists.txt

@@ -172,6 +172,7 @@ option(GGML_HIP                             "ggml: use HIP"
 option(GGML_HIP_GRAPHS                      "ggml: use HIP graph, experimental, slow"         OFF)
 option(GGML_HIP_NO_VMM                      "ggml: do not try to use HIP VMM"                 ON)
 option(GGML_HIP_ROCWMMA_FATTN               "ggml: enable rocWMMA for FlashAttention"         OFF)
+option(GGML_HIP_FORCE_ROCWMMA_FATTN_GFX12   "ggml: enable rocWMMA FlashAttention on GFX12"    OFF)
 option(GGML_VULKAN                          "ggml: use Vulkan"                                OFF)
 option(GGML_VULKAN_CHECK_RESULTS            "ggml: run Vulkan op checks"                      OFF)
 option(GGML_VULKAN_DEBUG                    "ggml: enable Vulkan debug output"                OFF)

+ 2 - 2
ggml/src/ggml-cuda/common.cuh

@@ -207,9 +207,9 @@ typedef float2 dfloat2;
 #define FP16_MMA_AVAILABLE
 #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
 
-#if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || defined(RDNA4))
+#if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4)))
 #define FP16_MMA_AVAILABLE
-#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || defined(RDNA4))
+#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4)))
 
 #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
 #define NEW_MMA_AVAILABLE

+ 4 - 0
ggml/src/ggml-hip/CMakeLists.txt

@@ -113,6 +113,10 @@ if (GGML_HIP_ROCWMMA_FATTN)
     add_compile_definitions(GGML_HIP_ROCWMMA_FATTN)
 endif()
 
+if (GGML_HIP_FORCE_ROCWMMA_FATTN_GFX12 OR ${hip_VERSION} VERSION_GREATER_EQUAL 7.0)
+    add_compile_definitions(GGML_HIP_ROCWMMA_FATTN_GFX12)
+endif()
+
 if (NOT GGML_CUDA_FA)
     add_compile_definitions(GGML_CUDA_NO_FA)
 endif()