Sfoglia il codice sorgente

CUDA: faster q2_K, q3_K MMQ + int8 tensor cores (#7921)

* CUDA: faster q2_K, q3_K MMQ + int8 tensor cores

* try CI fix

* try CI fix

* try CI fix

* fix data race

* rever q2_K precision related changes
Johannes Gäßler 1 anno fa
parent
commit
76d66ee0be
6 ha cambiato i file con 356 aggiunte e 203 eliminazioni
  1. 4 2
      ggml-cuda.cu
  2. 1 0
      ggml-cuda/argsort.cu
  3. 5 0
      ggml-cuda/common.cuh
  4. 329 182
      ggml-cuda/mmq.cuh
  5. 1 0
      ggml-cuda/softmax.cu
  6. 16 19
      ggml-cuda/vecdotq.cuh

+ 4 - 2
ggml-cuda.cu

@@ -188,13 +188,15 @@ static ggml_cuda_device_info ggml_cuda_init() {
         info.default_tensor_split[id] = total_vram;
         total_vram += prop.totalGlobalMem;
 
+        info.devices[id].nsm   = prop.multiProcessorCount;
+        info.devices[id].smpb  = prop.sharedMemPerBlock;
 #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+        info.devices[id].smpbo = prop.sharedMemPerBlock;
         info.devices[id].cc = 100*prop.major + 10*prop.minor + CC_OFFSET_AMD;
 #else
+        info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
         info.devices[id].cc = 100*prop.major + 10*prop.minor;
 #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
-        info.devices[id].smpb = prop.sharedMemPerBlock;
-        info.devices[id].nsm  = prop.multiProcessorCount;
     }
 
     for (int id = 0; id < info.device_count; ++id) {

+ 1 - 0
ggml-cuda/argsort.cu

@@ -73,6 +73,7 @@ static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, co
     const dim3 block_nums(1, nrows, 1);
     const size_t shared_mem = ncols_pad * sizeof(int);
 
+    // FIXME: this limit could be raised by ~2-4x on Ampere or newer
     GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);
 
     if (order == GGML_SORT_ORDER_ASC) {

+ 5 - 0
ggml-cuda/common.cuh

@@ -331,6 +331,10 @@ static __device__ __forceinline__ half2 __shfl_xor(half2 var, int laneMask, int
 #define FP16_AVAILABLE
 #endif // (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
 
+#if defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
+#define FAST_FP16_AVAILABLE
+#endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
+
 #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
 #define FP16_MMA_AVAILABLE
 #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
@@ -661,6 +665,7 @@ struct ggml_cuda_device_info {
         int     cc;                 // compute capability
         int     nsm;                // number of streaming multiprocessors
         size_t  smpb;               // max. shared memory per block
+        size_t  smpbo;              // max. shared memory per block (with opt-in)
         bool    vmm;                // virtual memory support
         size_t  vmm_granularity;    // granularity of virtual memory
         size_t  total_vram;

File diff suppressed because it is too large
+ 329 - 182
ggml-cuda/mmq.cuh


+ 1 - 0
ggml-cuda/softmax.cu

@@ -130,6 +130,7 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
     const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
     const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
 
+    // FIXME: this limit could be raised by ~2-4x on Ampere or newer
     if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
         switch (ncols_x) {
             case 32:

+ 16 - 19
ggml-cuda/vecdotq.cuh

@@ -265,36 +265,31 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq(
 
 // contiguous u/y values
 static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq(
-    const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ scales,
-    const half2 & dm2, const float & d8) {
+    const int * __restrict__ v, const int * __restrict__ u, const half2 * dm2, const float & d8) {
 
 #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
-    int sumi_d = 0;
-    int sumi_m = 0;
+    float sumf_d = 0.0f;
+    float sumf_m = 0.0f;
 
 #pragma unroll
     for (int i0 = 0; i0 < QI8_1; i0 += QI8_1/2) {
-        int sumi_d_sc = 0;
-
-        const int sc = scales[i0 / (QI8_1/2)];
-
-        // fill int with 4x m
-        int m = sc >> 4;
-        m |= m <<  8;
-        m |= m << 16;
+        const float2 dm2f = __half22float2(dm2[i0/(QI8_1/2)]);
+        int sumi_d = 0;
+        int sumi_m = 0;
 
+        const int vi0 = v[i0/(QI8_1/2)];
 #pragma unroll
         for (int i = i0; i < i0 + QI8_1/2; ++i) {
-            sumi_d_sc = __dp4a(v[i], u[i], sumi_d_sc); // SIMD dot product
-            sumi_m    = __dp4a(m,    u[i], sumi_m); // multiply sum of q8_1 values with m
+            const int vi = (vi0 >> (2*(i % (QI8_1/2)))) & 0x03030303;
+            sumi_d = __dp4a(vi,         u[i], sumi_d); // SIMD dot product
+            sumi_m = __dp4a(0x01010101, u[i], sumi_m);
         }
 
-        sumi_d += sumi_d_sc * (sc & 0xF);
+        sumf_d += dm2f.x * sumi_d;
+        sumf_m += dm2f.y * sumi_m;
     }
 
-    const float2 dm2f = __half22float2(dm2);
-
-    return d8 * (dm2f.x*sumi_d - dm2f.y*sumi_m);
+    return d8*(sumf_d - sumf_m);
 #else
     NO_DEVICE_CODE;
 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
@@ -352,8 +347,10 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq(
     for (int i0 = 0; i0 < QR3_K*VDR_Q3_K_Q8_1_MMQ; i0 += QI8_1/2) {
         int sumi_sc = 0;
 
+#pragma unroll
         for (int i = i0; i < i0 + QI8_1/2; ++i) {
-            sumi_sc = __dp4a(v[i], u[i], sumi_sc); // SIMD dot product
+            const int vi = __vsubss4((v[i/2] >> (4*(i%2))) & 0x0F0F0F0F, 0x04040404);
+            sumi_sc = __dp4a(vi, u[i], sumi_sc); // SIMD dot product
         }
 
         sumi += sumi_sc * scales[i0 / (QI8_1/2)];

Some files were not shown because too many files changed in this diff