|
|
@@ -68,10 +68,31 @@ static __device__ __forceinline__ half2 ggml_cuda_movmatrix(const half2 x) {
|
|
|
|
|
|
namespace ggml_cuda_mma {
|
|
|
|
|
|
+ // Some architectures like Volta or CDNA3 perform multiple matrix multiplications per warp in parallel,
|
|
|
+ // effectively the warp is being split into subgroups of threads that each perform a single mma instruction.
|
|
|
+ // In those cases the data can be split in different ways across the warp.
|
|
|
+ enum data_layout {
|
|
|
+ // By default the data uses the I direction as its major dimension and the J direction as its minor dimension.
|
|
|
+ // For the A/C matrices this means I major == row major, J major == column major.
|
|
|
+ // For the B matrix this means I major == column major, J major == row major.
|
|
|
+ // MIRRORED == Each data value is held exactly once per thread subgroup.
|
|
|
+ DATA_LAYOUT_I_MAJOR = 0, // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell.
|
|
|
+ DATA_LAYOUT_I_MAJOR_MIRRORED = 10,
|
|
|
+ DATA_LAYOUT_J_MAJOR_MIRRORED = 20,
|
|
|
+ };
|
|
|
+ // Implemented mma combinations are:
|
|
|
+ // - (I_MAJOR, I_MAJOR) -> I_MAJOR
|
|
|
+ // - (I_MAJOR, I_MAJOR_MIRRORED) -> I_MAJOR
|
|
|
+ // - (I_MAJOR, J_MAJOR_MIRRORED) -> I_MAJOR
|
|
|
+
|
|
|
+ template <int I_, int J_, typename T, data_layout ds_=DATA_LAYOUT_I_MAJOR>
|
|
|
+ struct tile {};
|
|
|
+
|
|
|
template <int I_, int J_, typename T>
|
|
|
- struct tile {
|
|
|
- static constexpr int I = I_;
|
|
|
- static constexpr int J = J_;
|
|
|
+ struct tile<I_, J_, T, DATA_LAYOUT_I_MAJOR> {
|
|
|
+ static constexpr int I = I_;
|
|
|
+ static constexpr int J = J_;
|
|
|
+ static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
|
|
|
|
|
|
#if defined(AMD_MFMA_AVAILABLE)
|
|
|
static constexpr int ne = I * J / 64;
|
|
|
@@ -131,9 +152,9 @@ namespace ggml_cuda_mma {
|
|
|
static __device__ __forceinline__ int get_i(const int l) {
|
|
|
if constexpr (I == 32 && J == 8) {
|
|
|
#ifdef GGML_CUDA_MMA_NO_VOLTA_PERM
|
|
|
- return (((threadIdx.x % 16) / 4) * 8) | ((threadIdx.x / 16) * 4) | (l & 2) | (threadIdx.x % 2);
|
|
|
+ return (((threadIdx.x % 16) / 4) * 8) + ((threadIdx.x / 16) * 4) + (l & 2) + (threadIdx.x % 2);
|
|
|
#else
|
|
|
- return (l & 2) | (threadIdx.x & ~2);
|
|
|
+ return (l & 2) + (threadIdx.x & ~2);
|
|
|
#endif // GGML_CUDA_MMA_NO_VOLTA_PERM
|
|
|
} else {
|
|
|
NO_DEVICE_CODE;
|
|
|
@@ -143,7 +164,7 @@ namespace ggml_cuda_mma {
|
|
|
|
|
|
static __device__ __forceinline__ int get_j(const int l) {
|
|
|
if constexpr (I == 32 && J == 8) {
|
|
|
- return (threadIdx.x & 2) | (l & (4 + 1));
|
|
|
+ return (threadIdx.x & 2) + (l & (4 + 1));
|
|
|
} else {
|
|
|
NO_DEVICE_CODE;
|
|
|
return -1;
|
|
|
@@ -196,9 +217,9 @@ namespace ggml_cuda_mma {
|
|
|
} else if constexpr (I == 8 && J == 8) {
|
|
|
return threadIdx.x / 4;
|
|
|
} else if constexpr (I == 16 && J == 8) {
|
|
|
- return ((l / 2) * 8) | (threadIdx.x / 4);
|
|
|
+ return ((l / 2) * 8) + (threadIdx.x / 4);
|
|
|
} else if constexpr (I == 16 && J == 16) {
|
|
|
- return (((l / 2) % 2) * 8) | (threadIdx.x / 4);
|
|
|
+ return (((l / 2) % 2) * 8) + (threadIdx.x / 4);
|
|
|
} else if constexpr (I == 32 && J == 8) {
|
|
|
return tile<16, 8, T>::get_i(l); // Memory layout simply repeated with same pattern in i direction.
|
|
|
} else {
|
|
|
@@ -211,11 +232,11 @@ namespace ggml_cuda_mma {
|
|
|
if constexpr (I == 8 && J == 4) {
|
|
|
return threadIdx.x % 4;
|
|
|
} else if constexpr (I == 8 && J == 8) {
|
|
|
- return (l * 4) | (threadIdx.x % 4);
|
|
|
+ return (l * 4) + (threadIdx.x % 4);
|
|
|
} else if constexpr (I == 16 && J == 8) {
|
|
|
- return ((threadIdx.x % 4) * 2) | (l % 2);
|
|
|
+ return ((threadIdx.x % 4) * 2) + (l % 2);
|
|
|
} else if constexpr (I == 16 && J == 16) {
|
|
|
- return ((l / 4) * 8) | ((threadIdx.x % 4) * 2) | (l % 2);
|
|
|
+ return ((l / 4) * 8) + ((threadIdx.x % 4) * 2) + (l % 2);
|
|
|
} else if constexpr (I == 32 && J == 8) {
|
|
|
return tile<16, 8, T>::get_j(l); // Memory layout simply repeated with same pattern in i direction.
|
|
|
} else {
|
|
|
@@ -227,26 +248,24 @@ namespace ggml_cuda_mma {
|
|
|
};
|
|
|
|
|
|
template <int I_, int J_>
|
|
|
- struct tile<I_, J_, half2> {
|
|
|
- static constexpr int I = I_;
|
|
|
- static constexpr int J = J_;
|
|
|
+ struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR> {
|
|
|
+ static constexpr int I = I_;
|
|
|
+ static constexpr int J = J_;
|
|
|
+ static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
|
|
|
|
|
|
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
|
|
- static constexpr int ne = I == 8 && J == 8 ? I * J / (WARP_SIZE/4) : I * J / WARP_SIZE;
|
|
|
+ static constexpr int ne = I * J / WARP_SIZE;
|
|
|
half2 x[ne] = {{0.0f, 0.0f}};
|
|
|
|
|
|
static constexpr __device__ bool supported() {
|
|
|
- if (I == 8 && J == 8) return true;
|
|
|
- if (I == 32 && J == 8) return true;
|
|
|
+ if (I == 32 && J == 4) return true;
|
|
|
return false;
|
|
|
}
|
|
|
|
|
|
static __device__ __forceinline__ int get_i(const int l) {
|
|
|
- if constexpr (I == 8 && J == 8) {
|
|
|
- return ((threadIdx.x / 16) * 4) | (threadIdx.x % 4);
|
|
|
- } else if constexpr (I == 32 && J == 8) {
|
|
|
+ if constexpr (I == 32 && J == 4) {
|
|
|
#ifdef GGML_CUDA_MMA_NO_VOLTA_PERM
|
|
|
- return (((threadIdx.x % 16) / 4) * 8) | ((threadIdx.x / 16) * 4) | (threadIdx.x % 4);
|
|
|
+ return (((threadIdx.x % 16) / 4) * 8) + ((threadIdx.x / 16) * 4) + (threadIdx.x % 4);
|
|
|
#else
|
|
|
return threadIdx.x;
|
|
|
#endif // GGML_CUDA_MMA_NO_VOLTA_PERM
|
|
|
@@ -257,7 +276,7 @@ namespace ggml_cuda_mma {
|
|
|
}
|
|
|
|
|
|
static __device__ __forceinline__ int get_j(const int l) {
|
|
|
- if constexpr ((I == 8 || I == 32) && J == 8) {
|
|
|
+ if constexpr (I == 32 && J == 4) {
|
|
|
return l;
|
|
|
} else {
|
|
|
NO_DEVICE_CODE;
|
|
|
@@ -307,11 +326,11 @@ namespace ggml_cuda_mma {
|
|
|
if constexpr (I == 8 && J == 8) {
|
|
|
return threadIdx.x / 4;
|
|
|
} else if constexpr (I == 16 && J == 4) {
|
|
|
- return (l * 8) | (threadIdx.x / 4);
|
|
|
+ return (l * 8) + (threadIdx.x / 4);
|
|
|
} else if constexpr (I == 16 && J == 8) {
|
|
|
- return ((l % 2) * 8) | (threadIdx.x / 4);
|
|
|
+ return ((l % 2) * 8) + (threadIdx.x / 4);
|
|
|
} else if constexpr (I == 32 && J == 8) {
|
|
|
- return ((l / 4) * 16) | ((l % 2) * 8) | (threadIdx.x / 4);
|
|
|
+ return ((l / 4) * 16) + ((l % 2) * 8) + (threadIdx.x / 4);
|
|
|
} else {
|
|
|
NO_DEVICE_CODE;
|
|
|
return -1;
|
|
|
@@ -320,13 +339,13 @@ namespace ggml_cuda_mma {
|
|
|
|
|
|
static __device__ __forceinline__ int get_j(const int l) {
|
|
|
if constexpr (I == 8 && J == 8) {
|
|
|
- return (l * 4) | (threadIdx.x % 4);
|
|
|
+ return (l * 4) + (threadIdx.x % 4);
|
|
|
} else if constexpr (I == 16 && J == 4) {
|
|
|
return threadIdx.x % 4;
|
|
|
} else if constexpr (I == 16 && J == 8) {
|
|
|
- return ((l / 2) * 4) | (threadIdx.x % 4);
|
|
|
+ return ((l / 2) * 4) + (threadIdx.x % 4);
|
|
|
} else if constexpr (I == 32 && J == 8) {
|
|
|
- return ((l & 2) * 2) | (threadIdx.x % 4);
|
|
|
+ return ((l & 2) * 2) + (threadIdx.x % 4);
|
|
|
} else {
|
|
|
NO_DEVICE_CODE;
|
|
|
return -1;
|
|
|
@@ -336,14 +355,15 @@ namespace ggml_cuda_mma {
|
|
|
};
|
|
|
|
|
|
template <int I_, int J_>
|
|
|
- struct tile<I_, J_, nv_bfloat162> {
|
|
|
- static constexpr int I = I_;
|
|
|
- static constexpr int J = J_;
|
|
|
+ struct tile<I_, J_, nv_bfloat162, DATA_LAYOUT_I_MAJOR> {
|
|
|
+ static constexpr int I = I_;
|
|
|
+ static constexpr int J = J_;
|
|
|
+ static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
|
|
|
+ static constexpr int ne = I * J / WARP_SIZE;
|
|
|
|
|
|
-#if defined(AMD_WMMA_AVAILABLE)
|
|
|
- static constexpr int ne = I * J / 32;
|
|
|
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
|
|
|
|
|
|
+#if defined(AMD_WMMA_AVAILABLE)
|
|
|
static constexpr __device__ bool supported() {
|
|
|
if (I == 16 && J == 8) return true;
|
|
|
return false;
|
|
|
@@ -367,9 +387,6 @@ namespace ggml_cuda_mma {
|
|
|
}
|
|
|
}
|
|
|
#else
|
|
|
- static constexpr int ne = I * J / WARP_SIZE;
|
|
|
- nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
|
|
|
-
|
|
|
static constexpr __device__ bool supported() {
|
|
|
if (I == 8 && J == 8) return true;
|
|
|
if (I == 16 && J == 4) return true;
|
|
|
@@ -381,9 +398,9 @@ namespace ggml_cuda_mma {
|
|
|
if constexpr (I == 8 && J == 8) {
|
|
|
return threadIdx.x / 4;
|
|
|
} else if constexpr (I == 16 && J == 4) {
|
|
|
- return (l * 8) | (threadIdx.x / 4);
|
|
|
+ return (l * 8) + (threadIdx.x / 4);
|
|
|
} else if constexpr (I == 16 && J == 8) {
|
|
|
- return ((l % 2) * 8) | (threadIdx.x / 4);
|
|
|
+ return ((l % 2) * 8) + (threadIdx.x / 4);
|
|
|
} else {
|
|
|
NO_DEVICE_CODE;
|
|
|
return -1;
|
|
|
@@ -392,11 +409,11 @@ namespace ggml_cuda_mma {
|
|
|
|
|
|
static __device__ __forceinline__ int get_j(const int l) {
|
|
|
if constexpr (I == 8 && J == 8) {
|
|
|
- return (l * 4) | (threadIdx.x % 4);
|
|
|
+ return (l * 4) + (threadIdx.x % 4);
|
|
|
} else if constexpr (I == 16 && J == 4) {
|
|
|
return threadIdx.x % 4;
|
|
|
} else if constexpr (I == 16 && J == 8) {
|
|
|
- return ((l / 2) * 4) | (threadIdx.x % 4);
|
|
|
+ return ((l / 2) * 4) + (threadIdx.x % 4);
|
|
|
} else {
|
|
|
NO_DEVICE_CODE;
|
|
|
return -1;
|
|
|
@@ -405,6 +422,73 @@ namespace ggml_cuda_mma {
|
|
|
#endif // defined(AMD_WMMA_AVAILABLE)
|
|
|
};
|
|
|
|
|
|
+ template <int I_, int J_>
|
|
|
+ struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> {
|
|
|
+ static constexpr int I = I_;
|
|
|
+ static constexpr int J = J_;
|
|
|
+ static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
|
|
|
+ static constexpr int ne = I * J / (WARP_SIZE/4);
|
|
|
+
|
|
|
+ half2 x[ne] = {{0.0f, 0.0f}};
|
|
|
+
|
|
|
+ static constexpr __device__ bool supported() {
|
|
|
+ if (I == 8 && J == 4) return true;
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+
|
|
|
+ static __device__ __forceinline__ int get_i(const int /*l*/) {
|
|
|
+ if constexpr (I == 8 && J == 4) {
|
|
|
+ return ((threadIdx.x / 16) * 4) + (threadIdx.x % 4);
|
|
|
+ } else {
|
|
|
+ NO_DEVICE_CODE;
|
|
|
+ return -1;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ static __device__ __forceinline__ int get_j(const int l) {
|
|
|
+ if constexpr (I == 8 && J == 4) {
|
|
|
+ return l;
|
|
|
+ } else {
|
|
|
+ NO_DEVICE_CODE;
|
|
|
+ return -1;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ };
|
|
|
+
|
|
|
+ template <int I_, int J_>
|
|
|
+ struct tile<I_, J_, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> {
|
|
|
+ static constexpr int I = I_;
|
|
|
+ static constexpr int J = J_;
|
|
|
+ static constexpr data_layout dl = DATA_LAYOUT_J_MAJOR_MIRRORED;
|
|
|
+ static constexpr int ne = I * J / (WARP_SIZE/4);
|
|
|
+
|
|
|
+ half2 x[ne] = {{0.0f, 0.0f}};
|
|
|
+
|
|
|
+ static constexpr __device__ bool supported() {
|
|
|
+ if (I == 8 && J == 4) return true;
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+
|
|
|
+ static __device__ __forceinline__ int get_i(const int l) {
|
|
|
+ if constexpr (I == 8 && J == 4) {
|
|
|
+ return ((l / 2) * 4) + (threadIdx.x % 4);
|
|
|
+ } else {
|
|
|
+ NO_DEVICE_CODE;
|
|
|
+ return -1;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ static __device__ __forceinline__ int get_j(const int l) {
|
|
|
+ if constexpr (I == 8 && J == 4) {
|
|
|
+ return ((threadIdx.x / 16) * 2) + (l % 2);
|
|
|
+ } else {
|
|
|
+ NO_DEVICE_CODE;
|
|
|
+ return -1;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ };
|
|
|
+
|
|
|
+#if defined(TURING_MMA_AVAILABLE)
|
|
|
template <int I, int J>
|
|
|
static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
|
|
|
tile<I, J/2, half2> ret;
|
|
|
@@ -422,9 +506,26 @@ namespace ggml_cuda_mma {
|
|
|
|
|
|
return ret;
|
|
|
}
|
|
|
+#else // Volta
|
|
|
+ template <int I, int J>
|
|
|
+ static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
|
|
|
+ tile<I, J/2, half2> ret;
|
|
|
+#pragma unroll
|
|
|
+ for (int l0 = 0; l0 < tile_float.ne; l0 += 4) {
|
|
|
+ ret.x[l0/2 + 0] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]);
|
|
|
+ ret.x[l0/2 + 1] = make_half2(tile_float.x[l0 + 2], tile_float.x[l0 + 3]);
|
|
|
+
|
|
|
+ // On Volta FP16 and FP32 tiles have a different memory layout,
|
|
|
+ // for the conversion threads with an offset of 2 need to exchange half their values:
|
|
|
+ ret.x[l0/2 + (((threadIdx.x % 4) / 2) ^ 1)] = __shfl_xor_sync(
|
|
|
+ 0xFFFFFFFF, ret.x[l0/2 + (((threadIdx.x % 4) / 2) ^ 1)], 2, WARP_SIZE);
|
|
|
+ }
|
|
|
+ return ret;
|
|
|
+ }
|
|
|
+#endif // defined(TURING_MMA_AVAILABLE)
|
|
|
|
|
|
- template <int I, int J, typename T>
|
|
|
- static __device__ __forceinline__ void load_generic(tile<I, J, T> & t, const T * __restrict__ xs0, const int stride) {
|
|
|
+ template <int I, int J, typename T, data_layout dl>
|
|
|
+ static __device__ __forceinline__ void load_generic(tile<I, J, T, dl> & t, const T * __restrict__ xs0, const int stride) {
|
|
|
#if defined(AMD_MFMA_AVAILABLE)
|
|
|
if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
|
|
|
#pragma unroll
|
|
|
@@ -511,18 +612,6 @@ namespace ggml_cuda_mma {
|
|
|
: "=r"(xi[0]), "=r"(xi[1]), "=r"(xi[2]), "=r"(xi[3])
|
|
|
: "l"(xs));
|
|
|
#else
|
|
|
-#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
|
|
- GGML_UNUSED_VARS(t, xs0, stride);
|
|
|
- NO_DEVICE_CODE;
|
|
|
-#else
|
|
|
- load_generic(t, xs0, stride);
|
|
|
-#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
|
|
-#endif // TURING_MMA_AVAILABLE
|
|
|
- }
|
|
|
-
|
|
|
- template <typename T>
|
|
|
- static __device__ __forceinline__ void load_ldmatrix(
|
|
|
- tile<32, 8, T> & t, const T * __restrict__ xs0, const int stride) {
|
|
|
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
|
|
#if 1
|
|
|
// TODO: more generic handling
|
|
|
@@ -533,9 +622,31 @@ namespace ggml_cuda_mma {
|
|
|
load_generic(t, xs0, stride);
|
|
|
#endif // 1
|
|
|
#else
|
|
|
- tile<16, 8, T> * t16 = (tile<16, 8, T> *) &t;
|
|
|
- load_ldmatrix(t16[0], xs0 + 0*stride, stride);
|
|
|
- load_ldmatrix(t16[1], xs0 + 16*stride, stride);
|
|
|
+ load_generic(t, xs0, stride);
|
|
|
+#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
|
|
+#endif // TURING_MMA_AVAILABLE
|
|
|
+ }
|
|
|
+
|
|
|
+ static __device__ __forceinline__ void load_ldmatrix(
|
|
|
+ tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & t, const half2 * __restrict__ xs0, const int stride) {
|
|
|
+ ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride);
|
|
|
+ }
|
|
|
+
|
|
|
+ static __device__ __forceinline__ void load_ldmatrix(
|
|
|
+ tile<8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> & t, const half2 * __restrict__ xs0, const int stride) {
|
|
|
+#pragma unroll
|
|
|
+ for (int l0 = 0; l0 < t.ne; l0 += 2) {
|
|
|
+ ggml_cuda_memcpy_1<2*sizeof(half2)>(t.x + l0, xs0 + t.get_i(l0)*stride + t.get_j(l0));
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ static __device__ __forceinline__ void load_ldmatrix(
|
|
|
+ tile<32, 4, half2> & t, const half2 * __restrict__ xs0, const int stride) {
|
|
|
+#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
|
|
+ ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride);
|
|
|
+#else
|
|
|
+ GGML_UNUSED_VARS(t, xs0, stride);
|
|
|
+ NO_DEVICE_CODE;
|
|
|
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
|
|
}
|
|
|
|
|
|
@@ -860,14 +971,14 @@ namespace ggml_cuda_mma {
|
|
|
template <typename T1, typename T2, int J, int K>
|
|
|
static __device__ __forceinline__ void mma(
|
|
|
tile<32, J, T1> & D, const tile<32, K, T2> & A, const tile<J, K, T2> & B) {
|
|
|
- tile<16, J, T1> * D16 = (tile<16, J, T1> *) &D;
|
|
|
- tile<16, K, T2> * A16 = (tile<16, K, T2> *) &A;
|
|
|
+ tile <16, J, T1> * D16 = reinterpret_cast< tile<16, J, T1> *>(&D);
|
|
|
+ const tile<16, K, T2> * A16 = reinterpret_cast<const tile<16, K, T2> *>(&A);
|
|
|
mma(D16[0], A16[0], B);
|
|
|
mma(D16[1], A16[1], B);
|
|
|
}
|
|
|
|
|
|
static __device__ __forceinline__ void mma(
|
|
|
- tile<32, 8, float> & D, const tile<32, 8, half2> & A, const tile<8, 8, half2> & B) {
|
|
|
+ tile<32, 8, float> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & B) {
|
|
|
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
|
|
const int * Axi = (const int *) A.x;
|
|
|
const int * Bxi = (const int *) B.x;
|
|
|
@@ -880,20 +991,30 @@ namespace ggml_cuda_mma {
|
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
|
|
|
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
|
|
|
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]), "r"(Bxi[3]));
|
|
|
- asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
|
|
|
- "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
|
|
|
- : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
|
|
|
- : "r"(Axi[4]), "r"(Axi[5]), "r"(Bxi[4]), "r"(Bxi[5]));
|
|
|
- asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
|
|
|
- "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
|
|
|
- : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
|
|
|
- : "r"(Axi[6]), "r"(Axi[7]), "r"(Bxi[6]), "r"(Bxi[7]));
|
|
|
#else
|
|
|
- tile <16, 8, float> * D16 = reinterpret_cast<tile <16, 8, float> *>(&D);
|
|
|
- const tile<16, 8, half2> * A16 = reinterpret_cast<const tile<16, 8, half2> *>(&A);
|
|
|
- mma(D16[0], A16[0], B);
|
|
|
- mma(D16[1], A16[1], B);
|
|
|
-#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
|
|
+ GGML_UNUSED_VARS(D, A, B);
|
|
|
+ NO_DEVICE_CODE;
|
|
|
+#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
|
|
+ }
|
|
|
+
|
|
|
+ static __device__ __forceinline__ void mma(
|
|
|
+ tile<32, 4, half2> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> & B) {
|
|
|
+#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
|
|
+ const int * Axi = (const int *) A.x;
|
|
|
+ const int * Bxi = (const int *) B.x;
|
|
|
+ int * Dxi = (int *) D.x;
|
|
|
+ asm("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 "
|
|
|
+ "{%0, %1, %2, %3}, {%4, %5}, {%6, %7}, {%0, %1, %2, %3};"
|
|
|
+ : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
|
|
|
+ : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]), "r"(Bxi[1]));
|
|
|
+ asm("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 "
|
|
|
+ "{%0, %1, %2, %3}, {%4, %5}, {%6, %7}, {%0, %1, %2, %3};"
|
|
|
+ : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
|
|
|
+ : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]), "r"(Bxi[3]));
|
|
|
+#else
|
|
|
+ GGML_UNUSED_VARS(D, A, B);
|
|
|
+ NO_DEVICE_CODE;
|
|
|
+#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
|
|
}
|
|
|
|
|
|
static __device__ __forceinline__ void mma(
|