|
|
@@ -251,25 +251,21 @@ static constexpr __device__ int mmq_get_granularity_device(const int /*mmq_x*/)
|
|
|
#endif // AMD_MFMA_AVAILABLE
|
|
|
|
|
|
#if defined(GGML_USE_HIP)
|
|
|
-static int mmq_get_nwarps_host(const int cc) {
|
|
|
- return amd_mfma_available(cc) ? 8 : 4;
|
|
|
+static int mmq_get_nwarps_host(const int cc, const int warp_size) {
|
|
|
+ return amd_mfma_available(cc) ? 8 : 256/warp_size;
|
|
|
}
|
|
|
#else
|
|
|
-static int mmq_get_nwarps_host(const int /*cc*/) {
|
|
|
- return 8;
|
|
|
+static int mmq_get_nwarps_host(const int /*cc*/, const int warp_size) {
|
|
|
+ return 256/warp_size;
|
|
|
}
|
|
|
#endif // (GGML_USE_HIP)
|
|
|
|
|
|
static constexpr __device__ int mmq_get_nwarps_device() {
|
|
|
-#if defined(GGML_USE_HIP)
|
|
|
#if defined(AMD_MFMA_AVAILABLE)
|
|
|
return 8;
|
|
|
#else
|
|
|
- return 4;
|
|
|
+ return 256/ggml_cuda_get_physical_warp_size();
|
|
|
#endif // AMD_MFMA_AVAILABLE
|
|
|
-#else
|
|
|
- return 8;
|
|
|
-#endif // defined(GGML_USE_HIP)
|
|
|
}
|
|
|
|
|
|
// ------------------------------------------------------------
|
|
|
@@ -3472,7 +3468,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
|
|
|
const int cc = ggml_cuda_info().devices[id].cc;
|
|
|
const int nsm = ggml_cuda_info().devices[id].nsm;
|
|
|
const int warp_size = ggml_cuda_info().devices[id].warp_size;
|
|
|
- const int nwarps = mmq_get_nwarps_host(cc);
|
|
|
+ const int nwarps = mmq_get_nwarps_host(cc, warp_size);
|
|
|
const int mmq_y = get_mmq_y_host(cc);
|
|
|
|
|
|
const dim3 block_dims(warp_size, nwarps, 1);
|
|
|
@@ -3559,7 +3555,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
|
|
|
const int cc = ggml_cuda_info().devices[id].cc;
|
|
|
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
|
|
|
const int warp_size = ggml_cuda_info().devices[id].warp_size;
|
|
|
- const int nwarps = mmq_get_nwarps_host(cc);
|
|
|
+ const int nwarps = mmq_get_nwarps_host(cc, warp_size);
|
|
|
|
|
|
const int mmq_x_max = get_mmq_x_max_host(cc);
|
|
|
const int mmq_y = get_mmq_y_host(cc);
|