Просмотр исходного кода

CUDA: optimize and refactor MMQ (#8416)

* CUDA: optimize and refactor MMQ

* explicit q8_1 memory layouts, add documentation
Johannes Gäßler 1 год назад
Родитель
Сommit
808aba3916

+ 4 - 0
ggml/src/ggml-cuda/mma.cuh

@@ -70,6 +70,10 @@ struct mma_int_A_I16K8 {
         }
 #endif // defined(INT8_MMA_AVAILABLE)
     }
+
+    __device__ __forceinline__ void load_low(const int * __restrict__ xs0, const int & stride) {
+        ((mma_int_A_I16K4 *) x)[0].load(xs0, stride);
+    }
 };
 
 struct mma_int_B_J8K4 {

Разница между файлами не показана из-за своего большого размера
+ 359 - 343
ggml/src/ggml-cuda/mmq.cuh


+ 82 - 25
ggml/src/ggml-cuda/quantize.cu

@@ -37,47 +37,92 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
     reinterpret_cast<half&>(y[ib].ds.y) = sum;
 }
 
-template <bool need_sum>
+template <mmq_q8_1_ds_layout ds_layout>
 static __global__ void quantize_mmq_q8_1(
     const float * __restrict__ x, void * __restrict__ vy, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded) {
 
-    const int64_t ix0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
+    constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32;
+    constexpr int vals_per_sum   = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32;
+
+    const int64_t ix0 = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*4;
 
     if (ix0 >= kx0_padded) {
         return;
     }
 
+    const float4 * x4 = (const float4 *) x;
+
     const int64_t ix1 = kx1*blockIdx.z + blockIdx.y;
 
     block_q8_1_mmq * y = (block_q8_1_mmq *) vy;
 
-    const int64_t ib0 = blockIdx.z*(gridDim.y*gridDim.x*blockDim.x/(4*QK8_1)); // first block of channel
-    const int64_t ib  = ib0 + (ix0 / (4*QK8_1))*kx1 + blockIdx.y;              // block index in channel
-    const int64_t iqs = ix0 % (4*QK8_1);                                       // quant index in block
-
-    const float xi = ix0 < kx0 ? x[ix1*kx0 + ix0] : 0.0f;
-    float amax = fabsf(xi);
-
-    amax = warp_reduce_max(amax);
+    const int64_t ib0 = blockIdx.z*((int64_t)gridDim.y*gridDim.x*blockDim.x/QK8_1); // first block of channel
+    const int64_t ib  = ib0 + (ix0 / (4*QK8_1))*kx1 + blockIdx.y;                   // block index in channel
+    const int64_t iqs = ix0 % (4*QK8_1);                                            // quant index in block
+
+    // Load 4 floats per thread and calculate max. abs. value between them:
+    const float4 xi = ix0 < kx0 ? x4[(ix1*kx0 + ix0)/4] : make_float4(0.0f, 0.0f, 0.0f, 0.0f);
+    float amax = fabsf(xi.x);
+    amax = fmaxf(amax, fabsf(xi.y));
+    amax = fmaxf(amax, fabsf(xi.z));
+    amax = fmaxf(amax, fabsf(xi.w));
+
+    // Exchange max. abs. value between vals_per_scale/4 threads.
+#pragma unroll
+    for (int mask = vals_per_scale/8; mask > 0; mask >>= 1) {
+        amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, WARP_SIZE));
+    }
 
     float sum;
-    if (need_sum) {
-        sum = warp_reduce_sum(xi);
+    if (ds_layout != MMQ_Q8_1_DS_LAYOUT_D4) {
+        sum = xi.x + xi.y + xi.z + xi.w;
+
+        // Exchange calculate sum across vals_per_sum/4 threads.
+#pragma unroll
+        for (int mask = vals_per_sum/8; mask > 0; mask >>= 1) {
+            sum += __shfl_xor_sync(0xFFFFFFFF, sum, mask, WARP_SIZE);
+        }
     }
 
-    const float d = amax / 127;
-    const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
+    const float d_inv = 127.0f / amax;
+    char4 q;
+    q.x = roundf(xi.x*d_inv);
+    q.y = roundf(xi.y*d_inv);
+    q.z = roundf(xi.z*d_inv);
+    q.w = roundf(xi.w*d_inv);
 
-    y[ib].qs[iqs] = q;
+    // Write back 4 int8 values as a single 32 bit value for better memroy bandwidth:
+    char4 * yqs4 = (char4 *) y[ib].qs;
+    yqs4[iqs/4] = q;
+
+    if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6) {
+        if (iqs % 16 != 0 || iqs >= 96) {
+            return;
+        }
+
+        y[ib].d2s6[2 + iqs/16] = sum;
+
+        if (iqs % 64 != 0) {
+            return;
+        }
+
+        const float d = 1.0f / d_inv;
+
+        y[ib].d2s6[iqs/64] = d;
 
-    if (iqs % QK8_1 != 0) {
         return;
     }
 
-    if (need_sum) {
-        y[ib].ds[iqs/QK8_1] = make_half2(d, sum);
+    if (iqs % 32 != 0) {
+        return;
+    }
+
+    const float d = 1.0f / d_inv;
+
+    if (ds_layout == MMQ_Q8_1_DS_LAYOUT_DS4) {
+        y[ib].ds4[iqs/32] = make_half2(d, sum);
     } else {
-        ((float *) y[ib].ds)[iqs/QK8_1] = d;
+        y[ib].d4[iqs/32]  = d;
     }
 }
 
@@ -101,12 +146,24 @@ void quantize_mmq_q8_1_cuda(
 
     GGML_ASSERT(kx0_padded % (4*QK8_1) == 0);
 
-    const int64_t block_num_x = (kx0_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
+    const int64_t block_num_x = (kx0_padded + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
     const dim3 num_blocks(block_num_x, kx1, channels);
-    const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1);
-    if (mmq_need_sum(type_x)) {
-        quantize_mmq_q8_1<true><<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
-    } else {
-        quantize_mmq_q8_1<false><<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
+    const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1);
+    switch (mmq_get_q8_1_ds_layout(type_x)) {
+        case MMQ_Q8_1_DS_LAYOUT_D4:
+            quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D4>
+                <<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
+            break;
+        case MMQ_Q8_1_DS_LAYOUT_DS4:
+            quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_DS4>
+                <<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
+            break;
+        case MMQ_Q8_1_DS_LAYOUT_D2S6:
+            quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D2S6>
+                <<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
+            break;
+        default:
+            GGML_ASSERT(false);
+            break;
     }
 }

+ 5 - 1
ggml/src/ggml-cuda/quantize.cuh

@@ -5,7 +5,11 @@
 
 #include <cstdint>
 
-#define CUDA_QUANTIZE_BLOCK_SIZE 256
+#define CUDA_QUANTIZE_BLOCK_SIZE     256
+#define CUDA_QUANTIZE_BLOCK_SIZE_MMQ 128
+
+static_assert(MATRIX_ROW_PADDING %    CUDA_QUANTIZE_BLOCK_SIZE      == 0, "Risk of out-of-bounds access.");
+static_assert(MATRIX_ROW_PADDING % (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ) == 0, "Risk of out-of-bounds access.");
 
 typedef void (*quantize_cuda_t)(
     const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded,

+ 49 - 23
ggml/src/ggml-cuda/vecdotq.cuh

@@ -189,7 +189,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_1_q8_1_imp
 }
 
 #define VDR_Q2_K_Q8_1_MMVQ 1
-#define VDR_Q2_K_Q8_1_MMQ  2
+#define VDR_Q2_K_Q8_1_MMQ  4
 
 // contiguous v/x values
 static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq(
@@ -219,32 +219,56 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq(
     return dm2f.x*sumf_d - dm2f.y*sumf_m;
 }
 
-// contiguous u/y values
+// contiguous v/x + u/y values
+template <int ns8>
 static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq(
-    const int * __restrict__ v, const int * __restrict__ u, const half2 * dm2, const float & d8) {
+    const int * __restrict__ v, const int * __restrict__ u, const half2 * dm2, const float & d8, const half2 * s8) {
 
-    float sumf_d = 0.0f;
-    float sumf_m = 0.0f;
+    float sumf    = 0.0f;
+    float sumf_d8 = 0.0f;
 
 #pragma unroll
-    for (int i0 = 0; i0 < QI8_1; i0 += QI8_1/2) {
-        const float2 dm2f = __half22float2(dm2[i0/(QI8_1/2)]);
-        int sumi_d = 0;
-        int sumi_m = 0;
+    for (int i0 = 0; i0 < QR2_K*VDR_Q2_K_Q8_1_MMQ; i0 += QI8_1) {
+        const float2 dm2f0 = __half22float2(dm2[i0/(QI8_1/2) + 0]);
+        int sumi_d0 = 0;
+
+        const float2 dm2f1 = __half22float2(dm2[i0/(QI8_1/2) + 1]);
+        int sumi_d1 = 0;
 
-        const int vi0 = v[i0/(QI8_1/2)];
 #pragma unroll
         for (int i = i0; i < i0 + QI8_1/2; ++i) {
-            const int vi = (vi0 >> (2*(i % (QI8_1/2)))) & 0x03030303;
-            sumi_d = ggml_cuda_dp4a(vi,         u[i], sumi_d); // SIMD dot product
-            sumi_m = ggml_cuda_dp4a(0x01010101, u[i], sumi_m);
+            sumi_d0 = ggml_cuda_dp4a(v[i], u[i], sumi_d0);
+        }
+        sumf_d8 += dm2f0.x * sumi_d0;
+
+#pragma unroll
+        for (int i = i0 + QI8_1/2; i < i0 + QI8_1; ++i) {
+            sumi_d1 = ggml_cuda_dp4a(v[i], u[i], sumi_d1);
         }
+        sumf_d8 += dm2f1.x * sumi_d1;
 
-        sumf_d += dm2f.x * sumi_d;
-        sumf_m += dm2f.y * sumi_m;
+        if (i0/QI8_1 < ns8) {
+            const float2 s8f = __half22float2(s8[i0/QI8_1]);
+            sumf -= dm2f0.y*s8f.x;
+            sumf -= dm2f1.y*s8f.y;
+        } else {
+            int sumi_m0 = 0;
+#pragma unroll
+            for (int i = i0; i < i0 + QI8_1/2; ++i) {
+                sumi_m0 = ggml_cuda_dp4a(0x01010101, u[i], sumi_m0);
+            }
+            sumf_d8 -= dm2f0.y * sumi_m0;
+
+            int sumi_m1 = 0;
+#pragma unroll
+            for (int i = i0 + QI8_1/2; i < i0 + QI8_1; ++i) {
+                sumi_m1 = ggml_cuda_dp4a(0x01010101, u[i], sumi_m1);
+            }
+            sumf_d8 -= dm2f1.y * sumi_m1;
+        }
     }
 
-    return d8*(sumf_d - sumf_m);
+    return sumf + d8*sumf_d8;
 }
 
 #define VDR_Q3_K_Q8_1_MMVQ 1
@@ -283,7 +307,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq(
     return d3 * sumf;
 }
 
-// contiguous u/y values
+// contiguous v/x + u/y values
 static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq(
     const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ scales,
     const float & d3, const float & d8) {
@@ -296,8 +320,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq(
 
 #pragma unroll
         for (int i = i0; i < i0 + QI8_1/2; ++i) {
-            const int vi = __vsubss4((v[i/2] >> (4*(i%2))) & 0x0F0F0F0F, 0x04040404);
-            sumi_sc = ggml_cuda_dp4a(vi, u[i], sumi_sc); // SIMD dot product
+            sumi_sc = ggml_cuda_dp4a(v[i], u[i], sumi_sc); // SIMD dot product
         }
 
         sumi += sumi_sc * scales[i0 / (QI8_1/2)];
@@ -334,7 +357,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq(
     return dm4f.x*sumf_d - dm4f.y*sumf_m;
 }
 
-// contiguous u/y values
+// contiguous v/x + u/y values
 static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq(
     const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
     const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) {
@@ -397,7 +420,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq(
     return dm5f.x*sumf_d - dm5f.y*sumf_m;
 }
 
-// contiguous u/y values
+// contiguous v/x + u/y values
 static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq(
     const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
     const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) {
@@ -451,13 +474,16 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq(
     return d*sumf;
 }
 
-// contiguous u/y values
+// contiguous v/x + u/y values
 static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(
     const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ sc,
     const float & d6, const float * __restrict__ d8) {
 
     float sumf_d = 0.0f;
 
+    const int      sc_packed = get_int_b4(sc, 0);
+    const int8_t * sc_reg    = (const int8_t *) &sc_packed;
+
 #pragma unroll
     for (int i0 = 0; i0 < VDR_Q6_K_Q8_1_MMQ; i0 += 4) {
         int2 sumi_d = {0, 0}; // 2 q6_K scales per q8_1 scale
@@ -471,7 +497,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(
             sumi_d.y = ggml_cuda_dp4a(v[2*i+5], u[2*i+5], sumi_d.y); // SIMD dot product
         }
 
-        sumf_d += d8[i0/4] * (sc[i0/2+0]*sumi_d.x + sc[i0/2+1]*sumi_d.y);
+        sumf_d += d8[i0/4] * (sc_reg[i0/2+0]*sumi_d.x + sc_reg[i0/2+1]*sumi_d.y);
     }
 
     return d6 * sumf_d;

Некоторые файлы не были показаны из-за большого количества измененных файлов