|
|
@@ -6,19 +6,19 @@
|
|
|
#include "fattn-common.cuh"
|
|
|
#include "fattn-wmma-f16.cuh"
|
|
|
|
|
|
-#ifdef FP16_MMA_AVAILABLE
|
|
|
+#ifdef GGML_USE_WMMA_FATTN
|
|
|
#if !defined(GGML_USE_HIP)
|
|
|
#include <mma.h>
|
|
|
-#ifdef GGML_USE_MUSA
|
|
|
+#if defined(GGML_USE_MUSA)
|
|
|
namespace wmma = mtmusa::wmma;
|
|
|
#else // GGML_USE_MUSA
|
|
|
namespace wmma = nvcuda::wmma;
|
|
|
#endif // GGML_USE_MUSA
|
|
|
-#elif defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)
|
|
|
+#elif defined(GGML_USE_HIP)
|
|
|
#include <rocwmma/rocwmma.hpp>
|
|
|
namespace wmma = rocwmma;
|
|
|
#endif // !defined(GGML_USE_HIP)
|
|
|
-#endif // FP16_MMA_AVAILABLE
|
|
|
+#endif // GGML_USE_WMMA_FATTN
|
|
|
|
|
|
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
|
|
|
template<int D, int ncols, int nwarps, int VKQ_stride, typename KQ_acc_t, bool use_logit_softcap>
|
|
|
@@ -45,7 +45,7 @@ static __global__ void flash_attn_ext_f16(
|
|
|
const int32_t nb21, const int32_t nb22, const int64_t nb23,
|
|
|
const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
|
|
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
|
|
|
-#if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
|
|
|
+#if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN)))
|
|
|
// Skip unused kernel variants for faster compilation:
|
|
|
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
|
|
NO_DEVICE_CODE;
|
|
|
@@ -481,7 +481,7 @@ static __global__ void flash_attn_ext_f16(
|
|
|
ne31, ne32, ne33,
|
|
|
nb31, nb32, nb33);
|
|
|
NO_DEVICE_CODE;
|
|
|
-#endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
|
|
|
+#endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN)))
|
|
|
}
|
|
|
|
|
|
constexpr int get_max_power_of_2(int x) {
|