|
|
@@ -4,7 +4,7 @@
|
|
|
#include <hip/hip_runtime.h>
|
|
|
#include <hipblas/hipblas.h>
|
|
|
#include <hip/hip_fp16.h>
|
|
|
-#include <hip/hip_bfloat16.h>
|
|
|
+#include <hip/hip_bf16.h>
|
|
|
|
|
|
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
|
|
|
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
|
|
|
@@ -135,7 +135,7 @@
|
|
|
#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
|
|
|
#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
|
|
|
|
|
|
-#if HIP_VERSION >= 70000000
|
|
|
+#if HIP_VERSION >= 60500000
|
|
|
#define CUBLAS_COMPUTE_16F HIPBLAS_COMPUTE_16F
|
|
|
#define CUBLAS_COMPUTE_32F HIPBLAS_COMPUTE_32F
|
|
|
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_COMPUTE_32F_FAST_16F
|
|
|
@@ -147,7 +147,7 @@
|
|
|
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
|
|
|
#define cublasComputeType_t hipblasDatatype_t
|
|
|
#define cudaDataType_t hipblasDatatype_t
|
|
|
-#endif // HIP_VERSION >= 7000000
|
|
|
+#endif // HIP_VERSION >= 6050000
|
|
|
|
|
|
#if !defined(__HIP_PLATFORM_AMD__)
|
|
|
#error "The HIP backend supports only AMD targets"
|
|
|
@@ -179,8 +179,7 @@
|
|
|
#define RDNA4
|
|
|
#endif
|
|
|
|
|
|
-#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
|
|
|
- defined(__gfx1150__) || defined(__gfx1151__)
|
|
|
+#if defined(__GFX11__)
|
|
|
#define RDNA3
|
|
|
#endif
|
|
|
|
|
|
@@ -197,8 +196,8 @@
|
|
|
#define __has_builtin(x) 0
|
|
|
#endif
|
|
|
|
|
|
-typedef hip_bfloat16 nv_bfloat16;
|
|
|
-typedef short2 nv_bfloat162; // FIXME there is no 2x BF16 type being defined in bfloat16.h, ad-hoc compilation fix
|
|
|
+typedef __hip_bfloat16 nv_bfloat16;
|
|
|
+typedef __hip_bfloat162 nv_bfloat162;
|
|
|
|
|
|
typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
|
|
|
typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
|