|
|
@@ -47,11 +47,89 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
|
|
|
1;
|
|
|
}
|
|
|
|
|
|
+enum mmvq_parameter_table_id {
|
|
|
+ MMVQ_PARAMETERS_GENERIC = 0,
|
|
|
+ MMVQ_PARAMETERS_GCN,
|
|
|
+ MMVQ_PARAMETERS_RDNA2
|
|
|
+};
|
|
|
+
|
|
|
+static constexpr __device__ mmvq_parameter_table_id get_device_table_id() {
|
|
|
+#if defined(RDNA2) || defined(RDNA3)
|
|
|
+ return MMVQ_PARAMETERS_RDNA2;
|
|
|
+#elif defined(GCN) || defined(CDNA)
|
|
|
+ return MMVQ_PARAMETERS_GCN;
|
|
|
+#else
|
|
|
+ return MMVQ_PARAMETERS_GENERIC;
|
|
|
+#endif
|
|
|
+}
|
|
|
+
|
|
|
+static __host__ mmvq_parameter_table_id get_device_table_id(int cc) {
|
|
|
+ if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3(cc)) {
|
|
|
+ return MMVQ_PARAMETERS_RDNA2;
|
|
|
+ }
|
|
|
+ if (GGML_CUDA_CC_IS_GCN(cc) || GGML_CUDA_CC_IS_CDNA(cc)) {
|
|
|
+ return MMVQ_PARAMETERS_GCN;
|
|
|
+ }
|
|
|
+ return MMVQ_PARAMETERS_GENERIC;
|
|
|
+}
|
|
|
+
|
|
|
+static constexpr __host__ __device__ int calc_nwarps(int ncols_y, mmvq_parameter_table_id table_id) {
|
|
|
+ if (table_id == MMVQ_PARAMETERS_GENERIC) {
|
|
|
+ switch (ncols_y) {
|
|
|
+ case 1:
|
|
|
+ case 2:
|
|
|
+ case 3:
|
|
|
+ case 4:
|
|
|
+ return 4;
|
|
|
+ case 5:
|
|
|
+ case 6:
|
|
|
+ case 7:
|
|
|
+ case 8:
|
|
|
+ return 2;
|
|
|
+ default:
|
|
|
+ return 1;
|
|
|
+ }
|
|
|
+ } else if (table_id == MMVQ_PARAMETERS_GCN) {
|
|
|
+ switch (ncols_y) {
|
|
|
+ case 1:
|
|
|
+ case 2:
|
|
|
+ case 3:
|
|
|
+ case 4:
|
|
|
+ return 2;
|
|
|
+ case 5:
|
|
|
+ case 6:
|
|
|
+ case 7:
|
|
|
+ case 8:
|
|
|
+ default:
|
|
|
+ return 1;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return 1;
|
|
|
+}
|
|
|
+
|
|
|
+static constexpr __host__ __device__ int calc_rows_per_block(int ncols_y, int table_id) {
|
|
|
+ if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN) {
|
|
|
+ switch (ncols_y) {
|
|
|
+ case 1:
|
|
|
+ return 1;
|
|
|
+ case 2:
|
|
|
+ case 3:
|
|
|
+ case 4:
|
|
|
+ case 5:
|
|
|
+ case 6:
|
|
|
+ case 7:
|
|
|
+ case 8:
|
|
|
+ return 2;
|
|
|
+ default:
|
|
|
+ return 1;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return 1;
|
|
|
+}
|
|
|
+
|
|
|
template <ggml_type type, int ncols_y>
|
|
|
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
|
|
// tell the compiler to use as many registers as it wants, see nwarps definition below
|
|
|
-__launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1)
|
|
|
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
|
|
+__launch_bounds__(calc_nwarps(ncols_y, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
|
|
|
static __global__ void mul_mat_vec_q(
|
|
|
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
|
|
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
|
|
@@ -59,24 +137,20 @@ static __global__ void mul_mat_vec_q(
|
|
|
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
|
|
constexpr int qi = ggml_cuda_type_traits<type>::qi;
|
|
|
constexpr int vdr = get_vdr_mmvq(type);
|
|
|
+ constexpr mmvq_parameter_table_id table_id = get_device_table_id();
|
|
|
+ constexpr int nwarps = calc_nwarps(ncols_y, table_id);
|
|
|
+ constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_y, table_id);
|
|
|
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
|
|
|
|
constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
|
|
|
|
|
|
-#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
|
|
|
- constexpr int nwarps = 1;
|
|
|
- constexpr int rows_per_cuda_block = 1;
|
|
|
-#else
|
|
|
- constexpr int nwarps = ncols_y <= 4 ? 4 : 2;
|
|
|
- constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
|
|
|
-#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
|
|
|
-
|
|
|
- const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
|
|
|
+ const int tid = warp_size*threadIdx.y + threadIdx.x;
|
|
|
const int row0 = rows_per_cuda_block*blockIdx.x;
|
|
|
const int blocks_per_row_x = ncols_x / qk;
|
|
|
const int blocks_per_col_y = nrows_y / QK8_1;
|
|
|
- constexpr int blocks_per_iter = vdr * nwarps*WARP_SIZE / qi;
|
|
|
+ constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
|
|
|
|
|
|
-// partial sum for each thread
|
|
|
+ // partial sum for each thread
|
|
|
float tmp[ncols_y][rows_per_cuda_block] = {0.0f};
|
|
|
|
|
|
const block_q8_1 * y = (const block_q8_1 *) vy;
|
|
|
@@ -96,7 +170,7 @@ static __global__ void mul_mat_vec_q(
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y][rows_per_cuda_block][WARP_SIZE];
|
|
|
+ __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y][rows_per_cuda_block][warp_size];
|
|
|
if (threadIdx.y > 0) {
|
|
|
#pragma unroll
|
|
|
for (int j = 0; j < ncols_y; ++j) {
|
|
|
@@ -120,7 +194,7 @@ static __global__ void mul_mat_vec_q(
|
|
|
for (int l = 0; l < nwarps-1; ++l) {
|
|
|
tmp[j][i] += tmp_shared[l][j][i][threadIdx.x];
|
|
|
}
|
|
|
- tmp[j][i] = warp_reduce_sum(tmp[j][i]);
|
|
|
+ tmp[j][i] = warp_reduce_sum<warp_size>(tmp[j][i]);
|
|
|
}
|
|
|
|
|
|
if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) {
|
|
|
@@ -129,6 +203,13 @@ static __global__ void mul_mat_vec_q(
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+static std::pair<dim3, dim3> calc_launch_params(const int ncols_y, const int nrows_x, const int warp_size, const mmvq_parameter_table_id table_id) {
|
|
|
+ const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_y, table_id) - 1) / calc_rows_per_block(ncols_y, table_id);
|
|
|
+ const dim3 block_nums(nblocks, 1, 1);
|
|
|
+ const dim3 block_dims(warp_size, calc_nwarps(ncols_y, table_id), 1);
|
|
|
+ return {block_nums, block_dims};
|
|
|
+}
|
|
|
+
|
|
|
template <ggml_type type>
|
|
|
static void mul_mat_vec_q_cuda(
|
|
|
const void * vx, const void * vy, float * dst,
|
|
|
@@ -137,65 +218,67 @@ static void mul_mat_vec_q_cuda(
|
|
|
GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0);
|
|
|
GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE);
|
|
|
|
|
|
- int id = ggml_cuda_get_device();
|
|
|
-
|
|
|
- int64_t nwarps = 1;
|
|
|
- int64_t rows_per_cuda_block = 1;
|
|
|
-
|
|
|
- if (ggml_cuda_info().devices[id].cc < GGML_CUDA_CC_RDNA2) { // NVIDIA and AMD older than RDNA2
|
|
|
- switch(ncols_y) {
|
|
|
- case 1:
|
|
|
- nwarps = 4;
|
|
|
- rows_per_cuda_block = 1;
|
|
|
- break;
|
|
|
- case 2:
|
|
|
- case 3:
|
|
|
- case 4:
|
|
|
- nwarps = 4;
|
|
|
- rows_per_cuda_block = 2;
|
|
|
- break;
|
|
|
- case 5:
|
|
|
- case 6:
|
|
|
- case 7:
|
|
|
- case 8:
|
|
|
- nwarps = 2;
|
|
|
- rows_per_cuda_block = 2;
|
|
|
- break;
|
|
|
- default:
|
|
|
- GGML_ABORT("fatal error");
|
|
|
- break;
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1) / rows_per_cuda_block;
|
|
|
- const dim3 block_nums(nblocks, 1, 1);
|
|
|
- const dim3 block_dims(WARP_SIZE, nwarps, 1);
|
|
|
+ const int device = ggml_cuda_get_device();
|
|
|
+ const int warp_size = ggml_cuda_info().devices[device].warp_size;
|
|
|
+ const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc);
|
|
|
|
|
|
switch (ncols_y) {
|
|
|
case 1:
|
|
|
- mul_mat_vec_q<type, 1><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
|
|
+ {
|
|
|
+ constexpr int c_ncols_y = 1;
|
|
|
+ std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
|
|
|
+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
|
|
break;
|
|
|
+ }
|
|
|
case 2:
|
|
|
- mul_mat_vec_q<type, 2><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
|
|
+ {
|
|
|
+ constexpr int c_ncols_y = 2;
|
|
|
+ std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
|
|
|
+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
|
|
break;
|
|
|
+ }
|
|
|
case 3:
|
|
|
- mul_mat_vec_q<type, 3><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
|
|
+ {
|
|
|
+ constexpr int c_ncols_y = 3;
|
|
|
+ std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
|
|
|
+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
|
|
break;
|
|
|
+ }
|
|
|
case 4:
|
|
|
- mul_mat_vec_q<type, 4><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
|
|
+ {
|
|
|
+ constexpr int c_ncols_y = 4;
|
|
|
+ std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
|
|
|
+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
|
|
break;
|
|
|
+ }
|
|
|
case 5:
|
|
|
- mul_mat_vec_q<type, 5><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
|
|
+ {
|
|
|
+ constexpr int c_ncols_y = 5;
|
|
|
+ std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
|
|
|
+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
|
|
break;
|
|
|
+ }
|
|
|
case 6:
|
|
|
- mul_mat_vec_q<type, 6><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
|
|
+ {
|
|
|
+ constexpr int c_ncols_y = 6;
|
|
|
+ std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
|
|
|
+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
|
|
break;
|
|
|
+ }
|
|
|
case 7:
|
|
|
- mul_mat_vec_q<type, 7><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
|
|
+ {
|
|
|
+ constexpr int c_ncols_y = 7;
|
|
|
+ std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
|
|
|
+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
|
|
break;
|
|
|
+ }
|
|
|
case 8:
|
|
|
- mul_mat_vec_q<type, 8><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
|
|
+ {
|
|
|
+ constexpr int c_ncols_y = 8;
|
|
|
+ std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
|
|
|
+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
|
|
break;
|
|
|
+ }
|
|
|
default:
|
|
|
GGML_ABORT("fatal error");
|
|
|
break;
|