|
|
@@ -12,6 +12,10 @@
|
|
|
#else
|
|
|
#define GGML_COMMON_DECL_CUDA
|
|
|
#define GGML_COMMON_IMPL_CUDA
|
|
|
+#if defined(GGML_USE_MUSA)
|
|
|
+#define GGML_COMMON_DECL_MUSA
|
|
|
+#define GGML_COMMON_IMPL_MUSA
|
|
|
+#endif
|
|
|
#endif
|
|
|
#include "ggml-common.h"
|
|
|
|
|
|
@@ -114,6 +118,150 @@
|
|
|
#define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED
|
|
|
#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
|
|
|
#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
|
|
|
+#elif defined(GGML_USE_MUSA)
|
|
|
+#include <musa_runtime.h>
|
|
|
+#include <musa.h>
|
|
|
+#include <mublas.h>
|
|
|
+#include <musa_fp16.h>
|
|
|
+// XXX: Keep the following order the same as hipBLAS
|
|
|
+// #define CUBLAS_COMPUTE_16F MUBLAS_COMPUTE_16F
|
|
|
+// #define CUBLAS_COMPUTE_32F MUBLAS_COMPUTE_32F
|
|
|
+#define CUBLAS_COMPUTE_32F_FAST_16F MUBLAS_COMPUTE_32F_FAST_16F
|
|
|
+#define CUBLAS_GEMM_DEFAULT MUBLAS_GEMM_DEFAULT
|
|
|
+#define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT
|
|
|
+#define CUBLAS_OP_N MUBLAS_OP_N
|
|
|
+#define CUBLAS_OP_T MUBLAS_OP_T
|
|
|
+#define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS
|
|
|
+// #define CUBLAS_TF32_TENSOR_OP_MATH 0
|
|
|
+#define CUDA_R_16F MUSA_R_16F
|
|
|
+#define CUDA_R_32F MUSA_R_32F
|
|
|
+// #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
|
|
|
+// #define cublasComputeType_t mublasComputeType_t
|
|
|
+#define cublasCreate mublasCreate
|
|
|
+#define cublasDestroy mublasDestroy
|
|
|
+#define cublasGemmEx mublasGemmEx
|
|
|
+#define cublasGemmBatchedEx mublasGemmBatchedEx
|
|
|
+#define cublasGemmStridedBatchedEx mublasGemmStridedBatchedEx
|
|
|
+#define cublasHandle_t mublasHandle_t
|
|
|
+// #define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
|
|
|
+#define cublasSetMathMode mublasSetMathMode
|
|
|
+#define cublasSetStream mublasSetStream
|
|
|
+#define cublasSgemm mublasSgemm
|
|
|
+#define cublasStatus_t mublasStatus_t
|
|
|
+#define cudaDataType_t musaDataType_t //deprecated, new hipblasDatatype not in 5.6
|
|
|
+#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer
|
|
|
+#define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess
|
|
|
+#define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess
|
|
|
+#define cudaDeviceProp musaDeviceProp
|
|
|
+#define cudaDeviceSynchronize musaDeviceSynchronize
|
|
|
+#define cudaError_t musaError_t
|
|
|
+#define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled
|
|
|
+#define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled
|
|
|
+#define cudaEventCreateWithFlags musaEventCreateWithFlags
|
|
|
+#define cudaEventDisableTiming musaEventDisableTiming
|
|
|
+#define cudaEventRecord musaEventRecord
|
|
|
+#define cudaEventSynchronize musaEventSynchronize
|
|
|
+#define cudaEvent_t musaEvent_t
|
|
|
+#define cudaEventDestroy musaEventDestroy
|
|
|
+#define cudaFree musaFree
|
|
|
+#define cudaFreeHost musaFreeHost
|
|
|
+#define cudaGetDevice musaGetDevice
|
|
|
+#define cudaGetDeviceCount musaGetDeviceCount
|
|
|
+#define cudaGetDeviceProperties musaGetDeviceProperties
|
|
|
+#define cudaGetErrorString musaGetErrorString
|
|
|
+#define cudaGetLastError musaGetLastError
|
|
|
+#define cudaHostRegister musaHostRegister
|
|
|
+#define cudaHostRegisterPortable musaHostRegisterPortable
|
|
|
+#define cudaHostRegisterReadOnly musaHostRegisterReadOnly
|
|
|
+#define cudaHostUnregister musaHostUnregister
|
|
|
+#define cudaLaunchHostFunc musaLaunchHostFunc
|
|
|
+#define cudaMalloc musaMalloc
|
|
|
+#define cudaMallocHost musaMallocHost
|
|
|
+#define cudaMemcpy musaMemcpy
|
|
|
+#define cudaMemcpyAsync musaMemcpyAsync
|
|
|
+#define cudaMemcpyPeerAsync musaMemcpyPeerAsync
|
|
|
+#define cudaMemcpy2DAsync musaMemcpy2DAsync
|
|
|
+#define cudaMemcpyDeviceToDevice musaMemcpyDeviceToDevice
|
|
|
+#define cudaMemcpyDeviceToHost musaMemcpyDeviceToHost
|
|
|
+#define cudaMemcpyHostToDevice musaMemcpyHostToDevice
|
|
|
+#define cudaMemcpyKind musaMemcpyKind
|
|
|
+#define cudaMemset musaMemset
|
|
|
+#define cudaMemsetAsync musaMemsetAsync
|
|
|
+#define cudaMemGetInfo musaMemGetInfo
|
|
|
+#define cudaOccupancyMaxPotentialBlockSize musaOccupancyMaxPotentialBlockSize
|
|
|
+#define cudaSetDevice musaSetDevice
|
|
|
+#define cudaStreamCreateWithFlags musaStreamCreateWithFlags
|
|
|
+#define cudaStreamDestroy musaStreamDestroy
|
|
|
+#define cudaStreamFireAndForget musaStreamFireAndForget
|
|
|
+#define cudaStreamNonBlocking musaStreamNonBlocking
|
|
|
+#define cudaStreamPerThread musaStreamPerThread
|
|
|
+#define cudaStreamSynchronize musaStreamSynchronize
|
|
|
+#define cudaStreamWaitEvent musaStreamWaitEvent
|
|
|
+#define cudaStream_t musaStream_t
|
|
|
+#define cudaSuccess musaSuccess
|
|
|
+
|
|
|
+// XXX: Other CUDA => MUSA mapping
|
|
|
+#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE MU_MEM_ACCESS_FLAGS_PROT_READWRITE
|
|
|
+#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED MU_MEM_ALLOC_GRANULARITY_RECOMMENDED
|
|
|
+#define CU_MEM_ALLOCATION_TYPE_PINNED MU_MEM_ALLOCATION_TYPE_PINNED
|
|
|
+#define CU_MEM_LOCATION_TYPE_DEVICE MU_MEM_LOCATION_TYPE_DEVICE
|
|
|
+#define CUdevice MUdevice
|
|
|
+#define CUdeviceptr MUdeviceptr
|
|
|
+#define CUmemAccessDesc MUmemAccessDesc
|
|
|
+#define CUmemAllocationProp MUmemAllocationProp
|
|
|
+#define CUmemGenericAllocationHandle MUmemGenericAllocationHandle
|
|
|
+#define cuDeviceGet muDeviceGet
|
|
|
+#define cuDeviceGetAttribute muDeviceGetAttribute
|
|
|
+#define cuMemAddressFree muMemAddressFree
|
|
|
+#define cuMemAddressReserve muMemAddressReserve
|
|
|
+#define cuMemCreate muMemCreate
|
|
|
+#define cuMemGetAllocationGranularity muMemGetAllocationGranularity
|
|
|
+#define cuMemMap muMemMap
|
|
|
+#define cuMemRelease muMemRelease
|
|
|
+#define cuMemSetAccess muMemSetAccess
|
|
|
+#define cuMemUnmap muMemUnmap
|
|
|
+#define cudaFuncAttributeMaxDynamicSharedMemorySize musaFuncAttributeMaxDynamicSharedMemorySize
|
|
|
+#define cudaFuncSetAttribute musaFuncSetAttribute
|
|
|
+#define cudaMemcpy3DPeerParms musaMemcpy3DPeerParms
|
|
|
+#define make_cudaExtent make_musaExtent
|
|
|
+#define make_cudaPitchedPtr make_musaPitchedPtr
|
|
|
+
|
|
|
+// XXX: USE_CUDA_GRAPH
|
|
|
+#define CUDA_SUCCESS MUSA_SUCCESS
|
|
|
+#define CUresult MUresult
|
|
|
+#define cuGetErrorString muGetErrorString
|
|
|
+#define cudaErrorGraphExecUpdateFailure musaErrorGraphExecUpdateFailure
|
|
|
+#define cudaErrorInvalidDeviceFunction musaErrorInvalidDeviceFunction
|
|
|
+#define cudaGraphDestroy musaGraphDestroy
|
|
|
+#define cudaGraphExecDestroy musaGraphExecDestroy
|
|
|
+#define cudaGraphExec_t musaGraphExec_t
|
|
|
+#define cudaGraphExecUpdate musaGraphExecUpdate
|
|
|
+#define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult
|
|
|
+#define cudaGraphGetNodes musaGraphGetNodes
|
|
|
+#define cudaGraphInstantiate musaGraphInstantiate
|
|
|
+#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
|
|
|
+#define cudaGraphKernelNodeSetParams musaGraphKernelNodeSetParams
|
|
|
+#define cudaGraphLaunch musaGraphLaunch
|
|
|
+#define cudaGraphNodeGetType musaGraphNodeGetType
|
|
|
+#define cudaGraphNode_t musaGraphNode_t
|
|
|
+#define cudaGraphNodeType musaGraphNodeType
|
|
|
+#define cudaGraphNodeTypeKernel musaGraphNodeTypeKernel
|
|
|
+#define cudaGraph_t musaGraph_t
|
|
|
+#define cudaKernelNodeParams musaKernelNodeParams
|
|
|
+#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
|
|
|
+#define cudaStreamEndCapture musaStreamEndCapture
|
|
|
+
|
|
|
+// XXX: cuBLAS => muBLAS mapping
|
|
|
+#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED MU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
|
|
|
+#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_MATH_MODE_DEFAULT
|
|
|
+#define CUBLAS_COMPUTE_16F CUDA_R_16F
|
|
|
+#define CUBLAS_COMPUTE_32F CUDA_R_32F
|
|
|
+#define cublasComputeType_t cudaDataType_t
|
|
|
+
|
|
|
+// XXX: Clang builtins mapping
|
|
|
+#define __vsub4 __vsub4_musa
|
|
|
+#define __vcmpeq4 __vcmpeq4_musa
|
|
|
+#define __vcmpne4 __vcmpne4_musa
|
|
|
#else
|
|
|
#include <cuda_runtime.h>
|
|
|
#include <cuda.h>
|
|
|
@@ -168,9 +316,13 @@ void ggml_cuda_error(const char * stmt, const char * func, const char * file, in
|
|
|
|
|
|
#define CUDA_CHECK(err) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString)
|
|
|
|
|
|
-#if CUDART_VERSION >= 12000
|
|
|
+#if CUDART_VERSION >= 12000 || defined(GGML_USE_MUSA)
|
|
|
static const char * cublas_get_error_str(const cublasStatus_t err) {
|
|
|
+#ifndef GGML_USE_MUSA
|
|
|
return cublasGetStatusString(err);
|
|
|
+#else
|
|
|
+ return mublasStatus_to_string(err);
|
|
|
+#endif // GGML_USE_MUSA
|
|
|
}
|
|
|
#else
|
|
|
static const char * cublas_get_error_str(const cublasStatus_t err) {
|
|
|
@@ -200,7 +352,7 @@ static const char * cu_get_error_str(CUresult err) {
|
|
|
#define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)
|
|
|
#endif
|
|
|
|
|
|
-#if CUDART_VERSION >= 11100
|
|
|
+#if CUDART_VERSION >= 11100 || defined(GGML_USE_MUSA)
|
|
|
#define GGML_CUDA_ASSUME(x) __builtin_assume(x)
|
|
|
#else
|
|
|
#define GGML_CUDA_ASSUME(x)
|
|
|
@@ -214,6 +366,42 @@ typedef float dfloat; // dequantize float
|
|
|
typedef float2 dfloat2;
|
|
|
#endif //GGML_CUDA_F16
|
|
|
|
|
|
+#if defined(GGML_USE_MUSA)
|
|
|
+#ifndef __has_builtin
|
|
|
+ #define __has_builtin(x) 0
|
|
|
+#endif
|
|
|
+
|
|
|
+typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
|
|
|
+
|
|
|
+static __device__ __forceinline__ int __vsub4_musa(const int a, const int b) {
|
|
|
+ return __vsubss4(a, b);
|
|
|
+}
|
|
|
+
|
|
|
+static __device__ __forceinline__ unsigned int __vcmpeq4_musa(unsigned int a, unsigned int b) {
|
|
|
+ const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
|
|
|
+ const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
|
|
|
+ unsigned int c;
|
|
|
+ uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
|
|
|
+#pragma unroll
|
|
|
+ for (int i = 0; i < 4; ++i) {
|
|
|
+ vc[i] = va[i] == vb[i] ? 0xff : 0x00;
|
|
|
+ }
|
|
|
+ return c;
|
|
|
+}
|
|
|
+
|
|
|
+static __device__ __forceinline__ unsigned int __vcmpne4_musa(unsigned int a, unsigned int b) {
|
|
|
+ const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
|
|
|
+ const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
|
|
|
+ unsigned int c;
|
|
|
+ uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
|
|
|
+#pragma unroll
|
|
|
+ for (int i = 0; i < 4; ++i) {
|
|
|
+ vc[i] = va[i] == vb[i] ? 0x00 : 0xff;
|
|
|
+ }
|
|
|
+ return c;
|
|
|
+}
|
|
|
+#endif // defined(GGML_USE_MUSA)
|
|
|
+
|
|
|
#if defined(GGML_USE_HIPBLAS)
|
|
|
#define __CUDA_ARCH__ 1300
|
|
|
|
|
|
@@ -455,7 +643,7 @@ static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half
|
|
|
const uint32_t mask_high = 0xFFFF0000 * (float(__high2half(a)) > float(__high2half(b)));
|
|
|
return mask_low | mask_high;
|
|
|
}
|
|
|
-#endif // CUDART_VERSION < 12000
|
|
|
+#endif // CUDART_VERSION < CUDART_HMASK
|
|
|
|
|
|
static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, int c) {
|
|
|
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|