|
|
@@ -6,15 +6,116 @@
|
|
|
#include <atomic>
|
|
|
#include <assert.h>
|
|
|
|
|
|
+#if defined(GGML_USE_HIPBLAS)
|
|
|
+#include <hip/hip_runtime.h>
|
|
|
+#include <hipblas/hipblas.h>
|
|
|
+#include <hip/hip_fp16.h>
|
|
|
+#ifdef __HIP_PLATFORM_AMD__
|
|
|
+// for rocblas_initialize()
|
|
|
+#include "rocblas/rocblas.h"
|
|
|
+#endif
|
|
|
+#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
|
|
|
+#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
|
|
|
+#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
|
|
|
+#define CUBLAS_OP_N HIPBLAS_OP_N
|
|
|
+#define CUBLAS_OP_T HIPBLAS_OP_T
|
|
|
+#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
|
|
|
+#define CUBLAS_TF32_TENSOR_OP_MATH 0
|
|
|
+#define CUDA_R_16F HIPBLAS_R_16F
|
|
|
+#define CUDA_R_32F HIPBLAS_R_32F
|
|
|
+#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
|
|
|
+#define cublasCreate hipblasCreate
|
|
|
+#define cublasGemmEx hipblasGemmEx
|
|
|
+#define cublasHandle_t hipblasHandle_t
|
|
|
+#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
|
|
|
+#define cublasSetStream hipblasSetStream
|
|
|
+#define cublasSgemm hipblasSgemm
|
|
|
+#define cublasStatus_t hipblasStatus_t
|
|
|
+#define cudaDeviceProp hipDeviceProp_t
|
|
|
+#define cudaDeviceSynchronize hipDeviceSynchronize
|
|
|
+#define cudaError_t hipError_t
|
|
|
+#define cudaEventCreateWithFlags hipEventCreateWithFlags
|
|
|
+#define cudaEventDisableTiming hipEventDisableTiming
|
|
|
+#define cudaEventRecord hipEventRecord
|
|
|
+#define cudaEvent_t hipEvent_t
|
|
|
+#define cudaEventDestroy hipEventDestroy
|
|
|
+#define cudaFree hipFree
|
|
|
+#define cudaFreeHost hipHostFree
|
|
|
+#define cudaGetDevice hipGetDevice
|
|
|
+#define cudaGetDeviceCount hipGetDeviceCount
|
|
|
+#define cudaGetDeviceProperties hipGetDeviceProperties
|
|
|
+#define cudaGetErrorString hipGetErrorString
|
|
|
+#define cudaGetLastError hipGetLastError
|
|
|
+#define cudaMalloc hipMalloc
|
|
|
+#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
|
|
|
+#define cudaMemcpy hipMemcpy
|
|
|
+#define cudaMemcpy2DAsync hipMemcpy2DAsync
|
|
|
+#define cudaMemcpyAsync hipMemcpyAsync
|
|
|
+#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice
|
|
|
+#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
|
|
|
+#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
|
|
|
+#define cudaMemcpyKind hipMemcpyKind
|
|
|
+#define cudaMemset hipMemset
|
|
|
+#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
|
|
|
+#define cudaSetDevice hipSetDevice
|
|
|
+#define cudaStreamCreateWithFlags hipStreamCreateWithFlags
|
|
|
+#define cudaStreamNonBlocking hipStreamNonBlocking
|
|
|
+#define cudaStreamSynchronize hipStreamSynchronize
|
|
|
+#define cudaStreamWaitEvent(stream, event) hipStreamWaitEvent(stream, event, 0)
|
|
|
+#define cudaStream_t hipStream_t
|
|
|
+#define cudaSuccess hipSuccess
|
|
|
+#else
|
|
|
#include <cuda_runtime.h>
|
|
|
#include <cublas_v2.h>
|
|
|
#include <cuda_fp16.h>
|
|
|
+#endif
|
|
|
|
|
|
#include "ggml-cuda.h"
|
|
|
#include "ggml.h"
|
|
|
|
|
|
#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
|
|
|
+#ifndef CC_TURING
|
|
|
#define CC_TURING 700
|
|
|
+#endif
|
|
|
+
|
|
|
+#if defined(GGML_USE_HIPBLAS)
|
|
|
+#define __CUDA_ARCH__ 1300
|
|
|
+
|
|
|
+typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
|
|
|
+static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
|
|
|
+ const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
|
|
|
+ const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
|
|
|
+ const int8x4_t c = __builtin_elementwise_sub_sat(va, vb);
|
|
|
+ return reinterpret_cast<const int&>(c);
|
|
|
+}
|
|
|
+
|
|
|
+static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
|
|
|
+#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__)
|
|
|
+ c = __builtin_amdgcn_sdot4(a, b, c, false);
|
|
|
+#elif defined(__gfx1100__)
|
|
|
+ c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
|
|
|
+#elif defined(__gfx1010__) || defined(__gfx900__)
|
|
|
+ int tmp1;
|
|
|
+ int tmp2;
|
|
|
+ asm("\n \
|
|
|
+ v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_0 src1_sel:BYTE_0 \n \
|
|
|
+ v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_1 src1_sel:BYTE_1 \n \
|
|
|
+ v_add3_u32 %0, %1, %2, %0 \n \
|
|
|
+ v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_2 src1_sel:BYTE_2 \n \
|
|
|
+ v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_3 src1_sel:BYTE_3 \n \
|
|
|
+ v_add3_u32 %0, %1, %2, %0 \n \
|
|
|
+ "
|
|
|
+ : "+v"(c), "=&v"(tmp1), "=&v"(tmp2)
|
|
|
+ : "v"(a), "v"(b)
|
|
|
+ );
|
|
|
+#else
|
|
|
+ const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
|
|
|
+ const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
|
|
|
+ c += va[0] * vb[0] + va[1] * vb[1] + va[2] * vb[2] + va[3] * vb[3];
|
|
|
+#endif
|
|
|
+ return c;
|
|
|
+}
|
|
|
+#endif
|
|
|
|
|
|
#if defined(_MSC_VER)
|
|
|
#pragma warning(disable: 4244 4267) // possible loss of data
|
|
|
@@ -424,8 +525,8 @@ static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const in
|
|
|
static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
|
|
|
const block_q4_1 * x = (const block_q4_1 *) vx;
|
|
|
|
|
|
- const dfloat d = x[ib].dm.x;
|
|
|
- const dfloat m = x[ib].dm.y;
|
|
|
+ const dfloat d = __low2half(x[ib].dm);
|
|
|
+ const dfloat m = __high2half(x[ib].dm);
|
|
|
|
|
|
const int vui = x[ib].qs[iqs];
|
|
|
|
|
|
@@ -467,8 +568,8 @@ static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const in
|
|
|
static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
|
|
|
const block_q5_1 * x = (const block_q5_1 *) vx;
|
|
|
|
|
|
- const dfloat d = x[ib].dm.x;
|
|
|
- const dfloat m = x[ib].dm.y;
|
|
|
+ const dfloat d = __low2half(x[ib].dm);
|
|
|
+ const dfloat m = __high2half(x[ib].dm);
|
|
|
|
|
|
uint32_t qh;
|
|
|
memcpy(&qh, x[ib].qh, sizeof(qh));
|
|
|
@@ -520,8 +621,8 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float
|
|
|
const uint8_t q = x[i].qs[32*n + l];
|
|
|
float * y = yy + i*QK_K + 128*n;
|
|
|
|
|
|
- float dall = x[i].dm.x;
|
|
|
- float dmin = x[i].dm.y;
|
|
|
+ float dall = __low2half(x[i].dm);
|
|
|
+ float dmin = __high2half(x[i].dm);
|
|
|
y[l+ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
|
|
|
y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4);
|
|
|
y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
|
|
|
@@ -531,8 +632,8 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float
|
|
|
const int il = tid%16; // 0...15
|
|
|
const uint8_t q = x[i].qs[il] >> (2*is);
|
|
|
float * y = yy + i*QK_K + 16*is + il;
|
|
|
- float dall = x[i].dm.x;
|
|
|
- float dmin = x[i].dm.y;
|
|
|
+ float dall = __low2half(x[i].dm);
|
|
|
+ float dmin = __high2half(x[i].dm);
|
|
|
y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
|
|
|
y[32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+2] >> 4);
|
|
|
#endif
|
|
|
@@ -618,8 +719,8 @@ static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, float
|
|
|
|
|
|
float * y = yy + i*QK_K + 64*il + n*ir;
|
|
|
|
|
|
- const float dall = x[i].dm.x;
|
|
|
- const float dmin = x[i].dm.y;
|
|
|
+ const float dall = __low2half(x[i].dm);
|
|
|
+ const float dmin = __high2half(x[i].dm);
|
|
|
|
|
|
const uint8_t * q = x[i].qs + 32*il + n*ir;
|
|
|
|
|
|
@@ -657,8 +758,8 @@ static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, float
|
|
|
|
|
|
float * y = yy + i*QK_K + 64*il + 2*ir;
|
|
|
|
|
|
- const float dall = x[i].dm.x;
|
|
|
- const float dmin = x[i].dm.y;
|
|
|
+ const float dall = __low2half(x[i].dm);
|
|
|
+ const float dmin = __high2half(x[i].dm);
|
|
|
|
|
|
const uint8_t * ql = x[i].qs + 32*il + 2*ir;
|
|
|
const uint8_t * qh = x[i].qh + 2*ir;
|
|
|
@@ -770,8 +871,8 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx,
|
|
|
const float * y = yy + i * QK_K + y_offset;
|
|
|
const uint8_t * q = x[i].qs + q_offset;
|
|
|
|
|
|
- const float dall = x[i].dm.x;
|
|
|
- const float dmin = x[i].dm.y;
|
|
|
+ const float dall = __low2half(x[i].dm);
|
|
|
+ const float dmin = __high2half(x[i].dm);
|
|
|
|
|
|
const uint32_t * a = (const uint32_t *)(x[i].scales + s_offset);
|
|
|
aux[0] = a[0] & 0x0f0f0f0f;
|
|
|
@@ -991,8 +1092,8 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx,
|
|
|
const float * y1 = yy + i*QK_K + y_offset;
|
|
|
const float * y2 = y1 + 128;
|
|
|
|
|
|
- const float dall = x[i].dm.x;
|
|
|
- const float dmin = x[i].dm.y;
|
|
|
+ const float dall = __low2half(x[i].dm);
|
|
|
+ const float dmin = __high2half(x[i].dm);
|
|
|
|
|
|
const uint16_t * a = (const uint16_t *)x[i].scales;
|
|
|
aux[0] = a[im+0] & kmask1;
|
|
|
@@ -1124,8 +1225,8 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx,
|
|
|
const float * y1 = yy + i*QK_K + y_offset;
|
|
|
const float * y2 = y1 + 128;
|
|
|
|
|
|
- const float dall = x[i].dm.x;
|
|
|
- const float dmin = x[i].dm.y;
|
|
|
+ const float dall = __low2half(x[i].dm);
|
|
|
+ const float dmin = __high2half(x[i].dm);
|
|
|
|
|
|
const uint16_t * a = (const uint16_t *)x[i].scales;
|
|
|
aux[0] = a[im+0] & kmask1;
|
|
|
@@ -1348,8 +1449,8 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
|
|
|
return;
|
|
|
}
|
|
|
|
|
|
- y[ib].ds.x = d;
|
|
|
- y[ib].ds.y = sum;
|
|
|
+ reinterpret_cast<half&>(y[ib].ds.x) = d;
|
|
|
+ reinterpret_cast<half&>(y[ib].ds.y) = sum;
|
|
|
}
|
|
|
|
|
|
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
|
|
|
@@ -2346,7 +2447,7 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1(
|
|
|
u[i] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
|
|
|
}
|
|
|
|
|
|
- return vec_dot_q8_0_q8_1_impl<VDR_Q8_0_Q8_1_MMVQ>(v, u, bq8_0->d, bq8_1->ds.x);
|
|
|
+ return vec_dot_q8_0_q8_1_impl<VDR_Q8_0_Q8_1_MMVQ>(v, u, bq8_0->d, __low2half(bq8_1->ds));
|
|
|
}
|
|
|
|
|
|
template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q8_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
|
|
|
@@ -2432,7 +2533,7 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1(
|
|
|
#pragma unroll
|
|
|
for (int i = 0; i < QR2_K; ++ i) {
|
|
|
u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1);
|
|
|
- d8[i] = bq8_1[bq8_offset + i].ds.x;
|
|
|
+ d8[i] = __low2half(bq8_1[bq8_offset + i].ds);
|
|
|
}
|
|
|
|
|
|
return vec_dot_q2_K_q8_1_impl_mmvq(v, u, scales, bq2_K->dm, d8);
|
|
|
@@ -2551,7 +2652,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1(
|
|
|
#pragma unroll
|
|
|
for (int i = 0; i < QR3_K; ++i) {
|
|
|
u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1);
|
|
|
- d8[i] = bq8_1[bq8_offset + i].ds.x;
|
|
|
+ d8[i] = __low2half(bq8_1[bq8_offset + i].ds);
|
|
|
}
|
|
|
|
|
|
return vec_dot_q3_K_q8_1_impl_mmvq(vl, vh, u, bq3_K->scales, scale_offset, d, d8);
|
|
|
@@ -2720,7 +2821,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
|
|
|
|
|
|
for (int i = 0; i < QR4_K; ++i) {
|
|
|
const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
|
|
|
- d8[i] = bq8i->ds.x;
|
|
|
+ d8[i] = __low2half(bq8i->ds);
|
|
|
|
|
|
const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4);
|
|
|
u[2*i+0] = q8[0];
|
|
|
@@ -2747,8 +2848,8 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
|
|
|
const float dall = bq4_K->d[0];
|
|
|
const float dmin = bq4_K->d[1];
|
|
|
|
|
|
- const float d8_1 = bq8_1[0].ds.x;
|
|
|
- const float d8_2 = bq8_1[1].ds.x;
|
|
|
+ const float d8_1 = __low2float(bq8_1[0].ds);
|
|
|
+ const float d8_2 = __low2float(bq8_1[1].ds);
|
|
|
|
|
|
const int ui1 = *((const int *)bq8_1[0].qs + (iqs/2));
|
|
|
const int ui2 = *((const int *)bq8_1[0].qs + (iqs/2) + 4);
|
|
|
@@ -2901,7 +3002,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
|
|
|
#pragma unroll
|
|
|
for (int i = 0; i < QR5_K; ++i) {
|
|
|
const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
|
|
|
- d8[i] = bq8i->ds.x;
|
|
|
+ d8[i] = __low2float(bq8i->ds);
|
|
|
|
|
|
const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4);
|
|
|
u[2*i+0] = q8[0];
|
|
|
@@ -2919,8 +3020,8 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
|
|
|
|
|
|
const float d = bq5_K->d;
|
|
|
|
|
|
- const float d8_1 = bq8_1[0].ds.x;
|
|
|
- const float d8_2 = bq8_1[1].ds.x;
|
|
|
+ const float d8_1 = __low2half(bq8_1[0].ds);
|
|
|
+ const float d8_2 = __low2half(bq8_1[1].ds);
|
|
|
|
|
|
const int ui1 = *((const int *)bq8_1[0].qs + (iqs/2));
|
|
|
const int ui2 = *((const int *)bq8_1[0].qs + (iqs/2) + 4);
|
|
|
@@ -3075,7 +3176,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1(
|
|
|
#pragma unroll
|
|
|
for (int i = 0; i < QR6_K; ++i) {
|
|
|
u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + 2*i].qs, iqs % QI8_1);
|
|
|
- d8[i] = bq8_1[bq8_offset + 2*i].ds.x;
|
|
|
+ d8[i] = __low2half(bq8_1[bq8_offset + 2*i].ds);
|
|
|
}
|
|
|
|
|
|
return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scales, bq6_K->d, d8);
|
|
|
@@ -3243,7 +3344,7 @@ static __device__ __forceinline__ void mul_mat_q(
|
|
|
*dsi_dst = *dsi_src;
|
|
|
} else {
|
|
|
float * dfi_dst = (float *) dsi_dst;
|
|
|
- *dfi_dst = (*dsi_src).x;
|
|
|
+ *dfi_dst = __low2half(*dsi_src);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -4944,10 +5045,18 @@ void ggml_init_cublas() {
|
|
|
static bool initialized = false;
|
|
|
|
|
|
if (!initialized) {
|
|
|
+
|
|
|
+#ifdef __HIP_PLATFORM_AMD__
|
|
|
+ // Workaround for a rocBLAS bug when using multiple graphics cards:
|
|
|
+ // https://github.com/ROCmSoftwarePlatform/rocBLAS/issues/1346
|
|
|
+ rocblas_initialize();
|
|
|
+ CUDA_CHECK(cudaDeviceSynchronize());
|
|
|
+#endif
|
|
|
+
|
|
|
CUDA_CHECK(cudaGetDeviceCount(&g_device_count));
|
|
|
GGML_ASSERT(g_device_count <= GGML_CUDA_MAX_DEVICES);
|
|
|
int64_t total_vram = 0;
|
|
|
- fprintf(stderr, "%s: found %d CUDA devices:\n", __func__, g_device_count);
|
|
|
+ fprintf(stderr, "%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, g_device_count);
|
|
|
for (int id = 0; id < g_device_count; ++id) {
|
|
|
cudaDeviceProp prop;
|
|
|
CUDA_CHECK(cudaGetDeviceProperties(&prop, id));
|