|
@@ -2,6 +2,7 @@
|
|
|
|
|
|
|
|
#include "mma.cuh"
|
|
#include "mma.cuh"
|
|
|
#include "common.cuh"
|
|
#include "common.cuh"
|
|
|
|
|
+#include "convert.cuh"
|
|
|
|
|
|
|
|
using namespace ggml_cuda_mma;
|
|
using namespace ggml_cuda_mma;
|
|
|
|
|
|
|
@@ -27,20 +28,35 @@ static __global__ void mul_mat_f(
|
|
|
const int stride_col_id, const int stride_row_id,
|
|
const int stride_col_id, const int stride_row_id,
|
|
|
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
|
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
|
|
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
|
|
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
|
|
|
-#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
|
|
|
|
|
|
+// TODO: handle this in a consistent and simpler way after AMD MFMA support has been added
|
|
|
|
|
+#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
|
|
|
|
|
+#if defined(AMD_WMMA_AVAILABLE)
|
|
|
|
|
+ // Special case for tf32, just dummy mma layout as wmma doesn't support it.
|
|
|
|
|
+ constexpr int tile_B_I = std::is_same_v<T, float> ? 8 : 16;
|
|
|
|
|
+ constexpr int tile_C_J = std::is_same_v<T, float> ? 8 : 16;
|
|
|
|
|
+ typedef tile<16, 8, T> tile_A;
|
|
|
|
|
+ typedef tile<tile_B_I, 8, T> tile_B;
|
|
|
|
|
+ typedef tile<16, tile_C_J, float> tile_C;
|
|
|
|
|
+
|
|
|
|
|
+ constexpr bool a_supported = tile_A::supported();
|
|
|
|
|
+ constexpr bool b_supported = tile_B::supported();
|
|
|
|
|
+ constexpr bool c_supported = tile_C::supported();
|
|
|
|
|
+ constexpr bool supported = a_supported && b_supported && c_supported;
|
|
|
|
|
+#else
|
|
|
constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported();
|
|
constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported();
|
|
|
constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported();
|
|
constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported();
|
|
|
-
|
|
|
|
|
- if (!I_16_supported && !I_32_supported) {
|
|
|
|
|
- NO_DEVICE_CODE;
|
|
|
|
|
- return;
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ constexpr bool supported = I_16_supported || I_32_supported;
|
|
|
|
|
|
|
|
constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster.
|
|
constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster.
|
|
|
|
|
|
|
|
typedef tile<I_preferred, 8, T> tile_A;
|
|
typedef tile<I_preferred, 8, T> tile_A;
|
|
|
typedef tile<8, 8, T> tile_B;
|
|
typedef tile<8, 8, T> tile_B;
|
|
|
typedef tile<I_preferred, 8, float> tile_C;
|
|
typedef tile<I_preferred, 8, float> tile_C;
|
|
|
|
|
+#endif // defined(AMD_WMMA_AVAILABLE)
|
|
|
|
|
+ if constexpr (!supported) {
|
|
|
|
|
+ NO_DEVICE_CODE;
|
|
|
|
|
+ return;
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
|
constexpr int tile_k_padded = warp_size + 4;
|
|
constexpr int tile_k_padded = warp_size + 4;
|
|
@@ -161,11 +177,11 @@ static __global__ void mul_mat_f(
|
|
|
|
|
|
|
|
if constexpr (!has_ids) {
|
|
if constexpr (!has_ids) {
|
|
|
const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f);
|
|
const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f);
|
|
|
- tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
|
|
|
|
|
|
|
+ tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast<T>(tmp);
|
|
|
} else {
|
|
} else {
|
|
|
const bool valid = j < cols_per_block && (col_base + j) < ncols_dst_total && slot_map[j] >= 0;
|
|
const bool valid = j < cols_per_block && (col_base + j) < ncols_dst_total && slot_map[j] >= 0;
|
|
|
float2 tmp = valid ? *(const float2*) &y[slot_map[j]*stride_channel_y + 2*(j*stride_col_y + col)] : make_float2(0.0f, 0.0f);
|
|
float2 tmp = valid ? *(const float2*) &y[slot_map[j]*stride_channel_y + 2*(j*stride_col_y + col)] : make_float2(0.0f, 0.0f);
|
|
|
- tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
|
|
|
|
|
|
|
+ tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast<T>(tmp);
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
} else {
|
|
} else {
|
|
@@ -239,7 +255,7 @@ static __global__ void mul_mat_f(
|
|
|
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
|
|
NO_DEVICE_CODE;
|
|
NO_DEVICE_CODE;
|
|
|
-#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
|
|
|
|
|
|
+#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
//This kernel is for larger batch sizes of mul_mat_id
|
|
//This kernel is for larger batch sizes of mul_mat_id
|
|
@@ -253,20 +269,35 @@ static __global__ void mul_mat_f_ids(
|
|
|
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
|
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
|
|
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
|
|
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
|
|
|
const uint3 sis1_fd, const uint3 nch_fd) {
|
|
const uint3 sis1_fd, const uint3 nch_fd) {
|
|
|
-#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
|
|
|
|
|
|
+// TODO: handle this in a consistent and simpler way after AMD MFMA support has been added
|
|
|
|
|
+#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
|
|
|
|
|
+#if defined(AMD_WMMA_AVAILABLE)
|
|
|
|
|
+ // Special case for tf32, just dummy mma layout as wmma doesn't support it.
|
|
|
|
|
+ constexpr int tile_B_I = std::is_same_v<T, float> ? 8 : 16;
|
|
|
|
|
+ constexpr int tile_C_J = std::is_same_v<T, float> ? 8 : 16;
|
|
|
|
|
+ typedef tile<16, 8, T> tile_A;
|
|
|
|
|
+ typedef tile<tile_B_I, 8, T> tile_B;
|
|
|
|
|
+ typedef tile<16, tile_C_J, float> tile_C;
|
|
|
|
|
+
|
|
|
|
|
+ constexpr bool a_supported = tile_A::supported();
|
|
|
|
|
+ constexpr bool b_supported = tile_B::supported();
|
|
|
|
|
+ constexpr bool c_supported = tile_C::supported();
|
|
|
|
|
+ constexpr bool supported = a_supported && b_supported && c_supported;
|
|
|
|
|
+#else
|
|
|
constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported();
|
|
constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported();
|
|
|
constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported();
|
|
constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported();
|
|
|
|
|
+ constexpr bool supported = I_16_supported || I_32_supported;
|
|
|
|
|
|
|
|
- if (!I_16_supported && !I_32_supported) {
|
|
|
|
|
- NO_DEVICE_CODE;
|
|
|
|
|
- return;
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work butr 16 is ~1% faster.
|
|
|
|
|
|
|
+ constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster.
|
|
|
|
|
|
|
|
typedef tile<I_preferred, 8, T> tile_A;
|
|
typedef tile<I_preferred, 8, T> tile_A;
|
|
|
typedef tile<8, 8, T> tile_B;
|
|
typedef tile<8, 8, T> tile_B;
|
|
|
typedef tile<I_preferred, 8, float> tile_C;
|
|
typedef tile<I_preferred, 8, float> tile_C;
|
|
|
|
|
+#endif // defined(AMD_WMMA_AVAILABLE)
|
|
|
|
|
+ if constexpr (!supported) {
|
|
|
|
|
+ NO_DEVICE_CODE;
|
|
|
|
|
+ return;
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
|
constexpr int tile_k_padded = warp_size + 4;
|
|
constexpr int tile_k_padded = warp_size + 4;
|
|
@@ -408,7 +439,7 @@ static __global__ void mul_mat_f_ids(
|
|
|
#pragma unroll
|
|
#pragma unroll
|
|
|
for (int j0 = 0; j0 < tile_B::I; ++j0) {
|
|
for (int j0 = 0; j0 < tile_B::I; ++j0) {
|
|
|
const float2 tmp = vals_buf[curr_buf][j0];
|
|
const float2 tmp = vals_buf[curr_buf][j0];
|
|
|
- tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
|
|
|
|
|
|
|
+ tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast<T>(tmp);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
if (itB + 1 < ntB) {
|
|
if (itB + 1 < ntB) {
|
|
@@ -492,7 +523,7 @@ static __global__ void mul_mat_f_ids(
|
|
|
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd);
|
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd);
|
|
|
NO_DEVICE_CODE;
|
|
NO_DEVICE_CODE;
|
|
|
-#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
|
|
|
|
|
|
+#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
template<typename T, int cols_per_block, int nwarps>
|
|
template<typename T, int cols_per_block, int nwarps>
|
|
@@ -554,7 +585,8 @@ void mul_mat_f_cuda(
|
|
|
cudaStream_t stream, const mmf_ids_data * ids_data) {
|
|
cudaStream_t stream, const mmf_ids_data * ids_data) {
|
|
|
typedef tile<16, 8, T> tile_A_16;
|
|
typedef tile<16, 8, T> tile_A_16;
|
|
|
typedef tile<32, 8, T> tile_A_32;
|
|
typedef tile<32, 8, T> tile_A_32;
|
|
|
- typedef tile< 8, 8, T> tile_B;
|
|
|
|
|
|
|
+ typedef tile<16, 8, T> tile_B_16;
|
|
|
|
|
+ typedef tile< 8, 8, T> tile_B_8;
|
|
|
|
|
|
|
|
GGML_ASSERT(ncols_x % 2 == 0);
|
|
GGML_ASSERT(ncols_x % 2 == 0);
|
|
|
GGML_ASSERT(stride_row % 2 == 0);
|
|
GGML_ASSERT(stride_row % 2 == 0);
|
|
@@ -581,7 +613,8 @@ void mul_mat_f_cuda(
|
|
|
|
|
|
|
|
constexpr int rows_per_block = MMF_ROWS_PER_BLOCK;
|
|
constexpr int rows_per_block = MMF_ROWS_PER_BLOCK;
|
|
|
const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + 4) * 4;
|
|
const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + 4) * 4;
|
|
|
- const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4;
|
|
|
|
|
|
|
+ const int nbytes_cols_per_block_pad = amd_wmma_available(cc) ? tile_B_16::I : tile_B_8::I;
|
|
|
|
|
+ const int nbytes_shared_combine = GGML_PAD(cols_per_block, nbytes_cols_per_block_pad) * (nwarps_best*rows_per_block + 4) * 4;
|
|
|
const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
|
|
const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
|
|
|
const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0;
|
|
const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0;
|
|
|
const int nbytes_shared_total = nbytes_shared + nbytes_slotmap;
|
|
const int nbytes_shared_total = nbytes_shared + nbytes_slotmap;
|