|
|
@@ -20,6 +20,20 @@ struct mma_int_A_I16K4 {
|
|
|
GGML_CUDA_ASSUME(ret < K);
|
|
|
return ret;
|
|
|
}
|
|
|
+
|
|
|
+ __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
|
|
|
+#if defined(INT8_MMA_AVAILABLE)
|
|
|
+ const int * xs = xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
|
|
|
+ asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
|
|
|
+ : "+r"(x[0]), "+r"(x[1])
|
|
|
+ : "l"(xs));
|
|
|
+#else
|
|
|
+#pragma unroll
|
|
|
+ for (int l = 0; l < ne; ++l) {
|
|
|
+ x[l] = xs0[get_i(l)*stride + get_k(l)];
|
|
|
+ }
|
|
|
+#endif // defined(INT8_MMA_AVAILABLE)
|
|
|
+ }
|
|
|
};
|
|
|
|
|
|
struct mma_int_A_I16K8 {
|
|
|
@@ -42,6 +56,20 @@ struct mma_int_A_I16K8 {
|
|
|
GGML_CUDA_ASSUME(ret < K);
|
|
|
return ret;
|
|
|
}
|
|
|
+
|
|
|
+ __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
|
|
|
+#if defined(INT8_MMA_AVAILABLE)
|
|
|
+ const int * xs = xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
|
|
|
+ asm("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
|
|
|
+ : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
|
|
|
+ : "l"(xs));
|
|
|
+#else
|
|
|
+#pragma unroll
|
|
|
+ for (int l = 0; l < ne; ++l) {
|
|
|
+ x[l] = xs0[get_i(l)*stride + get_k(l)];
|
|
|
+ }
|
|
|
+#endif // defined(INT8_MMA_AVAILABLE)
|
|
|
+ }
|
|
|
};
|
|
|
|
|
|
struct mma_int_B_J8K4 {
|
|
|
@@ -64,6 +92,20 @@ struct mma_int_B_J8K4 {
|
|
|
GGML_CUDA_ASSUME(ret < K);
|
|
|
return ret;
|
|
|
}
|
|
|
+
|
|
|
+ __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
|
|
|
+#if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster
|
|
|
+ const int * xs = xs0 + (threadIdx.x%J)*stride;
|
|
|
+ asm("ldmatrix.sync.aligned.m8n8.x1.b16 {%0}, [%1];"
|
|
|
+ : "+r"(x[0])
|
|
|
+ : "l"(xs));
|
|
|
+#else
|
|
|
+#pragma unroll
|
|
|
+ for (int l = 0; l < ne; ++l) {
|
|
|
+ x[l] = xs0[get_j(l)*stride + get_k(l)];
|
|
|
+ }
|
|
|
+#endif // defined(INT8_MMA_AVAILABLE)
|
|
|
+ }
|
|
|
};
|
|
|
|
|
|
struct mma_int_B_J8K8 {
|
|
|
@@ -86,6 +128,20 @@ struct mma_int_B_J8K8 {
|
|
|
GGML_CUDA_ASSUME(ret < K);
|
|
|
return ret;
|
|
|
}
|
|
|
+
|
|
|
+ __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
|
|
|
+#if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster
|
|
|
+ const int * xs = xs0 + (threadIdx.x%J)*stride + ((threadIdx.x/J)*(K/2)) % K;
|
|
|
+ asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
|
|
|
+ : "+r"(x[0]), "+r"(x[1])
|
|
|
+ : "l"(xs));
|
|
|
+#else
|
|
|
+#pragma unroll
|
|
|
+ for (int l = 0; l < ne; ++l) {
|
|
|
+ x[l] = xs0[get_j(l)*stride + get_k(l)];
|
|
|
+ }
|
|
|
+#endif // defined(INT8_MMA_AVAILABLE)
|
|
|
+ }
|
|
|
};
|
|
|
|
|
|
struct mma_int_C_I16J8 {
|