Просмотр исходного кода

cuda : fix HIP and MUSA BF16 (#0)

ggml-ci
Georgi Gerganov 9 месяцев назад
Родитель
Сommit
1a1ab7e7a4

+ 1 - 7
ggml/src/ggml-cuda/convert.cu

@@ -579,13 +579,7 @@ static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __res
 
     const src_t * x = (const src_t *) vx;
 
-    if constexpr (std::is_same_v<src_t, nv_bfloat16>) {
-        y[i] = __bfloat162float(x[i]);
-    } else if constexpr (std::is_same_v<dst_t, nv_bfloat16> && std::is_same_v<src_t, half>) {
-        y[i] = (float)x[i];
-    } else {
-        y[i] = x[i];
-    }
+    y[i] = float(x[i]);
 }
 
 template <typename src_t, typename dst_t>

+ 1 - 0
ggml/src/ggml-cuda/vendors/hip.h

@@ -20,6 +20,7 @@
 #define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
 #define CUBLAS_TF32_TENSOR_OP_MATH 0
 #define CUDA_R_16F  HIPBLAS_R_16F
+#define CUDA_R_16BF HIPBLAS_R_16B
 #define CUDA_R_32F  HIPBLAS_R_32F
 #define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED hipDeviceAttributeVirtualMemoryManagementSupported
 #define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED hipMemAllocationGranularityRecommended

+ 1 - 0
ggml/src/ggml-cuda/vendors/musa.h

@@ -15,6 +15,7 @@
 #define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS
 #define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_MATH_MODE_DEFAULT
 #define CUDA_R_16F  MUSA_R_16F
+#define CUDA_R_16BF MUSA_R_16BF
 #define CUDA_R_32F  MUSA_R_32F
 #define cublasComputeType_t cudaDataType_t
 #define cublasCreate mublasCreate