Răsfoiți Sursa

ggml-cuda: Fix HIP build (#4528)

regression of #4490
Adds defines for two new datatypes
cublasComputeType_t, cudaDataType_t.

Currently using deprecated hipblasDatatype_t since newer ones very recent.
arlo-phoenix 2 ani în urmă
părinte
comite
a7aee47b98
1 a modificat fișierele cu 2 adăugiri și 0 ștergeri
  1. 2 0
      ggml-cuda.cu

+ 2 - 0
ggml-cuda.cu

@@ -31,6 +31,7 @@
 #define CUDA_R_16F  HIPBLAS_R_16F
 #define CUDA_R_16F  HIPBLAS_R_16F
 #define CUDA_R_32F  HIPBLAS_R_32F
 #define CUDA_R_32F  HIPBLAS_R_32F
 #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
 #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
+#define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6
 #define cublasCreate hipblasCreate
 #define cublasCreate hipblasCreate
 #define cublasGemmEx hipblasGemmEx
 #define cublasGemmEx hipblasGemmEx
 #define cublasGemmBatchedEx hipblasGemmBatchedEx
 #define cublasGemmBatchedEx hipblasGemmBatchedEx
@@ -40,6 +41,7 @@
 #define cublasSetStream hipblasSetStream
 #define cublasSetStream hipblasSetStream
 #define cublasSgemm hipblasSgemm
 #define cublasSgemm hipblasSgemm
 #define cublasStatus_t hipblasStatus_t
 #define cublasStatus_t hipblasStatus_t
+#define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6
 #define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
 #define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
 #define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
 #define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
 #define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
 #define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess