Просмотр исходного кода

CUDA: optimize MMQ int8 tensor core performance (#8062)

* CUDA: optimize MMQ int8 tensor core performance

* only a single get_mma_tile_x_k function

* simplify code, make functions constexpr
Johannes Gäßler 1 год назад
Родитель
Сommit
9a590c8226
3 измененных файлов с 441 добавлено и 269 удалено
  1. 2 2
      ggml-cuda/common.cuh
  2. 56 0
      ggml-cuda/mma.cuh
  3. 383 267
      ggml-cuda/mmq.cuh

+ 2 - 2
ggml-cuda/common.cuh

@@ -643,7 +643,7 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ3_S> {
     static constexpr int qi = QI3_S;
     static constexpr int qi = QI3_S;
 };
 };
 
 
-static int get_mmq_x_max_host(const int cc) {
+static constexpr int get_mmq_x_max_host(int cc) {
 #ifdef CUDA_USE_TENSOR_CORES
 #ifdef CUDA_USE_TENSOR_CORES
     return cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? MMQ_MAX_BATCH_SIZE : 64;
     return cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? MMQ_MAX_BATCH_SIZE : 64;
 #else
 #else
@@ -652,7 +652,7 @@ static int get_mmq_x_max_host(const int cc) {
 }
 }
 
 
 // Round rows to this value for --split-mode row:
 // Round rows to this value for --split-mode row:
-static int get_mmq_y_host(const int cc) {
+static constexpr int get_mmq_y_host(int cc) {
     return cc >= CC_VOLTA ? 128 : 64;
     return cc >= CC_VOLTA ? 128 : 64;
 }
 }
 
 

+ 56 - 0
ggml-cuda/mma.cuh

@@ -20,6 +20,20 @@ struct mma_int_A_I16K4 {
         GGML_CUDA_ASSUME(ret <  K);
         GGML_CUDA_ASSUME(ret <  K);
         return ret;
         return ret;
     }
     }
+
+    __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
+#if defined(INT8_MMA_AVAILABLE)
+        const int * xs = xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
+        asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
+            : "+r"(x[0]), "+r"(x[1])
+            : "l"(xs));
+#else
+#pragma unroll
+        for (int l = 0; l < ne; ++l) {
+            x[l] = xs0[get_i(l)*stride + get_k(l)];
+        }
+#endif // defined(INT8_MMA_AVAILABLE)
+    }
 };
 };
 
 
 struct mma_int_A_I16K8 {
 struct mma_int_A_I16K8 {
@@ -42,6 +56,20 @@ struct mma_int_A_I16K8 {
         GGML_CUDA_ASSUME(ret <  K);
         GGML_CUDA_ASSUME(ret <  K);
         return ret;
         return ret;
     }
     }
+
+    __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
+#if defined(INT8_MMA_AVAILABLE)
+        const int * xs = xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
+        asm("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
+            : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
+            : "l"(xs));
+#else
+#pragma unroll
+        for (int l = 0; l < ne; ++l) {
+            x[l] = xs0[get_i(l)*stride + get_k(l)];
+        }
+#endif // defined(INT8_MMA_AVAILABLE)
+    }
 };
 };
 
 
 struct mma_int_B_J8K4 {
 struct mma_int_B_J8K4 {
@@ -64,6 +92,20 @@ struct mma_int_B_J8K4 {
         GGML_CUDA_ASSUME(ret <  K);
         GGML_CUDA_ASSUME(ret <  K);
         return ret;
         return ret;
     }
     }
+
+    __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
+#if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster
+        const int * xs = xs0 + (threadIdx.x%J)*stride;
+        asm("ldmatrix.sync.aligned.m8n8.x1.b16 {%0}, [%1];"
+            : "+r"(x[0])
+            : "l"(xs));
+#else
+#pragma unroll
+        for (int l = 0; l < ne; ++l) {
+            x[l] = xs0[get_j(l)*stride + get_k(l)];
+        }
+#endif // defined(INT8_MMA_AVAILABLE)
+    }
 };
 };
 
 
 struct mma_int_B_J8K8 {
 struct mma_int_B_J8K8 {
@@ -86,6 +128,20 @@ struct mma_int_B_J8K8 {
         GGML_CUDA_ASSUME(ret <  K);
         GGML_CUDA_ASSUME(ret <  K);
         return ret;
         return ret;
     }
     }
+
+    __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
+#if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster
+        const int * xs = xs0 + (threadIdx.x%J)*stride + ((threadIdx.x/J)*(K/2)) % K;
+        asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
+            : "+r"(x[0]), "+r"(x[1])
+            : "l"(xs));
+#else
+#pragma unroll
+        for (int l = 0; l < ne; ++l) {
+            x[l] = xs0[get_j(l)*stride + get_k(l)];
+        }
+#endif // defined(INT8_MMA_AVAILABLE)
+    }
 };
 };
 
 
 struct mma_int_C_I16J8 {
 struct mma_int_C_I16J8 {

Разница между файлами не показана из-за своего большого размера
+ 383 - 267
ggml-cuda/mmq.cuh


Некоторые файлы не были показаны из-за большого количества измененных файлов